# Classification with pyspark and pytorch lightning

> **Databricks Notebook:** [End-to-end distributed training with TorchDistributor](https://docs.databricks.com/en/_extras/notebooks/source/deep-learning/torch-distributor-notebook.html)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import datetime
from pathlib import Path
from plantclef.utils import get_spark
from pyspark.sql import functions as F
from pytorch_lightning.callbacks import ModelCheckpoint


spark = get_spark()
display(spark)

  from .autonotebook import tqdm as notebook_tqdm
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/14 23:38:04 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/04/14 23:38:04 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
# Path and dataset names
gcs_path = "gs://dsgt-clef-plantclef-2024/data/process"
dct_emb_train = "training_cropped_resized_v2/dino_dct/data"

# Define the GCS path to the embedding files
dct_gcs_path = f"{gcs_path}/{dct_emb_train}"

# Read the Parquet file into a DataFrame
dct_df = spark.read.parquet(dct_gcs_path)

# Show the data
dct_df.show(n=5, truncate=50)

                                                                                

+--------------------------------------------+----------+--------------------------------------------------+
|                                  image_name|species_id|                                     dct_embedding|
+--------------------------------------------+----------+--------------------------------------------------+
|170e88ca9af457daa1038092479b251c61c64f7d.jpg|   1742956|[-20648.51, 2133.689, -2555.3125, 14820.57, 685...|
|c24a2d8646f5bc7112a39908bd2f6c45bf066a71.jpg|   1356834|[-25395.82, -12564.387, 24736.02, 20483.8, 2115...|
|e1f68e5f05618921969aee2575de20e537e6d66b.jpg|   1563754|[-26178.633, -7670.404, -22552.29, -6563.006, 8...|
|b0433cd6968b57d52e5c25dc45a28e674a25e61e.jpg|   1367432|[-23662.764, -6773.8213, -8283.518, 3769.6064, ...|
|96478a0fe20a41e755b0c8d798690f2c2b7c115f.jpg|   1389010|[-22182.172, -19444.006, 23355.23, 7042.8604, -...|
+--------------------------------------------+----------+--------------------------------------------------+
only showing top 5 

### convert to a PyTorch dataset

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from torch import nn
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)


class EmbeddingDataset(Dataset):
    def __init__(self, spark_df):
        self.embeddings = [x["dct_embedding"] for x in spark_df.collect()]
        self.labels = [x["index"] for x in spark_df.collect()]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        embeddings = torch.tensor(self.embeddings[index], dtype=torch.float)
        labels = torch.tensor(self.labels[index], dtype=torch.long)
        return embeddings, labels


class MultiClassClassifier(pl.LightningModule):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.save_hyperparameters()  # Saves hyperparams in the checkpoints
        self.layer = nn.Linear(num_features, num_classes)
        self.learning_rate = 0.002
        self.accuracy = MulticlassAccuracy(num_classes=num_classes, average="weighted")
        self.f1_score = MulticlassF1Score(num_classes=num_classes, average="weighted")
        self.precision = MulticlassPrecision(
            num_classes=num_classes, average="weighted"
        )
        self.recall = MulticlassRecall(num_classes=num_classes, average="weighted")

    def forward(self, x):
        return torch.log_softmax(self.layer(x), dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        # Use negative log likelihood loss for multiclass classification
        loss = torch.nn.functional.nll_loss(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        valid_loss = {"valid_loss": torch.nn.functional.nll_loss(y_hat, y)}
        # Update metrics
        self.log(
            "valid_accuracy", self.accuracy(y_hat, y), on_step=False, on_epoch=True
        )
        self.log("valid_f1", self.f1_score(y_hat, y), on_step=False, on_epoch=True)
        self.log(
            "valid_precision", self.precision(y_hat, y), on_step=False, on_epoch=True
        )
        self.log("valid_recall", self.recall(y_hat, y), on_step=False, on_epoch=True)
        return valid_loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

### prepare subset of data for testing end-to-end pipeline

In [5]:
from pyspark.sql import functions as F
from pyspark.sql import DataFrame


def prepare_species_data(
    dct_df: DataFrame,
    limit_species: int = None,
    species_image_count: int = 100,
):
    """
    Prepare species data by filtering, indexing, and joining.

    :param dct_df: DataFrame containing species data
    :param limit_species: Maximum number of species to include (None means no limit)
    :param species_image_count: Minimum number of images per species to include
    :return: DataFrame of filtered and indexed species data
    """
    # Aggregate and filter species based on image count
    grouped_df = (
        dct_df.groupBy("species_id")
        .agg(F.count("species_id").alias("n"))
        .filter(F.col("n") >= species_image_count)
        .orderBy(F.col("n").desc())
        .withColumn("index", F.monotonically_increasing_id())
    ).drop("n")

    # Use broadcast join to optimize smaller DataFrame joining
    filtered_dct_df = dct_df.join(F.broadcast(grouped_df), "species_id", "inner").drop(
        "index"
    )

    # Optionally limit the number of species
    if limit_species:
        limited_grouped_df = (
            (
                grouped_df.orderBy(F.rand(seed=42))
                .limit(limit_species)
                .withColumn("new_index", F.monotonically_increasing_id())
            )
            .drop("index")
            .withColumnRenamed("new_index", "index")
        )

        filtered_dct_df = filtered_dct_df.join(
            F.broadcast(limited_grouped_df), "species_id", "inner"
        )

    return filtered_dct_df

In [6]:
# Params
LIMIT_SPECIES = 5
SPECIES_IMAGE_COUNT = 100

# Call function
prepared_df = prepare_species_data(
    dct_df, limit_species=LIMIT_SPECIES, species_image_count=SPECIES_IMAGE_COUNT
)
print(f"DF count: {prepared_df.count()}")
prepared_df.show()

                                                                                

DF count: 1185


[Stage 27:>                                                         (0 + 1) / 1]

+----------+--------------------+--------------------+-----+
|species_id|          image_name|       dct_embedding|index|
+----------+--------------------+--------------------+-----+
|   1358851|a5a1530acc42ee28a...|[-22140.71, -2232...|    3|
|   1392723|76056d8c5c2eabdae...|[-18462.121, -112...|    4|
|   1360938|aa65bf7e5cbbea170...|[-27158.367, -183...|    0|
|   1392723|ae436ff1f04ca5412...|[-21858.686, -435...|    4|
|   1360938|3d922d3fe00d95887...|[-25446.95, -5724...|    0|
|   1360299|c914a7f8d83a73727...|[-24541.422, 1324...|    1|
|   1360299|5b995de41dc8c507e...|[-26373.861, 1665...|    1|
|   1358851|6ceb22e1e2d2a0560...|[-24388.037, -243...|    3|
|   1358851|360605951bcdd6843...|[-26956.902, -127...|    3|
|   1360938|cc7b5743d897349af...|[-25043.629, -657...|    0|
|   1358851|43a7b8a23a79645ce...|[-24329.762, -147...|    3|
|   1392723|107d18234ccc4bf99...|[-20970.615, 7978...|    4|
|   1358851|ed49aa18677936d8f...|[-17723.512, -340...|    3|
|   1357220|d6edbca4549d

                                                                                

### train/validation split

In [7]:
# Perform a train-validation split
def train_valid_split(df):
    train_df, valid_df = df.randomSplit([0.8, 0.2], seed=42)
    return train_df, valid_df


# Pass desired DF to function
train_df, valid_df = train_valid_split(df=prepared_df)
print(f"train: {train_df.count()}, valid: {valid_df.count()}")

# Init params
NUM_CLASSES = int(prepared_df.select("species_id").distinct().count())
BATCH_SIZE = 32
NUM_EPOCHS = 10
print(f"Num classes: {NUM_CLASSES}")

                                                                                

train: 938, valid: 247




Num classes: 5


                                                                                

### prepare DataLoaders

In [None]:
# Create PyTorch Datasets
train_dataset = EmbeddingDataset(spark_df=train_df)
valid_dataset = EmbeddingDataset(spark_df=valid_df)

# DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)

### train model

In [None]:
# Current path
curr_dir = Path(os.getcwd())

# Setup checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor="valid_accuracy",  # Monitor validation accuracy for improvement
    dirpath=curr_dir,  # Directory path for saving checkpoints
    filename="model-{epoch:02d}-{val_accuracy:.2f}",
    save_top_k=1,  # Save only top 1 model
    mode="max",
)

# model
model = MultiClassClassifier(
    num_features=64,
    num_classes=NUM_CLASSES,
)

# Train model
trainer = pl.Trainer(max_epochs=NUM_EPOCHS, callbacks=[checkpoint_callback])
trainer.fit(model, train_dataloader, valid_dataloader)

### load trained model

In [None]:
# Assuming the best model is saved in the same directory as your script
best_model = "model-epoch=08-val_accuracy=0.00.ckpt"
checkpoint_path = curr_dir / Path(best_model)  # Adjust filename as necessary

# Load the checkpoint file
checkpoint = torch.load(
    checkpoint_path, map_location=torch.device("cpu")
)  # Use 'cpu' to avoid GPU memory issues

# Print the keys and any hyperparameters stored in the checkpoint
print(checkpoint.keys())
if "hyper_parameters" in checkpoint:
    print(checkpoint["hyper_parameters"])
else:
    print("No hyperparameters stored in this checkpoint.")

# Load the trained model from checkpoint
model = MultiClassClassifier.load_from_checkpoint(
    checkpoint_path=checkpoint_path,
    num_features=64,
    num_classes=NUM_CLASSES,
)
model.eval()  # Set the model to evaluation mode

### make predictions on validation dataset

In [None]:
def validate_model(model, dataloader):
    model.eval()  # Set the model to evaluation mode
    predictions = []
    labels = []

    with torch.no_grad():  # Turn off gradients for validation, saves memory and computations
        for batch in dataloader:
            inputs, targets = batch  # Adjust these based on how your data is structured
            outputs = model(inputs)
            _, predicted = torch.max(
                outputs.data, 1
            )  # Get the index of the max log-probability
            predictions.extend(predicted.tolist())
            labels.extend(targets.tolist())

    return predictions, labels


# Call the function
predictions, labels = validate_model(model, valid_dataloader)

### evaluate model's performance

In [None]:
from sklearn.metrics import accuracy_score, classification_report

# Accuracy
accuracy = accuracy_score(labels, predictions)

# Target names
target_names = [
    str(row["index"]) for row in train_df.select("index").distinct().collect()
]
report = classification_report(labels, predictions, target_names=target_names)

# Print scores
print(f"Validation Accuracy: {accuracy:.4f}")
print(f"Classification Report:\n{report}")