# Classification with pyspark and pytorch lightning

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from plantclef.utils import get_spark
from pyspark.sql import functions as F
import os
from pathlib import Path
from pytorch_lightning.callbacks import ModelCheckpoint

spark = get_spark()
display(spark)

  from .autonotebook import tqdm as notebook_tqdm
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/13 22:26:04 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/04/13 22:26:05 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
# Path and dataset names
gcs_path = "gs://dsgt-clef-plantclef-2024/data/process"
dct_emb_train = "training_cropped_resized_v2/dino_dct/data"

# Define the GCS path to the embedding files
dct_gcs_path = f"{gcs_path}/{dct_emb_train}"

# Read the Parquet file into a DataFrame
dct_df = spark.read.parquet(dct_gcs_path)

# Show the data
dct_df.show(n=5, truncate=50)

                                                                                

+--------------------------------------------+----------+--------------------------------------------------+
|                                  image_name|species_id|                                     dct_embedding|
+--------------------------------------------+----------+--------------------------------------------------+
|170e88ca9af457daa1038092479b251c61c64f7d.jpg|   1742956|[-20648.51, 2133.689, -2555.3125, 14820.57, 685...|
|c24a2d8646f5bc7112a39908bd2f6c45bf066a71.jpg|   1356834|[-25395.82, -12564.387, 24736.02, 20483.8, 2115...|
|e1f68e5f05618921969aee2575de20e537e6d66b.jpg|   1563754|[-26178.633, -7670.404, -22552.29, -6563.006, 8...|
|b0433cd6968b57d52e5c25dc45a28e674a25e61e.jpg|   1367432|[-23662.764, -6773.8213, -8283.518, 3769.6064, ...|
|96478a0fe20a41e755b0c8d798690f2c2b7c115f.jpg|   1389010|[-22182.172, -19444.006, 23355.23, 7042.8604, -...|
+--------------------------------------------+----------+--------------------------------------------------+
only showing top 5 

### convert to a PyTorch dataset

In [11]:
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from torch import nn
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)


class EmbeddingDataset(Dataset):
    def __init__(self, spark_df):
        self.embeddings = [x["dct_embedding"] for x in spark_df.collect()]
        self.labels = [x["species_id"] for x in spark_df.collect()]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        embeddings = torch.tensor(self.embeddings[index], dtype=torch.float)
        labels = torch.tensor(self.labels[index], dtype=torch.long)
        return embeddings, labels


class MultiClassClassifier(pl.LightningModule):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.layer = nn.Linear(num_features, num_classes)
        self.learning_rate = 0.002
        self.accuracy = MulticlassAccuracy(num_classes=num_classes, average="weighted")
        self.f1_score = MulticlassF1Score(num_classes=num_classes, average="weighted")
        self.precision = MulticlassPrecision(
            num_classes=num_classes, average="weighted"
        )
        self.recall = MulticlassRecall(num_classes=num_classes, average="weighted")

    def forward(self, x):
        return torch.log_softmax(self.layer(x), dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        # Use negative log likelihood loss for multiclass classification
        loss = torch.nn.functional.nll_loss(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        valid_loss = {"valid_loss": torch.nn.functional.nll_loss(y_hat, y)}
        # Update metrics
        self.log(
            "valid_accuracy", self.accuracy(y_hat, y), on_step=False, on_epoch=True
        )
        self.log("valid_f1", self.f1_score(y_hat, y), on_step=False, on_epoch=True)
        self.log(
            "valid_precision", self.precision(y_hat, y), on_step=False, on_epoch=True
        )
        self.log("valid_recall", self.recall(y_hat, y), on_step=False, on_epoch=True)
        return valid_loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

### prepare subset of data, 10 species

In [5]:
from pyspark.sql import functions as F

# Get a small subset of the dataset
sub_species_df = (
    dct_df.groupBy("species_id")
    .agg(F.count("*").alias("n"))
    .where(F.col("n") > 100)
    .orderBy(F.rand(seed=42))
    .limit(10)
)

# Collect the species_id into a list of values
species_id_subset = [
    row["species_id"]
    for row in sub_species_df.select("species_id").distinct().collect()
]

# Get subset of images to test pipeline
subset_df = dct_df.where(F.col("species_id").isin(species_id_subset))

# # Show DFs and count number of rows in subset_df
# sub_species_df.show(truncate=80)
# subset_df.show()
# print(f"subset_df count: {subset_df.count()}")

                                                                                

### prepare subset of data, species with >= 100 images

In [6]:
# Param
species_count = 100

# Transformation
grouped_df = (
    dct_df.groupBy("species_id")
    .agg(F.count("species_id").alias("n"))
    .filter(f"n >= {species_count}")
    .orderBy(F.col("n").desc())
)
grouped_count = grouped_df.count()
print(f"Grouped DF count: {grouped_count}")

# Join the grouped_df with dct_df to get the selected species_id
sub_dct_df = dct_df.join(grouped_df, "species_id", "inner").drop("n")
sub_dct_count = sub_dct_df.count()
print(f"Subset of DCT DF count: {sub_dct_count}")

# Show final DF
sub_dct_df.show()

                                                                                

Grouped DF count: 4797


                                                                                

Subset of DCT DF count: 1323820




+----------+--------------------+--------------------+
|species_id|          image_name|       dct_embedding|
+----------+--------------------+--------------------+
|   1393370|5ff9d9d52dffbaf6f...|[-26706.178, 4697...|
|   1392699|c7836a178d7ab62c4...|[-33230.715, -200...|
|   1391491|d72a8ec1a51d5f798...|[-26003.773, 2450...|
|   1356209|d558d431e8f45e3e4...|[-21498.129, 1596...|
|   1356209|838535612bd792064...|[-18489.996, -264...|
|   1360196|d0a7cd780bae40203...|[-24187.928, 1197...|
|   1391529|b8ca83f88909e7b53...|[-29997.904, 1558...|
|   1396980|da923be1e2fbd6d59...|[-22626.072, 1051...|
|   1397518|fbdd202481452a588...|[-19896.24, 6995....|
|   1743246|a1bbf2c842bba16b8...|[-23111.025, -122...|
|   1743246|bf352dd4ba63ebede...|[-21741.3, 42392....|
|   1743246|2779e8dc801c0b69a...|[-22881.133, -798...|
|   1722546|c0d8d74c08668a644...|[-26559.254, 3705...|
|   1393639|27add1b9a5fb8b802...|[-27432.734, -240...|
|   1393639|aa133fe4a64b13a0f...|[-21371.344, 1590...|
|   139361

                                                                                

### train/validation split

In [16]:
# Perform a train-validation split
train_df, valid_df = sub_dct_df.randomSplit([0.8, 0.2], seed=42)
print(f"train: {train_df.count()}, valid: {valid_df.count()}")

# Init params
NUM_CLASSES = int(sub_dct_df.select("species_id").distinct().count())
BATCH_SIZE = 32
NUM_EPOCHS = 10
print(f"Num classes: {NUM_CLASSES}")



### train the model

In [14]:
# Create PyTorch Datasets
train_dataset = EmbeddingDataset(spark_df=train_df)
valid_dataset = EmbeddingDataset(spark_df=valid_df)

# DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)

# Current path
curr_dir = Path(os.getcwd())

# Setup checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor="valid_accuracy",  # Monitor validation accuracy for improvement
    dirpath=curr_dir,  # Directory path for saving checkpoints
    filename="model-{epoch:02d}-{val_accuracy:.2f}",
    save_top_k=1,  # Save only top 1 model
    mode="max",
)

# model
model = MultiClassClassifier(num_features=64, num_classes=NUM_CLASSES)

# Train model
trainer = pl.Trainer(max_epochs=NUM_EPOCHS, callbacks=[checkpoint_callback])
trainer.fit(model, train_dataloader, valid_dataloader)

                                                                                

                                                                   

GPU available: False, used: False                                               
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type                | Params
--------------------------------------------------
0 | layer     | Linear              | 311 K 
1 | accuracy  | MulticlassAccuracy  | 0     
2 | f1_score  | MulticlassF1Score   | 0     
3 | precision | MulticlassPrecision | 0     
4 | recall    | MulticlassRecall    | 0     
--------------------------------------------------
311 K     Trainable params
0         Non-trainable params
311 K     Total params
1.247     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

IndexError: Target 1355868 is out of bounds.

### load trained model