# Classification with pyspark and TorchDistributor

> **Databricks Notebook:** [End-to-end distributed training with TorchDistributor](https://docs.databricks.com/en/_extras/notebooks/source/deep-learning/torch-distributor-notebook.html)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import datetime
from pathlib import Path
from plantclef.utils import get_spark
from pyspark.sql import functions as F
from pytorch_lightning.callbacks import ModelCheckpoint


spark = get_spark()
display(spark)

  from .autonotebook import tqdm as notebook_tqdm
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/15 14:13:40 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/04/15 14:13:40 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
# Path and dataset names
gcs_path = "gs://dsgt-clef-plantclef-2024/data/process"
dct_emb_train = "training_cropped_resized_v2/dino_dct/data"

# Define the GCS path to the embedding files
dct_gcs_path = f"{gcs_path}/{dct_emb_train}"

# Read the Parquet file into a DataFrame
dct_df = spark.read.parquet(dct_gcs_path)

# Show the data
dct_df.show(n=5, truncate=50)

                                                                                

+--------------------------------------------+----------+--------------------------------------------------+
|                                  image_name|species_id|                                     dct_embedding|
+--------------------------------------------+----------+--------------------------------------------------+
|170e88ca9af457daa1038092479b251c61c64f7d.jpg|   1742956|[-20648.51, 2133.689, -2555.3125, 14820.57, 685...|
|c24a2d8646f5bc7112a39908bd2f6c45bf066a71.jpg|   1356834|[-25395.82, -12564.387, 24736.02, 20483.8, 2115...|
|e1f68e5f05618921969aee2575de20e537e6d66b.jpg|   1563754|[-26178.633, -7670.404, -22552.29, -6563.006, 8...|
|b0433cd6968b57d52e5c25dc45a28e674a25e61e.jpg|   1367432|[-23662.764, -6773.8213, -8283.518, 3769.6064, ...|
|96478a0fe20a41e755b0c8d798690f2c2b7c115f.jpg|   1389010|[-22182.172, -19444.006, 23355.23, 7042.8604, -...|
+--------------------------------------------+----------+--------------------------------------------------+
only showing top 5 

### prepare subset of data for testing end-to-end pipeline

In [4]:
from pyspark.sql import functions as F
from pyspark.sql import DataFrame


def prepare_species_data(
    dct_df: DataFrame,
    limit_species: int = None,
    species_image_count: int = 100,
):
    """
    Prepare species data by filtering, indexing, and joining.

    :param dct_df: DataFrame containing species data
    :param limit_species: Maximum number of species to include (None means no limit)
    :param species_image_count: Minimum number of images per species to include
    :return: DataFrame of filtered and indexed species data
    """
    # Aggregate and filter species based on image count
    grouped_df = (
        dct_df.groupBy("species_id")
        .agg(F.count("species_id").alias("n"))
        .filter(F.col("n") >= species_image_count)
        .orderBy(F.col("n").desc())
        .withColumn("index", F.monotonically_increasing_id())
    ).drop("n")

    # Use broadcast join to optimize smaller DataFrame joining
    filtered_dct_df = dct_df.join(F.broadcast(grouped_df), "species_id", "inner").drop(
        "index"
    )

    # Optionally limit the number of species
    if limit_species:
        limited_grouped_df = (
            (
                grouped_df.orderBy(F.rand(seed=42))
                .limit(limit_species)
                .withColumn("new_index", F.monotonically_increasing_id())
            )
            .drop("index")
            .withColumnRenamed("new_index", "index")
        )

        filtered_dct_df = filtered_dct_df.join(
            F.broadcast(limited_grouped_df), "species_id", "inner"
        )

    return filtered_dct_df

In [5]:
# Params
LIMIT_SPECIES = 5
SPECIES_IMAGE_COUNT = 100

# Call function
prepared_df = prepare_species_data(
    dct_df, limit_species=LIMIT_SPECIES, species_image_count=SPECIES_IMAGE_COUNT
)
print(f"DF count: {prepared_df.count()}")
prepared_df.show()

                                                                                

DF count: 1185


[Stage 27:>                                                         (0 + 1) / 1]

+----------+--------------------+--------------------+-----+
|species_id|          image_name|       dct_embedding|index|
+----------+--------------------+--------------------+-----+
|   1358851|a5a1530acc42ee28a...|[-22140.71, -2232...|    3|
|   1392723|76056d8c5c2eabdae...|[-18462.121, -112...|    4|
|   1360938|aa65bf7e5cbbea170...|[-27158.367, -183...|    0|
|   1392723|ae436ff1f04ca5412...|[-21858.686, -435...|    4|
|   1360938|3d922d3fe00d95887...|[-25446.95, -5724...|    0|
|   1360299|c914a7f8d83a73727...|[-24541.422, 1324...|    1|
|   1360299|5b995de41dc8c507e...|[-26373.861, 1665...|    1|
|   1358851|6ceb22e1e2d2a0560...|[-24388.037, -243...|    3|
|   1358851|360605951bcdd6843...|[-26956.902, -127...|    3|
|   1360938|cc7b5743d897349af...|[-25043.629, -657...|    0|
|   1358851|43a7b8a23a79645ce...|[-24329.762, -147...|    3|
|   1392723|107d18234ccc4bf99...|[-20970.615, 7978...|    4|
|   1358851|ed49aa18677936d8f...|[-17723.512, -340...|    3|
|   1357220|d6edbca4549d

                                                                                

### train/validation split

In [6]:
# Perform a train-validation split
def train_valid_split(df):
    train_df, valid_df = df.randomSplit([0.8, 0.2], seed=42)
    return train_df, valid_df


# Pass desired DF to function
train_df, valid_df = train_valid_split(df=prepared_df)
print(f"train: {train_df.count()}, valid: {valid_df.count()}")



train: 938, valid: 247


                                                                                

## TorchDistributor

In [7]:
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from petastorm import make_reader
from petastorm.pytorch import DataLoader


class EmbeddingDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        embeddings = torch.tensor(self.embeddings[index], dtype=torch.float)
        labels = torch.tensor(self.labels[index], dtype=torch.long)
        return embeddings, labels


class TorchClassifier(pl.LightningModule):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.save_hyperparameters()  # Saves hyperparams in the checkpoints
        self.layer = nn.Linear(num_features, num_classes)

    def forward(self, x):
        return torch.log_softmax(self.layer(x), dim=1)


def petastorm_loader(url, batch_size, num_epochs, workers_count):
    with make_reader(url, num_epochs=num_epochs, workers_count=workers_count) as reader:
        dataloader = DataLoader(reader, batch_size=batch_size)
        for epoch in range(num_epochs):
            for batch in dataloader:
                yield batch


def train_one_epoch(
    model, device, dataloader, optimizer, epoch, log_interval: int = 100
):
    model.train()
    for batch_idx, batch in enumerate(dataloader):
        data = batch["dct_embedding"].to(device)
        target = batch["index"].to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = torch.nn.functional.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(dataloader) * len(data),
                    100.0 * batch_idx / len(dataloader),
                    loss.item(),
                )
            )


def test(log_dir, model, device, dataloader, num_epochs):
    loaded_model = model.to(device)

    # Load model and set to evaluation
    checkpoint = load_checkpoint(log_dir, num_epochs)
    loaded_model.load_state_dict(checkpoint["model"])
    loaded_model.eval()
    test_loss = 0

    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        output = loaded_model(data)
        test_loss += torch.nn.functional.nll_loss(output, target)
    test_loss /= len(dataloader.dataset)
    print("Average test loss: {}".format(test_loss.item()))


def save_checkpoint(log_dir, model, optimizer, epoch):
    filepath = log_dir + "/checkpoint-{epoch}.pth.tar".format(epoch=epoch)
    state = {
        "model": model.module.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(state, filepath)


def load_checkpoint(log_dir, epoch):
    filepath = log_dir + "/checkpoint-{epoch}.pth.tar".format(epoch=epoch)
    return torch.load(filepath)


def create_log_dir(exp_dir):
    now = datetime.datetime.now()
    date_dir = now.strftime("%Y%m%d-%H%M%S")
    log_dir = os.path.join(exp_dir, date_dir)
    os.makedirs(log_dir)
    return log_dir


# For distributed training we will merge the train and test steps into 1 main function
def main_fn(
    directory,
    device,
    train_url,
    valid_url,
    num_features,
    num_classes,
    batch_size,
    num_epochs,
    learning_rate=0.001,
    workers_count=4,
):
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP

    print("Running distributed training")
    dist.init_process_group(backend="nccl" if device.type == "cuda" else "gloo")
    local_rank = int(os.environ["LOCAL_RANK"])
    global_rank = int(os.environ["RANK"])

    model = TorchClassifier(num_features, num_classes).to(device)
    # Add Distributed Model
    if device.type == "cuda":
        ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank)
    else:
        ddp_model = DDP(model)

    optimizer = torch.optim.Adam(ddp_model.parameters(), lr=learning_rate)

    # Petastorm DataLoader
    train_dataloader = petastorm_loader(
        train_url, batch_size, num_epochs, workers_count
    )
    for epoch in range(1, num_epochs + 1):
        train_one_epoch(ddp_model, device, train_dataloader, optimizer, epoch)

        if global_rank == 0:
            save_checkpoint(directory, ddp_model, optimizer, epoch)
            valid_dataloader = petastorm_loader(valid_url, batch_size, 1, workers_count)
            test_loss = test(directory, ddp_model, device, valid_dataloader, num_epochs)

    dist.destroy_process_group()

    return model

### init parameters

In [8]:
# Params
PYTORCH_DIR = Path(os.getcwd()).parents[1] / "experiments"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_features = 64
num_classes = int(prepared_df.select("species_id").distinct().count())
batch_size = 100
num_epochs = 10
log_interval = 100
learning_rate = 0.001
workers_count = os.cpu_count()

print(f"num_classes: {num_classes}")
print(f"devide: {device}")



num_classes: 5
devide: cpu


                                                                                

### get data ready for training

In [13]:
# Get data ready
folder_dir = create_log_dir(exp_dir=PYTORCH_DIR)
print("Experiment is located at: ", folder_dir)

# Prepare data for Petastorm
train_dir = f"/mnt/data/train_data"
valid_dir = f"/mnt/data/valid_data"
train_df.write.mode("overwrite").parquet(train_dir)
valid_df.write.mode("overwrite").parquet(valid_dir)

Experiment is located at:  /home/mgustine/plantclef-2024/experiments/20240415-142947


In [21]:
def get_parquet_file_paths(directory):
    # Check if the directory exists
    if not os.path.exists(directory):
        print("The specified directory does not exist.")
        return []

    # List all files in the directory
    files = os.listdir(directory)
    # Filter for parquet files (if they are directly in the directory, not in subdirectories)
    parquet_files = [file for file in files if file.endswith(".parquet")]

    # Construct full paths
    full_paths = [os.path.join(f"file://{directory}", file) for file in parquet_files]

    return full_paths


train_file_paths = get_parquet_file_paths(train_dir)
valid_file_paths = get_parquet_file_paths(valid_dir)

### train model

In [23]:
from pyspark.ml.torch.distributor import TorchDistributor


# Run TorchDistributor
use_gpu = True if device.type == "cuda" else False
model = TorchDistributor(num_processes=2, local_mode=True, use_gpu=use_gpu).run(
    main_fn,
    folder_dir,
    device,
    train_file_paths,
    valid_file_paths,
    num_features,
    num_classes,
    batch_size,
    num_epochs,
    learning_rate,
    workers_count,
)

Started local training with 2 processes


*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
Running distributed training
Running distributed training
  self._filesystem = pyarrow.localfs
  self._filesystem = pyarrow.localfs
Traceback (most recent call last):
  File "/tmp/tmptso2m9p6/train.py", line 8, in <module>
    output = train_fn(*args, **kwargs)
  File "/tmp/ipykernel_25735/848267020.py", line 137, in main_fn
  File "/tmp/ipykernel_25735/848267020.py", line 45, in train_one_epoch
  File "/tmp/ipykernel_25735/848267020.py", line 34, in petastorm_loader
  File "/home/mgustine/.local/lib/python3.10/site-packages/petastorm/reader.py", line 155, in make_reader
    dataset_metadata.get_schema_from_dataset_url(dataset_url_or_urls, hdfs_driver=hdfs_driver,
  File "/home/mgustine/.local/lib/

RuntimeError: TorchDistributor failed during training.View stdout logs for detailed error message.