# Classification with pyspark and pytorch lightning

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from plantclef.utils import get_spark

spark = get_spark()
display(spark)

In [None]:
# Path and dataset names
gcs_path = "gs://dsgt-clef-plantclef-2024/data/process"
dct_emb_train = "training_cropped_resized_v2/dino_dct/data"

# Define the GCS path to the embedding files
dct_gcs_path = f"{gcs_path}/{dct_emb_train}"

# Read the Parquet file into a DataFrame
dct_df = spark.read.parquet(dct_gcs_path)

# Show the data
dct_df.show(n=5, truncate=50)

### convert to a PyTorch dataset

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader


class EmbeddingDataset(Dataset):
    def __init__(self, data, num_classes):
        self.data = data.toPandas()  # Convert to Pandas DF
        self.num_classes = num_classes  # Total number of classes
        self.species_id = self._get_species_index(data)

    def __len__(self):
        return len(self.data)

    def _get_species_index(self, data):
        species_ids = (
            data.select("species_id").distinct().rdd.map(lambda r: r[0]).collect()
        )
        species_id_to_index = {
            species_id: idx for idx, species_id in enumerate(species_ids)
        }
        return species_id_to_index

    def __getitem__(self, index):
        row = self.data.iloc[index]
        embeddings = torch.tensor(row["dct_embedding"])
        labels = torch.zeros(self.num_classes, dtype=torch.float)
        species_index = self.species_id[row["species_id"]]
        labels[species_index] = 1.0
        return embeddings, labels

In [None]:
import pytorch_lightning as pl
from torch import nn
from torchmetrics import Accuracy, F1Score, Precision, Recall, HammingLoss


class MultiLabelClassifier(pl.LightningModule):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.layer = nn.Linear(num_features, num_classes)
        self.accuracy = Accuracy(threshold=0.5, average="samples", multilabel=True)
        self.f1_score = F1Score(
            threshold=0.5, average="samples", num_classes=num_classes
        )
        self.precision = Precision(
            threshold=0.5, avereage="samples", num_classes=num_classes
        )
        self.recall = Recall(threshold=0.5, average="samples", num_classes=num_classes)
        self.hamming_loss = HammingLoss(threshold=0.5)

    def forward(self, x):
        return torch.sigmoid(
            self.layer(x)
        )  # Using sigmoid for multi-label classification

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.binary_cross_entropy(
            y_hat, y.float()
        )  # Ensure y is float for BCE
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        # Update metrics
        self.log("val_accuracy", self.accuracy(y_hat, y))
        self.log("val_f1", self.f1_score(y_hat, y))
        self.log("val_precision", self.precision(y_hat, y))
        self.log("val_recall", self.recall(y_hat, y))
        self.log("val_hamming_loss", self.hamming_loss(y_hat, y))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.002)

### prepare subset of data

In [None]:
from pyspark.sql import functions as F

# Get a small subset of the dataset
sub_species_df = (
    dct_df.groupBy("species_id")
    .agg(F.count("*").alias("n"))
    .where(F.col("n") > 100)
    .orderBy(F.rand(seed=42))
    .limit(10)
).cache()
sub_species_df.show(truncate=80)

# Collect the species_id into a list of values
species_id_subset = [
    row["species_id"]
    for row in sub_species_df.select("species_id").distinct().collect()
]

# Get subset of images to test pipeline
subset_df = dct_df.where(F.col("species_id").isin(species_id_subset)).cache()
subset_df.show()

# Count number of rows in subset_df
print(f"subset_df count: {subset_df.count()}")

### train-validation split

In [None]:
# Perform a train-validation split
train_df, valid_df = subset_df.randomSplit([0.8, 0.2], seed=42)

# Cache the splits to improve performance
train_df = train_df.cache()
valid_df = valid_df.cache()

print(f"Train DF count: {train_df.count()}")
print(f"Valid DF count: {valid_df.count()}")

### train the model

In [None]:
# Init params
num_classes = 10
batch_size = 32
epochs = 10

# Prepare PyTorch datasets
train_data = EmbeddingDataset(data=train_df, num_classes=num_classes)
valid_data = EmbeddingDataset(data=valid_df, num_classes=num_classes)

# Prepare DataLoaders
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=False)

# model
model = MultiLabelClassifier(
    num_features=64, num_classes=num_classes
)  # Only using 10 classses for testing

trainer = pl.Trainer(max_epochs=epochs)
trainer.fit(model, train_dataloader, valid_dataloader)