# Classification with PyTorch and Petastorm

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import datetime
from pathlib import Path
from pyspark.sql import DataFrame
from plantclef.utils import get_spark
from pyspark.sql import functions as F


spark = get_spark()
display(spark)

In [None]:
# Path and dataset names
gcs_path = "gs://dsgt-clef-plantclef-2024/data/process"
dct_emb_train = "training_cropped_resized_v2/dino_dct/data"

# Define the GCS path to the embedding files
dct_gcs_path = f"{gcs_path}/{dct_emb_train}"

# Read the Parquet file into a DataFrame
dct_df = spark.read.parquet(dct_gcs_path)

# Show the data
dct_df.show(n=5, truncate=50)

### prepare subset of data for testing end-to-end pipeline

In [None]:
def prepare_species_data(
    dct_df: DataFrame,
    limit_species: int = None,
    species_image_count: int = 100,
):
    """
    Prepare species data by filtering, indexing, and joining.

    :param dct_df: DataFrame containing species data
    :param limit_species: Maximum number of species to include (None means no limit)
    :param species_image_count: Minimum number of images per species to include
    :return: DataFrame of filtered and indexed species data
    """
    # Aggregate and filter species based on image count
    grouped_df = (
        dct_df.groupBy("species_id")
        .agg(F.count("species_id").alias("n"))
        .filter(F.col("n") >= species_image_count)
        .orderBy(F.col("n").desc())
        .withColumn("index", F.monotonically_increasing_id())
    ).drop("n")

    # Use broadcast join to optimize smaller DataFrame joining
    filtered_dct_df = dct_df.join(F.broadcast(grouped_df), "species_id", "inner").drop(
        "index"
    )

    # Optionally limit the number of species
    if limit_species:
        limited_grouped_df = (
            (
                grouped_df.orderBy(F.rand(seed=42))
                .limit(limit_species)
                .withColumn("new_index", F.monotonically_increasing_id())
            )
            .drop("index")
            .withColumnRenamed("new_index", "index")
        )

        filtered_dct_df = filtered_dct_df.join(
            F.broadcast(limited_grouped_df), "species_id", "inner"
        )

    return filtered_dct_df

In [None]:
# Params
LIMIT_SPECIES = 5
SPECIES_IMAGE_COUNT = 100

# Call function
prepared_df = prepare_species_data(
    dct_df, limit_species=LIMIT_SPECIES, species_image_count=SPECIES_IMAGE_COUNT
)
print(f"DF count: {prepared_df.count()}")
prepared_df.show()

### train/validation split

In [None]:
# Perform a train-validation-test split
def train_valid_test_split(df, train_split=0.7, valid_split=0.15, test_split=0.15):
    train_df, rest_df = df.randomSplit([train_split, 1 - train_split], seed=42)
    valid_df, test_df = rest_df.randomSplit(
        [
            valid_split / (valid_split + test_split),
            test_split / (valid_split + test_split),
        ],
        seed=42,
    )
    return train_df, valid_df, test_df


# Pass the prepared DataFrame to function
train_df, valid_df, test_df = train_valid_test_split(df=prepared_df)
print(f"train: {train_df.count()}, valid: {valid_df.count()}, test: {test_df.count()}")

### get data ready for training

In [None]:
# # Prepare data for Petastorm
# train_dir = f"/mnt/data/train_data"
# valid_dir = f"/mnt/data/valid_data"
# train_df.write.mode("overwrite").parquet(train_dir)
# valid_df.write.mode("overwrite").parquet(valid_dir)

In [None]:
def get_parquet_file_paths(directory):
    assert os.path.exists(directory)
    files = os.listdir(directory)
    parquet_files = [file for file in files if file.endswith(".parquet")]
    full_paths = [os.path.join(f"file://{directory}", file) for file in parquet_files]
    return full_paths


# # Get URLs for train and valid sets
# train_file_paths = get_parquet_file_paths(train_dir)
# valid_file_paths = get_parquet_file_paths(valid_dir)

In [None]:
# from petastorm.spark import SparkDatasetConverter, make_spark_converter

# # Set a cache directory on DBFS FUSE for intermediate data.
# cache_dir = "file:///mnt/data/petastorm/cache"
# spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, cache_dir)

# converter_train = make_spark_converter(train_df)
# converter_valid = make_spark_converter(valid_df)
# print(f"train: {len(converter_train)}, val: {len(converter_valid)}")

In [None]:
train_df.printSchema()

## TorchDistributor

In [None]:
import torch
import pytorch_lightning as pl
from torch import nn
from petastorm.spark import SparkDatasetConverter, make_spark_converter
from pyspark.ml.functions import vector_to_array
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)


class LitClassifier(pl.LightningModule):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.save_hyperparameters()  # Saves hyperparams in the checkpoints
        self.layer = nn.Linear(num_features, num_classes)
        self.learning_rate = 0.002
        self.accuracy = MulticlassAccuracy(num_classes=num_classes, average="weighted")
        self.f1_score = MulticlassF1Score(num_classes=num_classes, average="weighted")
        self.precision = MulticlassPrecision(
            num_classes=num_classes, average="weighted"
        )
        self.recall = MulticlassRecall(num_classes=num_classes, average="weighted")

    def forward(self, x):
        return torch.log_softmax(self.layer(x), dim=1)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch["features"], batch["label"]
        logits = self(x)
        loss = torch.nn.functional.nll_loss(logits, y)
        self.log("train_loss", loss)
        self.log("train_acc", self.accuracy(logits, y), prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch["features"], batch["label"]
        y_hat = self(x)
        valid_loss = {"valid_loss": torch.nn.functional.nll_loss(y_hat, y)}
        # Update metrics
        self.log(
            "valid_accuracy", self.accuracy(y_hat, y), on_step=False, on_epoch=True
        )
        self.log("valid_f1", self.f1_score(y_hat, y), on_step=False, on_epoch=True)
        self.log(
            "valid_precision", self.precision(y_hat, y), on_step=False, on_epoch=True
        )
        self.log("valid_recall", self.recall(y_hat, y), on_step=False, on_epoch=True)
        return valid_loss

    def test_step(self, batch, batch_idx):
        x, y = batch["features"], batch["label"]
        y_hat = self(x)
        test_loss = {"test_loss": torch.nn.functional.nll_loss(y_hat, y)}
        # Update metrics
        self.log("test_accuracy", self.accuracy(y_hat, y), on_step=False, on_epoch=True)
        self.log("test_f1", self.f1_score(y_hat, y), on_step=False, on_epoch=True)
        self.log(
            "test_precision", self.precision(y_hat, y), on_step=False, on_epoch=True
        )
        self.log("test_recall", self.recall(y_hat, y), on_step=False, on_epoch=True)
        return test_loss

    def predict_step(self, batch, batch_idx):
        x, y = batch["features"]
        logits = self(x)
        probabilities = torch.softmax(logits, dim=1)
        predicted_labels = torch.argmax(probabilities, dim=1)
        return {"probabilities": probabilities, "pred_labels": predicted_labels}


class PetastormDataModule(pl.LightningDataModule):
    def __init__(
        self,
        spark,
        cache_dir,
        train_data,
        valid_data,
        test_data,
        batch_size=32,
        num_partitions=32,
        workers_count=16,
    ):
        super().__init__()
        spark.conf.set(
            SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, Path(cache_dir).as_posix()
        )
        self.spark = spark
        self.train_data = train_data
        self.valid_data = valid_data
        self.test_data = test_data
        self.batch_size = batch_size
        self.num_partitions = num_partitions
        self.workers_count = workers_count

    def _prepare_dataframe(self, df, partitions=32):
        """Prepare the DataFrame for training by ensuring correct types and repartitioning"""
        return (
            df.withColumnRenamed("dct_embedding", "features")
            .withColumnRenamed("index", "label")
            .select("features", "label")
            .repartition(partitions)
        )

    def setup(self, stage=None):
        # setup petastorm data conversion from Spark to PyTorch
        def make_converter(df):
            return make_spark_converter(
                self._prepare_dataframe(df, self.num_partitions)
            )

        self.converter_train = make_converter(self.train_data)
        self.converter_valid = make_converter(self.valid_data)
        self.converter_test = make_converter(self.test_data)

    def _dataloader(self, converter):
        with converter.make_torch_dataloader(
            batch_size=self.batch_size,
            num_epochs=1,
            workers_count=self.workers_count,
        ) as dataloader:
            for batch in dataloader:
                yield batch

    def train_dataloader(self):
        for batch in self._dataloader(self.converter_train):
            yield batch

    def val_dataloader(self):
        for batch in self._dataloader(self.converter_valid):
            yield batch

    def test_dataloader(self):
        for batch in self._dataloader(self.converter_test):
            yield batch

    def predict_dataloader(self):
        if self.converter_predict:
            for batch in self._dataloader(self.converter_predict):
                yield batch
        else:
            raise Exception("No converter for predict")

In [None]:
# lightning model parameters
num_features = 64
num_classes = int(train_df.select("index").distinct().count())

# data module parameters
cache_dir = "file:///mnt/data/tmp"
batch_size = 32
num_partitions = 32
workers_count = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### train model

In [None]:
# define model
model = LitClassifier(
    num_features=num_features,
    num_classes=num_classes,
)

# define data module
data_module = PetastormDataModule(
    spark=spark,
    cache_dir=cache_dir,
    train_data=train_df,
    valid_data=valid_df,
)

In [None]:
# train model
trainer = pl.Trainer(
    max_epochs=10,
    accelerator="cpu",
    reload_dataloaders_every_n_epochs=1,
)

trainer.fit(model, data_module)

In [None]:
# Assuming PetastormDataModule is properly set up for prediction data
predict_dataloader = data_module.predict_dataloader()

# Predict using the trained model
predictions = trainer.predict(model, dataloaders=predict_dataloader)