# Classification with pyspark and pytorch lightning

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from plantclef.utils import get_spark
from pyspark.sql import functions as F
import os
from pathlib import Path
from pytorch_lightning.callbacks import ModelCheckpoint

spark = get_spark()
display(spark)

  from .autonotebook import tqdm as notebook_tqdm
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/14 17:50:32 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/04/14 17:50:32 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
# Path and dataset names
gcs_path = "gs://dsgt-clef-plantclef-2024/data/process"
dct_emb_train = "training_cropped_resized_v2/dino_dct/data"

# Define the GCS path to the embedding files
dct_gcs_path = f"{gcs_path}/{dct_emb_train}"

# Read the Parquet file into a DataFrame
dct_df = spark.read.parquet(dct_gcs_path)

# Show the data
dct_df.show(n=5, truncate=50)

                                                                                

+--------------------------------------------+----------+--------------------------------------------------+
|                                  image_name|species_id|                                     dct_embedding|
+--------------------------------------------+----------+--------------------------------------------------+
|170e88ca9af457daa1038092479b251c61c64f7d.jpg|   1742956|[-20648.51, 2133.689, -2555.3125, 14820.57, 685...|
|c24a2d8646f5bc7112a39908bd2f6c45bf066a71.jpg|   1356834|[-25395.82, -12564.387, 24736.02, 20483.8, 2115...|
|e1f68e5f05618921969aee2575de20e537e6d66b.jpg|   1563754|[-26178.633, -7670.404, -22552.29, -6563.006, 8...|
|b0433cd6968b57d52e5c25dc45a28e674a25e61e.jpg|   1367432|[-23662.764, -6773.8213, -8283.518, 3769.6064, ...|
|96478a0fe20a41e755b0c8d798690f2c2b7c115f.jpg|   1389010|[-22182.172, -19444.006, 23355.23, 7042.8604, -...|
+--------------------------------------------+----------+--------------------------------------------------+
only showing top 5 

### convert to a PyTorch dataset

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from torch import nn
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)


class EmbeddingDataset(Dataset):
    def __init__(self, spark_df):
        self.embeddings = [x["dct_embedding"] for x in spark_df.collect()]
        self.labels = [x["index"] for x in spark_df.collect()]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        embeddings = torch.tensor(self.embeddings[index], dtype=torch.float)
        labels = torch.tensor(self.labels[index], dtype=torch.long)
        return embeddings, labels


class MultiClassClassifier(pl.LightningModule):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.save_hyperparameters()  # Saves hyperparams in the checkpoints
        self.layer = nn.Linear(num_features, num_classes)
        self.learning_rate = 0.002
        self.accuracy = MulticlassAccuracy(num_classes=num_classes, average="weighted")
        self.f1_score = MulticlassF1Score(num_classes=num_classes, average="weighted")
        self.precision = MulticlassPrecision(
            num_classes=num_classes, average="weighted"
        )
        self.recall = MulticlassRecall(num_classes=num_classes, average="weighted")

    def forward(self, x):
        return torch.log_softmax(self.layer(x), dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        # Use negative log likelihood loss for multiclass classification
        loss = torch.nn.functional.nll_loss(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        valid_loss = {"valid_loss": torch.nn.functional.nll_loss(y_hat, y)}
        # Update metrics
        self.log(
            "valid_accuracy", self.accuracy(y_hat, y), on_step=False, on_epoch=True
        )
        self.log("valid_f1", self.f1_score(y_hat, y), on_step=False, on_epoch=True)
        self.log(
            "valid_precision", self.precision(y_hat, y), on_step=False, on_epoch=True
        )
        self.log("valid_recall", self.recall(y_hat, y), on_step=False, on_epoch=True)
        return valid_loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

### prepare subset of data for testing end-to-end pipeline

In [14]:
from pyspark.sql import functions as F

# Set variable to limit number of species in subset data
LIMIT_SPECIES = 5

# Get a small subset of the dataset
sub_species_df = (
    dct_df.groupBy("species_id")
    .agg(F.count("*").alias("n"))
    .where(F.col("n") > 100)
    .orderBy(F.rand(seed=42))
    .limit(LIMIT_SPECIES)
)

# Collect the species_id into a list of values
species_id_subset = [
    row["species_id"]
    for row in sub_species_df.select("species_id").distinct().collect()
]

# Assign an idex to each species_id
sub_species_idx_df = (
    sub_species_df.select("species_id")
    .distinct()
    .withColumn("index", F.monotonically_increasing_id())
)

# Get subset of images to test pipeline
subset_df = (
    dct_df.join(F.broadcast(sub_species_idx_df), "species_id", "inner")
    .drop("n")
    .where(F.col("species_id").isin(species_id_subset))
)

# # Show DFs and count number of rows in subset_df
# sub_species_df.show(truncate=80)
# subset_df.show()
# print(f"subset_df count: {subset_df.count()}")

                                                                                

### final subset of data, species with species having >= 100 images

In [6]:
from pyspark.sql.window import Window

# Param
SPECIES_IMAGE_COUNT = 100

# Transformation
grouped_df = (
    dct_df.groupBy("species_id")
    .agg(F.count("species_id").alias("n"))
    .filter(f"n >= {SPECIES_IMAGE_COUNT}")
    .orderBy(F.col("n").desc())
)
grouped_count = grouped_df.count()
print(f"Grouped DF count: {grouped_count}")

# Assign an idex to each species_id
species_idx_df = (
    grouped_df.select("species_id")
    .distinct()
    .withColumn("index", F.monotonically_increasing_id())
)

# Join the grouped_df with dct_df to get the selected species_id
filtered_dct_df = dct_df.join(F.broadcast(species_idx_df), "species_id", "inner").drop(
    "n"
)
filtered_dct_count = filtered_dct_df.count()
print(f"Filtered DCT DF count: {filtered_dct_count}")

# Show final DF
filtered_dct_df.show()

                                                                                

Grouped DF count: 4797


                                                                                

Filtered DCT DF count: 1323820


                                                                                

+----------+--------------------+--------------------+-----+
|species_id|          image_name|       dct_embedding|index|
+----------+--------------------+--------------------+-----+
|   1742956|170e88ca9af457daa...|[-20648.51, 2133....| 3769|
|   1367432|b0433cd6968b57d52...|[-23662.764, -677...| 4656|
|   1360260|2b0301802884c465a...|[-26661.809, 1455...| 2669|
|   1378563|e60616fddca9656d3...|[-31823.447, 393....| 1494|
|   1358543|24c155bbe8f0e77ae...|[-21757.35, 2965....| 2808|
|   1390691|c469f9a672a9b1d96...|[-27213.7, 20937....|  795|
|   1361302|39092f483d3fdef57...|[-22082.877, -312...| 4744|
|   1356428|8b33e8cd6dad65d5f...|[-32586.871, 3672...|   23|
|   1586076|6e393f50b7a31dea1...|[-28792.238, 1479...|  455|
|   1360262|a6a5272e3310a299e...|[-28321.89, 11192...|  886|
|   1362293|1fd06740a23e7353a...|[-20066.348, -968...| 1007|
|   1363358|587b331459d6f4643...|[-26982.426, 2332...|   75|
|   1390699|1f9ac6784cfae5012...|[-19362.758, -649...| 3654|
|   1360978|b46920705c0d

### train/validation split

In [15]:
# Perform a train-validation split
def train_valid_split(df):
    train_df, valid_df = df.randomSplit([0.8, 0.2], seed=42)
    return train_df, valid_df


# Pass desired DF to function
train_df, valid_df = train_valid_split(df=subset_df)
print(f"train: {train_df.count()}, valid: {valid_df.count()}")

# Init params
NUM_CLASSES = int(train_df.select("species_id").distinct().count())
BATCH_SIZE = 32
NUM_EPOCHS = 10
print(f"Num classes: {NUM_CLASSES}")

                                                                                

train: 986, valid: 256




Num classes: 5


                                                                                

### prepare DataLoaders

In [16]:
# Create PyTorch Datasets
train_dataset = EmbeddingDataset(spark_df=train_df)
valid_dataset = EmbeddingDataset(spark_df=valid_df)

# DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)

                                                                                

### train model

In [17]:
# Current path
curr_dir = Path(os.getcwd())

# Setup checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor="valid_accuracy",  # Monitor validation accuracy for improvement
    dirpath=curr_dir,  # Directory path for saving checkpoints
    filename="model-{epoch:02d}-{val_accuracy:.2f}",
    save_top_k=1,  # Save only top 1 model
    mode="max",
)

# model
model = MultiClassClassifier(
    num_features=64,
    num_classes=NUM_CLASSES,
)

# Train model
trainer = pl.Trainer(max_epochs=NUM_EPOCHS, callbacks=[checkpoint_callback])
trainer.fit(model, train_dataloader, valid_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/mgustine/plantclef-2024/notebooks/modeling/lightning_logs
/home/mgustine/.local/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /home/mgustine/plantclef-2024/notebooks/modeling exists and is not empty.

  | Name      | Type                | Params
--------------------------------------------------
0 | layer     | Linear              | 325   
1 | accuracy  | MulticlassAccuracy  | 0     
2 | f1_score  | MulticlassF1Score   | 0     
3 | precision | MulticlassPrecision | 0     
4 | recall    | MulticlassRecall    | 0     
--------------------------------------------------
325       Trainable params
0         Non-trainable params
325       Total params
0.001     Total estimated model params size (MB)


                                                                            

/home/mgustine/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/home/mgustine/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/home/mgustine/.local/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (31) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.




Epoch 9: 100%|██████████| 31/31 [00:00<00:00, 193.04it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 31/31 [00:00<00:00, 190.64it/s, v_num=0]


### load trained model

In [18]:
# Assuming the best model is saved in the same directory as your script
best_model = "model-epoch=06-val_accuracy=0.00.ckpt"
checkpoint_path = curr_dir / Path(best_model)  # Adjust filename as necessary

# Load the checkpoint file
checkpoint = torch.load(
    checkpoint_path, map_location=torch.device("cpu")
)  # Use 'cpu' to avoid GPU memory issues

# Print the keys and any hyperparameters stored in the checkpoint
print(checkpoint.keys())
if "hyper_parameters" in checkpoint:
    print(checkpoint["hyper_parameters"])
else:
    print("No hyperparameters stored in this checkpoint.")

# Load the trained model from checkpoint
model = MultiClassClassifier.load_from_checkpoint(
    checkpoint_path=checkpoint_path,
    num_features=64,
    num_classes=NUM_CLASSES,
)
model.eval()  # Set the model to evaluation mode

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])
{'num_features': 64, 'num_classes': 5}


MultiClassClassifier(
  (layer): Linear(in_features=64, out_features=5, bias=True)
  (accuracy): MulticlassAccuracy()
  (f1_score): MulticlassF1Score()
  (precision): MulticlassPrecision()
  (recall): MulticlassRecall()
)

### make predictions on validation dataset

In [19]:
def validate_model(model, dataloader):
    model.eval()  # Set the model to evaluation mode
    predictions = []
    labels = []

    with torch.no_grad():  # Turn off gradients for validation, saves memory and computations
        for batch in dataloader:
            inputs, targets = batch  # Adjust these based on how your data is structured
            outputs = model(inputs)
            _, predicted = torch.max(
                outputs.data, 1
            )  # Get the index of the max log-probability
            predictions.extend(predicted.tolist())
            labels.extend(targets.tolist())

    return predictions, labels


# Call the function
predictions, labels = validate_model(model, valid_dataloader)

### evaluate model's performance

In [20]:
from sklearn.metrics import accuracy_score, classification_report

# Accuracy
accuracy = accuracy_score(labels, predictions)

# Target names
target_names = [
    str(row["index"]) for row in train_df.select("index").distinct().collect()
]
report = classification_report(labels, predictions, target_names=target_names)

# Print scores
print(f"Validation Accuracy: {accuracy:.4f}")
print(f"Classification Report:\n{report}")

