# Train Classifier

Training a classification model for marketplace images using transfer learning:

- Frozen Backbone: Uses a pre-trained DINOv2 (Vision Transformer) as feature extractor
- Classification Head: Trains a linear or MLP classification head on top of DINOv2 embeddings
- MLflow Tracking: Logs all training metrics, hyperparameters, and model checkpoints


In [1]:
import os
import sys
from pathlib import Path

from dotenv import load_dotenv

notebooks_dir = Path().absolute()
project_dir = notebooks_dir.parent
os.chdir(project_dir)
load_dotenv()
sys.path.append(project_dir)

In [2]:
import warnings

import lightning.pytorch as pl
import numpy as np
import torch
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import MLFlowLogger

from dataset.dataset import MarketplaceDataModule
from dataset.model import DinoV2Classification

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
warnings.filterwarnings("ignore", message="xFormers is available")

### Train

In [5]:
import mlflow

print(f"MLflow Tracking URI: {mlflow.get_tracking_uri()}")
with mlflow.start_run():
    print("‚úì Successfully connected to MLflow!")

MLflow Tracking URI: http://192.168.30.108:5000/
‚úì Successfully connected to MLflow!
üèÉ View run luminous-lark-134 at: http://192.168.30.108:5000/#/experiments/2/runs/6be4554a4e58446fab11fe136c553286
üß™ View experiment at: http://192.168.30.108:5000/#/experiments/2


In [6]:
def run_training(
    model: pl.LightningModule,
    data_module: pl.LightningDataModule,
    max_epochs: int,
    precision: str,
):
    """
    Configures and runs a Lightning training process.
    """

    # initialize trainer
    mlf_logger = MLFlowLogger(
        experiment_name="marketplace-image-rag", 
        # tracking_uri="file:./ml-runs"
    )
    early_stop = EarlyStopping(
        monitor="val_accuracy",
        patience=5,
        mode="max",
    )
    checkpoint = ModelCheckpoint(
        monitor="val_accuracy",
        mode="max",
        save_top_k=3,
        dirpath="checkpoints",
        filename="dinov2-classification-{epoch:02d}-{val_accuracy:.4f}",
    )
    trainer = pl.Trainer(
        max_epochs=max_epochs,  # maximum number of training epochs
        precision=precision,  # numerical precision for the training process
        logger=mlf_logger,  # logging training metrics to mlflow
        enable_progress_bar=True,  # show training progress
        callbacks=[early_stop, checkpoint],
    )

    # start the training process
    trainer.fit(model, data_module)

    return trainer

In [None]:
# train classification head on tensor cores (speedup) 
torch.set_float32_matmul_precision('medium')

data_module = MarketplaceDataModule(batch_size=64)
data_module.setup("fit")
num_classes = len(data_module.train_dataset.dataset.classes)

model = DinoV2Classification(num_classes=num_classes)

# Execute the training
trainer = run_training(
    model=model,
    data_module=data_module,
    max_epochs=5,
    # evaluate backbone on tensor cores (speedup)
    precision="16-mixed",
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type               | Params | Mode  | FLOPs
--------------------------------------------------------------------
0 | model        | Sequential         | 86.8 M | train | 0    
1 | loss_fn      | CrossEntropyLoss   | 0      | train | 0    
2 | val_accuracy | MulticlassAccuracy | 0      | train | 0    
--------------------------------------------------------------------
171 K     Trainable params
86.6 M    Non-trainable params
86.8 M    Total params
347.020   Total estimated model params size (MB)
204       Modules in train mode
0         Modules in eval mode
0         Total Flops


Epoch 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1519/1519 [08:07<00:00 [00:10<08:38,  3.12it/s, v_num=6330, val_loss=1.340, val_accuracy=0.679]  

`Trainer.fit` stopped: `max_epochs=10` reached.


üèÉ View run crawling-grub-159 at: http://192.168.30.108:5000/#/experiments/2/runs/6e10f2dae99c4f88ae8782159af66330
üß™ View experiment at: http://192.168.30.108:5000/#/experiments/2
