# Classification with pyspark and pytorch lightning

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from plantclef.utils import get_spark

spark = get_spark()
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/09 12:50:35 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/04/09 12:50:35 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
# Path and dataset names
gcs_path = "gs://dsgt-clef-plantclef-2024/data/process"
dct_emb_train = "training_cropped_resized_v2/dino_dct/data"

# Define the GCS path to the embedding files
dct_gcs_path = f"{gcs_path}/{dct_emb_train}"

# Read the Parquet file into a DataFrame
dct_df = spark.read.parquet(dct_gcs_path)

# Show the data
dct_df.show(n=5, truncate=50)

                                                                                

+--------------------------------------------+----------+--------------------------------------------------+
|                                  image_name|species_id|                                     dct_embedding|
+--------------------------------------------+----------+--------------------------------------------------+
|170e88ca9af457daa1038092479b251c61c64f7d.jpg|   1742956|[-20648.51, 2133.689, -2555.3125, 14820.57, 685...|
|c24a2d8646f5bc7112a39908bd2f6c45bf066a71.jpg|   1356834|[-25395.82, -12564.387, 24736.02, 20483.8, 2115...|
|e1f68e5f05618921969aee2575de20e537e6d66b.jpg|   1563754|[-26178.633, -7670.404, -22552.29, -6563.006, 8...|
|b0433cd6968b57d52e5c25dc45a28e674a25e61e.jpg|   1367432|[-23662.764, -6773.8213, -8283.518, 3769.6064, ...|
|96478a0fe20a41e755b0c8d798690f2c2b7c115f.jpg|   1389010|[-22182.172, -19444.006, 23355.23, 7042.8604, -...|
+--------------------------------------------+----------+--------------------------------------------------+
only showing top 5 

### convert to a PyTorch dataset

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader


class EmbeddingDataset(Dataset):
    def __init__(self, data, num_classes):
        self.data = data.toPandas()  # Convert to Pandas DF
        self.num_classes = num_classes  # Total number of classes
        self.species_id = self._get_species_index(data)

    def __len__(self):
        return len(self.data)

    def _get_species_index(self, data):
        species_ids = (
            data.select("species_id").distinct().rdd.map(lambda r: r[0]).collect()
        )
        species_id_to_index = {
            species_id: idx for idx, species_id in enumerate(species_ids)
        }
        return species_id_to_index

    def __getitem__(self, index):
        row = self.data.iloc[index]
        embeddings = torch.tensor(row["dct_embedding"])
        labels = torch.zeros(self.num_classes, dtype=torch.float)
        species_index = self.species_id[row["species_id"]]
        labels[species_index] = 1.0
        return embeddings, labels

In [15]:
import pytorch_lightning as pl
from torch import nn
from torchmetrics.classification import (
    MultilabelAccuracy,
    MultilabelF1Score,
    MultilabelPrecision,
    MultilabelRecall,
)


class MultiLabelClassifier(pl.LightningModule):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.layer = nn.Linear(num_features, num_classes)
        self.accuracy = MultilabelAccuracy(threshold=0.5, num_labels=num_classes)
        self.f1_score = MultilabelF1Score(threshold=0.5, num_labels=num_classes)
        self.precision = MultilabelPrecision(threshold=0.5, num_labels=num_classes)
        self.recall = MultilabelRecall(threshold=0.5, num_labels=num_classes)

    def forward(self, x):
        return torch.sigmoid(
            self.layer(x)
        )  # Using sigmoid for multi-label classification

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.binary_cross_entropy(
            y_hat, y.float()
        )  # Ensure y is float for BCE
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        # Update metrics
        self.log("val_accuracy", self.accuracy(y_hat, y))
        self.log("val_f1", self.f1_score(y_hat, y))
        self.log("val_precision", self.precision(y_hat, y))
        self.log("val_recall", self.recall(y_hat, y))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.002)

### prepare subset of data

In [12]:
from pyspark.sql import functions as F

# Get a small subset of the dataset
sub_species_df = (
    dct_df.groupBy("species_id")
    .agg(F.count("*").alias("n"))
    .where(F.col("n") > 100)
    .orderBy(F.rand(seed=42))
    .limit(10)
).cache()
sub_species_df.show(truncate=80)

# Collect the species_id into a list of values
species_id_subset = [
    row["species_id"]
    for row in sub_species_df.select("species_id").distinct().collect()
]

# Get subset of images to test pipeline
subset_df = dct_df.where(F.col("species_id").isin(species_id_subset)).cache()
subset_df.show()

# Count number of rows in subset_df
print(f"subset_df count: {subset_df.count()}")

24/04/09 13:00:22 WARN CacheManager: Asked to cache already cached data.
24/04/09 13:00:22 WARN CacheManager: Asked to cache already cached data.


+----------+---+
|species_id|  n|
+----------+---+
|   1389268|291|
|   1358860|247|
|   1363409|271|
|   1396750|190|
|   1394732|243|
|   1359227|196|
|   1363259|262|
|   1393571|264|
|   1393681|248|
|   1360715|223|
+----------+---+

+--------------------+----------+--------------------+
|          image_name|species_id|       dct_embedding|
+--------------------+----------+--------------------+
|1417f00b385c9648e...|   1359227|[-21118.215, 1026...|
|1796804389a7af364...|   1393571|[-17140.287, -229...|
|f9bda2da2c8817243...|   1363409|[-18112.29, -9305...|
|3512914e3568872bc...|   1389268|[-31069.455, 1593...|
|a54095a70be70d3dd...|   1363259|[-23388.5, 1492.1...|
|2fd167d70f666ef5a...|   1393571|[-22588.898, 3876...|
|497c2b1590ed58e12...|   1393571|[-18909.512, -167...|
|af24c93fa73a97312...|   1393681|[-16674.791, -137...|
|5c88ac6797d82c04b...|   1363259|[-26085.584, 3116...|
|67053d421a3b1ef9b...|   1396750|[-28502.367, 9868...|
|37bba53333910dfab...|   1393571|[-29814.75, 8

### train-validation split

In [13]:
# Perform a train-validation split
train_df, valid_df = subset_df.randomSplit([0.8, 0.2], seed=42)

# Cache the splits to improve performance
train_df = train_df.cache()
valid_df = valid_df.cache()

print(f"Train DF count: {train_df.count()}")
print(f"Valid DF count: {valid_df.count()}")

24/04/09 13:00:24 WARN CacheManager: Asked to cache already cached data.
24/04/09 13:00:24 WARN CacheManager: Asked to cache already cached data.


Train DF count: 1949
Valid DF count: 486


### train the model

In [16]:
# Init params
num_classes = 10
batch_size = 32
epochs = 10

# Prepare PyTorch datasets
train_data = EmbeddingDataset(data=train_df, num_classes=num_classes)
valid_data = EmbeddingDataset(data=valid_df, num_classes=num_classes)

# Prepare DataLoaders
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=False)

# model
model = MultiLabelClassifier(
    num_features=64, num_classes=num_classes
)  # Only using 10 classses for testing

trainer = pl.Trainer(max_epochs=epochs)
trainer.fit(model, train_dataloader, valid_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/mgustine/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name      | Type                | Params
--------------------------------------------------
0 | layer     | Linear              | 650   
1 | accuracy  | MultilabelAccuracy  | 0     
2 | f1_score  | MultilabelF1Score   | 0     
3 | precision | MultilabelPrecision | 0     
4 | recall    | MultilabelRecall    | 0  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/mgustine/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


                                                                           

/home/mgustine/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 9: 100%|██████████| 61/61 [00:00<00:00, 106.46it/s, v_num=12]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 61/61 [00:00<00:00, 105.62it/s, v_num=12]
