# Classification with PyTorch and Petastorm

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
from pyspark.sql import DataFrame
from plantclef.utils import get_spark
from pyspark.sql import functions as F


spark = get_spark()
display(spark)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/20 23:56:07 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/04/20 23:56:07 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [3]:
# paths
gcs_path = "gs://dsgt-clef-plantclef-2024/data/process"
embedding_path = "training_cropped_resized_v2/dino_dct/data"
# data module parameters
input_path = f"{gcs_path}/{embedding_path}"
feature_col = "dct_embedding"
limit_species = 5
species_image_count = 100
batch_size = 32
num_partitions = 32
workers_count = 16

In [4]:
from plantclef.baseline.data import PetastormDataModule

# test data module
data_module = PetastormDataModule(
    spark,
    input_path,
    feature_col,
    limit_species,
    species_image_count,
    batch_size,
    num_partitions,
)
data_module.setup()

  from .autonotebook import tqdm as notebook_tqdm
  from pyarrow import LocalFileSystem





                                                                                

+----------+--------------------+--------------------+-----+
|species_id|          image_name|       dct_embedding|index|
+----------+--------------------+--------------------+-----+
|   1358851|a5a1530acc42ee28a...|[-22140.71, -2232...|    3|
|   1392723|76056d8c5c2eabdae...|[-18462.121, -112...|    4|
|   1360938|aa65bf7e5cbbea170...|[-27158.367, -183...|    0|
|   1392723|ae436ff1f04ca5412...|[-21858.686, -435...|    4|
|   1360938|3d922d3fe00d95887...|[-25446.95, -5724...|    0|
+----------+--------------------+--------------------+-----+
only showing top 5 rows



  self._filesystem = pyarrow.localfs
Converting floating-point columns to float32



limit_species: 5, <class 'int'>
num_classes: 5, <class 'int'>



The median size 8434 B (< 50 MB) of the parquet files is too small. Total size: 272576 B. Increase the median file size by calling df.repartition(n) or df.coalesce(n), which might help improve the performance. Parquet files: file:///mnt/data/tmp/20240420235711-appid-local-1713657367922-e7c1ae54-9a13-42bb-80e1-787c2b1a1e79/part-00011-dea129d6-28d1-4593-ae80-8ba4aaaf338b-c000.parquet, ...
Converting floating-point columns to float32
The median size 2956 B (< 50 MB) of the parquet files is too small. Total size: 92294 B. Increase the median file size by calling df.repartition(n) or df.coalesce(n), which might help improve the performance. Parquet files: file:///mnt/data/tmp/20240420235743-appid-local-1713657367922-543dc068-83a2-426b-bfc6-ace5433f914b/part-00003-109dd8fb-8497-40fb-b1d6-d23723064412-c000.parquet, ...


In [5]:
# model parameters
num_features = int(
    len(data_module.train_data.select("dct_embedding").first()["dct_embedding"])
)
num_classes = int(data_module.train_data.select("species_id").distinct().count())
print(f"num_features: {num_features}  num_classes: {num_classes}")



num_features: 64  num_classes: 5


                                                                                

In [None]:
from plantclef.baseline.model import LinearClassifier

# test classifier
model = LinearClassifier(
    num_features,
    num_classes,
)

In [None]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

# train model
trainer = pl.Trainer(
    max_epochs=10,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    reload_dataloaders_every_n_epochs=1,
    default_root_dir="gs://dsgt-clef-plantclef-2024/models/torch-petastorm-v1",
    callbacks=[
        EarlyStopping(monitor="val_loss", mode="min"),
        ModelCheckpoint(monitor="val_loss", save_last=True),
    ],
)
# fit model
trainer.fit(model, data_module)