# Classification with pyspark and TorchDistributor

> **Databricks Notebook:** [End-to-end distributed training with TorchDistributor](https://docs.databricks.com/en/_extras/notebooks/source/deep-learning/torch-distributor-notebook.html)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import datetime
from pathlib import Path
from pyspark.sql import DataFrame
from plantclef.utils import get_spark
from pyspark.sql import functions as F


spark = get_spark()
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/16 17:32:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/04/16 17:32:47 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
# Path and dataset names
gcs_path = "gs://dsgt-clef-plantclef-2024/data/process"
dct_emb_train = "training_cropped_resized_v2/dino_dct/data"

# Define the GCS path to the embedding files
dct_gcs_path = f"{gcs_path}/{dct_emb_train}"

# Read the Parquet file into a DataFrame
dct_df = spark.read.parquet(dct_gcs_path)

# Show the data
dct_df.show(n=5, truncate=50)

                                                                                

+--------------------------------------------+----------+--------------------------------------------------+
|                                  image_name|species_id|                                     dct_embedding|
+--------------------------------------------+----------+--------------------------------------------------+
|170e88ca9af457daa1038092479b251c61c64f7d.jpg|   1742956|[-20648.51, 2133.689, -2555.3125, 14820.57, 685...|
|c24a2d8646f5bc7112a39908bd2f6c45bf066a71.jpg|   1356834|[-25395.82, -12564.387, 24736.02, 20483.8, 2115...|
|e1f68e5f05618921969aee2575de20e537e6d66b.jpg|   1563754|[-26178.633, -7670.404, -22552.29, -6563.006, 8...|
|b0433cd6968b57d52e5c25dc45a28e674a25e61e.jpg|   1367432|[-23662.764, -6773.8213, -8283.518, 3769.6064, ...|
|96478a0fe20a41e755b0c8d798690f2c2b7c115f.jpg|   1389010|[-22182.172, -19444.006, 23355.23, 7042.8604, -...|
+--------------------------------------------+----------+--------------------------------------------------+
only showing top 5 

### prepare subset of data for testing end-to-end pipeline

In [4]:
def prepare_species_data(
    dct_df: DataFrame,
    limit_species: int = None,
    species_image_count: int = 100,
):
    """
    Prepare species data by filtering, indexing, and joining.

    :param dct_df: DataFrame containing species data
    :param limit_species: Maximum number of species to include (None means no limit)
    :param species_image_count: Minimum number of images per species to include
    :return: DataFrame of filtered and indexed species data
    """
    # Aggregate and filter species based on image count
    grouped_df = (
        dct_df.groupBy("species_id")
        .agg(F.count("species_id").alias("n"))
        .filter(F.col("n") >= species_image_count)
        .orderBy(F.col("n").desc())
        .withColumn("index", F.monotonically_increasing_id())
    ).drop("n")

    # Use broadcast join to optimize smaller DataFrame joining
    filtered_dct_df = dct_df.join(F.broadcast(grouped_df), "species_id", "inner").drop(
        "index"
    )

    # Optionally limit the number of species
    if limit_species:
        limited_grouped_df = (
            (
                grouped_df.orderBy(F.rand(seed=42))
                .limit(limit_species)
                .withColumn("new_index", F.monotonically_increasing_id())
            )
            .drop("index")
            .withColumnRenamed("new_index", "index")
        )

        filtered_dct_df = filtered_dct_df.join(
            F.broadcast(limited_grouped_df), "species_id", "inner"
        )

    return filtered_dct_df

In [5]:
# Params
LIMIT_SPECIES = 5
SPECIES_IMAGE_COUNT = 100

# Call function
prepared_df = prepare_species_data(
    dct_df, limit_species=LIMIT_SPECIES, species_image_count=SPECIES_IMAGE_COUNT
)
print(f"DF count: {prepared_df.count()}")
prepared_df.show()

                                                                                

DF count: 1185


[Stage 27:>                                                         (0 + 1) / 1]

+----------+--------------------+--------------------+-----+
|species_id|          image_name|       dct_embedding|index|
+----------+--------------------+--------------------+-----+
|   1358851|a5a1530acc42ee28a...|[-22140.71, -2232...|    3|
|   1392723|76056d8c5c2eabdae...|[-18462.121, -112...|    4|
|   1360938|aa65bf7e5cbbea170...|[-27158.367, -183...|    0|
|   1392723|ae436ff1f04ca5412...|[-21858.686, -435...|    4|
|   1360938|3d922d3fe00d95887...|[-25446.95, -5724...|    0|
|   1360299|c914a7f8d83a73727...|[-24541.422, 1324...|    1|
|   1360299|5b995de41dc8c507e...|[-26373.861, 1665...|    1|
|   1358851|6ceb22e1e2d2a0560...|[-24388.037, -243...|    3|
|   1358851|360605951bcdd6843...|[-26956.902, -127...|    3|
|   1360938|cc7b5743d897349af...|[-25043.629, -657...|    0|
|   1358851|43a7b8a23a79645ce...|[-24329.762, -147...|    3|
|   1392723|107d18234ccc4bf99...|[-20970.615, 7978...|    4|
|   1358851|ed49aa18677936d8f...|[-17723.512, -340...|    3|
|   1357220|d6edbca4549d

                                                                                

### train/validation split

In [6]:
# Perform a train-validation split
def train_valid_split(df):
    train_df, valid_df = df.randomSplit([0.8, 0.2], seed=42)
    return train_df, valid_df


# Pass desired DF to function
train_df, valid_df = train_valid_split(df=prepared_df)
print(f"train: {train_df.count()}, valid: {valid_df.count()}")



train: 938, valid: 247


                                                                                

### get data ready for training

In [7]:
# # Prepare data for Petastorm
# train_dir = f"/mnt/data/train_data"
# valid_dir = f"/mnt/data/valid_data"
# train_df.write.mode("overwrite").parquet(train_dir)
# valid_df.write.mode("overwrite").parquet(valid_dir)

In [8]:
def get_parquet_file_paths(directory):
    assert os.path.exists(directory)
    files = os.listdir(directory)
    parquet_files = [file for file in files if file.endswith(".parquet")]
    full_paths = [os.path.join(f"file://{directory}", file) for file in parquet_files]
    return full_paths


# # Get URLs for train and valid sets
# train_file_paths = get_parquet_file_paths(train_dir)
# valid_file_paths = get_parquet_file_paths(valid_dir)

In [9]:
# from petastorm.spark import SparkDatasetConverter, make_spark_converter

# # Set a cache directory on DBFS FUSE for intermediate data.
# cache_dir = "file:///mnt/data/petastorm/cache"
# spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, cache_dir)

# converter_train = make_spark_converter(train_df)
# converter_valid = make_spark_converter(valid_df)
# print(f"train: {len(converter_train)}, val: {len(converter_valid)}")

## TorchDistributor

In [10]:
import torch
import pytorch_lightning as pl
from torch import nn
from petastorm.spark import SparkDatasetConverter, make_spark_converter

from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)


class LitClassifier(pl.LightningModule):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.save_hyperparameters()  # Saves hyperparams in the checkpoints
        self.layer = nn.Linear(num_features, num_classes)
        self.learning_rate = 0.002
        self.accuracy = MulticlassAccuracy(num_classes=num_classes, average="weighted")
        self.f1_score = MulticlassF1Score(num_classes=num_classes, average="weighted")
        self.precision = MulticlassPrecision(
            num_classes=num_classes, average="weighted"
        )
        self.recall = MulticlassRecall(num_classes=num_classes, average="weighted")

    def forward(self, x):
        return torch.log_softmax(self.layer(x), dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = torch.nn.functional.nll_loss(logits, y)
        self.log("train_loss", loss)
        self.log("train_acc", self.accuracy(logits, y), prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        valid_loss = {"valid_loss": torch.nn.functional.nll_loss(y_hat, y)}
        # Update metrics
        self.log(
            "valid_accuracy", self.accuracy(y_hat, y), on_step=False, on_epoch=True
        )
        self.log("valid_f1", self.f1_score(y_hat, y), on_step=False, on_epoch=True)
        self.log(
            "valid_precision", self.precision(y_hat, y), on_step=False, on_epoch=True
        )
        self.log("valid_recall", self.recall(y_hat, y), on_step=False, on_epoch=True)
        return valid_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer


class PetastormDataModule(pl.LightningDataModule):
    def __init__(
        self,
        spark,
        cache_dir,
        train_data,
        valid_data,
        batch_size=32,
        num_partitions=32,
        workers_count=16,
    ):
        super().__init__()
        spark.conf.set(
            SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, Path(cache_dir).as_posix()
        )
        self.spark = spark
        self.train_data = train_data
        self.valid_data = valid_data
        self.batch_size = batch_size
        self.num_partitions = num_partitions
        self.workers_count = workers_count

    def _prepare_dataframe(self, df, partitions=32):
        """Prepare the DataFrame for training by ensuring correct types and repartitioning"""
        return (
            df.withColumnRenamed("dct_embedding", "features")
            .withColumnRenamed("index", "label")
            .repartition(partitions)
        )

    def setup(self, stage=None):
        # setup petastorm data conversion from Spark to PyTorch
        def make_converter(df):
            return make_spark_converter(
                self._prepare_dataframe(df, self.num_partitions).select(
                    "features", "label"
                )
            )

        self.converter_train = make_converter(self.train_data)
        self.converter_valid = make_converter(self.valid_data)

    def _dataloader(self, converter):
        with converter.make_torch_dataloader(
            batch_size=self.batch_size,
            num_epochs=1,
            workers_count=self.workers_count,
        ) as dataloader:
            for batch in dataloader:
                yield batch

    def train_dataloader(self):
        for batch in self._dataloader(self.converter_train):
            yield batch

    def val_dataloader(self):
        for batch in self._dataloader(self.converter_valid):
            yield batch

    def teardown(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_reader.stop()
            self.valid_reader.stop()

  from .autonotebook import tqdm as notebook_tqdm
  from pyarrow import LocalFileSystem


In [11]:
# lightning model parameters
num_features = 64
num_classes = int(train_df.select("index").distinct().count())

# data module parameters
cache_dir = "file:///mnt/data/tmp"
batch_size = 32
num_partitions = 32
workers_count = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

                                                                                

### train model

In [12]:
# define model
model = LitClassifier(
    num_features=num_features,
    num_classes=num_classes,
)

# define data module
data_module = PetastormDataModule(
    spark=spark,
    cache_dir=cache_dir,
    train_data=train_df,
    valid_data=valid_df,
)
# data_module.setup()

In [13]:
# train model
trainer = pl.Trainer(
    max_epochs=10,
    accelerator="cpu",
)

trainer.fit(model, data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/mgustine/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
  self._filesystem = pyarrow.localfs
Converting floating-point columns to float32
The median size 8434 B (< 50 MB) of the parquet files is too small. Total size: 272576 B. Increase the median file size by calling df.repartition(n) or df.coalesce(n), which might help improve the performance. Parquet files: file:///mnt

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

TypeError: ParquetDataset.__init__() got an unexpected keyword argument 'validate_schema'

In [None]:
from pyspark.ml.torch.distributor import TorchDistributor


# Run TorchDistributor
use_gpu = True if device.type == "cuda" else False
# model = TorchDistributor(num_processes=2, local_mode=True, use_gpu=use_gpu).run()