# Step Recipe Classifier

#### Fine-Tuning CLIP for Image Classification

This notebook implements a complete pipeline for fine-tuning OpenAI's [CLIP](https://github.com/openai/CLIP) model on step recipe images. The pipeline includes balanced data loading, augmentation, training with validation, and inference.

#### Modular Components:

- **TrainingConfig**: Central configuration for hyperparameters, paths, and grid search settings
- **DataModule**: Handles dataset loading, preprocessing, and balanced batch sampling
- **CLIPModule**: Manages model initialization and training setup
- **CLIPTrainer**: Orchestrates training loop, validation, and metric computation
- **Utility Functions**: Standalone tools for monitoring, saving, and loading models

#### Key Features:

- **Data Management**: Custom dataset handling with balanced sampling and augmentation
- **Model Training**: CLIP fine-tuning with early stopping and metric tracking
- **Experiment Tracking**: MLflow integration for logging metrics and artifacts
- **Inference**: Production-ready inference pipeline with optimized model loading

*Author: Alejandro Guirau*  
*Last Updated: February 2025*

### Install Libraries

In [0]:
!pip install -q --upgrade pip

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
!pip install -U -q torch torchmetrics torchvision

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
# Install CLIP from GitHub repo
!pip install -q git+https://github.com/openai/CLIP.git
!pip show clip

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m
Name: clip
Version: 1.0
Summary: 
Home-page: 
Author: OpenAI
Author-email: 
License: 
Location: /local_disk0/.ephemeral_nfs/envs/pythonEnv-2270c8be-a312-460a-8f71-8d9ecb5723c8/lib/python3.11/site-packages
Requires: ftfy, packaging, regex, torch, torchvision, tqdm
Required-by: 


### Utility Functions

#### Model Utils

Utility functions for model management, accuracy computation, and dataset analysis.

1. `convert_models_to_fp32`: Converts parameters to 32-bit floating point format.

2. `compute_per_class_accuracy`: Calculates accuracy for each class independently. Uses a confusion matrix to handle multi-class scenarios.

3. `class_distribution`: Analyzes dataset balance by calculating the percentage of samples per class.

In [0]:
import torch
import torchmetrics
from collections import Counter
from torch.utils.data import Dataset

def convert_models_to_fp32(model: torch.nn.Module):
    """Convert model parms and grads to fp32."""
    for p in model.parameters():
        p.data = p.data.float()
        if p.requires_grad:
            p.grad.data = p.grad.data.float()

def compute_per_class_accuracy(pred_labels: torch.Tensor, true_labels: torch.Tensor, num_classes: int, class_names: list[str]) -> dict[str, float]:
    """
    Calculate per-class accuracy using a confusion matrix.
    
    Args:
        pred_labels: Predicted class labels
        true_labels: Ground truth labels
        num_classes: Total number of classes
        class_names: List of class names

    Returns:
        Dictionary of per-class accuracy
    """
    # Compute confusion matrix
    confusion_matrix = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=num_classes)
    confusion_matrix = confusion_matrix(pred_labels, true_labels)

    # Extract true positives (diag) and total sampes per class (row sum)
    true_positives = confusion_matrix.diag()
    total_samples_per_class = confusion_matrix.sum(dim=-1)

    # Calculate per-class accuracy
    per_class_accuracy = {
        class_names[i]: (true_positives[i] / total_samples_per_class[i].item() * 100) 
        if total_samples_per_class[i] > 0 else float("nan") # Handle division by zero
        for i in range(num_classes)
    }

    return per_class_accuracy

def class_distribution(dataset: Dataset = None, labels: torch.Tensor = None, class_names: list[str] = None, dataset_name: str = "") -> list[tuple[str, float]]:
    """
    Displays and returns class distribution as percentages.

    Args:
        dataset: The dataset to compute the distribution for.
        labels: Tensor of labels to compute the distribution for.
        class_names: List of class names for display.
        dataset_name: Name of the dataset for display.

    Returns:
        List of tuples (class_name, percentage)
    """
    if class_names is None:
        raise ValueError("class_names must be provided.")

    # Count ocurrences of each class
    if dataset is not None:
        label_counts = Counter()
        # Iterate through the dataset to count labels
        for _, label in dataset:
            label_counts[label] += 1
        total_samples = sum(label_counts.values())
        class_counts = [label_counts.get(i, 0) for i in range(len(class_names))]
    elif labels is not None:
        class_counts = torch.bincount(labels, minlength=len(class_names)).numpy()
        total_samples = class_counts.sum()
    else:
        raise ValueError("Either dataset or labels must be provided.")

    # Create list of tuples (class_name, percentage)
    distribution = [
        (class_names[i], (class_counts[i] / total_samples) * 100)
        for i in range(len(class_names))
        if class_counts[i] > 0  # Exclude classes with zero samples
    ]

    # Sort by percentage in descending order
    distribution.sort(key=lambda x: x[1], reverse=True)

    print(f"{dataset_name} dataset -- Class distribution:")
    for class_name, percentage in distribution:
        print(f"{class_name} -- {percentage:.2f}%")
    print("\n")

#### Logging Utils

Utility functions for visualizing model performance through confusion matrices and dataset samples. Logged as artifacts in MLflow experiments.

1. `log_confusion_matrix`: Creates normalized confusion matrices to visualize model classification performance.

2. `log_random_image`: Samples and visualizes random images from the datases.


In [0]:
import os
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

import mlflow

def log_confusion_matrix(true_labels: torch.Tensor, pred_labels: torch.Tensor, class_names: list[str], epoch: int) -> None:
    """
    Logs a normalized confusion matrix plot to MLflow.
    
    Args:
        true_labels: Ground truth labels tensor
        pred_labels: Predicted labels tensor
        class_names: List of class names
        epoch: Current epoch
    """
    # Compute confusion matrix
    cm = confusion_matrix(true_labels.cpu().numpy(), pred_labels.cpu().numpy(), labels=range(len(class_names)))

    # Normalize confusion matrix
    with np.errstate(divide="ignore", invalid="ignore"):
        cm = cm.astype("float") / cm.sum(axis=1, keepdims=True)
        cm[np.isnan(cm)] = 0  # Replace NaN with 0

    # Plot confusion matrix
    fig, ax = plt.subplots(figsize=(20, 20))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation=90)

    # Format numbers to 2 decimal places
    for texts in ax.texts:
        text = texts.get_text()
        texts.set_text(f"{float(text):.2f}")

    plt.title(f"Confusion Matrix (Epoch {epoch})")
    plt.xlabel("Predicted Labels")
    plt.ylabel("True Labels")
    
    # Save the plot temporarily
    temp_dir = "/tmp/confusion_matrices"
    os.makedirs(temp_dir, exist_ok=True)
    plot_path = os.path.join(temp_dir, f"{epoch:02}_confusion_matrix.png")
    plt.savefig(plot_path, bbox_inches="tight")
    plt.close(fig)

    # Log plot to MLflow
    mlflow.log_artifact(plot_path, artifact_path="confusion_matrices")

def log_random_image(dataset: Dataset, class_names: list[str], artifact_path: str = "augmented_images/train", counter: int = 1) -> None:
    """
    Log a random image from the dataset with its class label.
    
    Args:
        dataset: Dataset containing images and labels
        class_names: List of class names for display
        artifact_path: MLflow artifact diretory path
        counter: Image counter for filename
    """
    # Randomly select image
    idx = random.randint(0, len(dataset) - 1)
    image, label = dataset[idx]

    # Convert image to displayable format (PIL)
    if isinstance(image, torch.Tensor):
        image = T.functional.to_pil_image(image)

    # Plot the image
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(image)
    ax.set_title(f"Class: {class_names[label]}")
    ax.axis("off")

    # Save the plot temporarily
    temp_dir = f"/tmp/{artifact_path}"
    os.makedirs(temp_dir, exist_ok=True)
    plot_path = os.path.join(temp_dir, f"{counter:02d}_augmented_image.png")
    plt.savefig(plot_path, bbox_inches="tight")
    plt.close(fig)

    # Log the image to MLflow
    mlflow.log_artifact(plot_path, artifact_path=artifact_path)

Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0xfffee64b4360>
Traceback (most recent call last):
  File "/databricks/python/lib/python3.11/site-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/databricks/python/lib/python3.11/site-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/databricks/python/lib/python3.11/site-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
                   ^^^^^^^^^^^^^^^^^^
  File "/databricks/python/lib/python3.11/site-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
             ^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'split'


#### Memory Monitoring Utils

Utility function for tracking GPU and CPU memory usage during model training


In [0]:
import psutil

def log_memory_usage(epoch, phase="Train"):
    """Logs memory usage for debugging."""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024 ** 3)  # in GB
        reserved = torch.cuda.memory_reserved() / (1024 ** 3)  # in GB
        print(f"[{phase} Epoch {epoch}] GPU Memory: Allocated={allocated:.2f}GB, Reserved={reserved:.2f}GB")
    else:
        process = psutil.Process()
        memory_info = process.memory_info()
        rss = memory_info.rss / (1024 ** 3)  # in GB
        print(f"[{phase} Epoch {epoch}] CPU Memory: RSS={rss:.2f}GB")

### Early Stopping 
Implements early stopping functionality to prevent overfitting during model training.

It tracks validation metrics, saves the best model checkpoint, and stops training when there hasn't been an improvement for a specified number of epochs.

In [0]:
class EarlyStopping:
    """
    Early stopping to terminate training when validation metric stagnates.
    Saves best model checkpoint and tracks improvement over epochs.
    """
    def __init__(self, patience: int = 5, mode: str = "min", delta: float = 0, verbose: bool = True):
        """
        Initialize early stopping parameters.
        
        Args:
            patience: How many epochs to wait before stopping.
            mode: "min" for minimizing, "max" for maximizing.
            delta: Minimum change to qualify as an improvement.
            verbose: Whether to print verbose output.
        """
        self.patience = patience
        self.mode = mode
        self.delta = delta
        self.verbose = verbose
        self.best_score = -float("inf") if self.mode == "max" else float("inf")
        self.best_epoch = None
        self.early_stop = None
        self.epochs_no_improve = 0

    def __call__(self, metric: float, model: torch.nn.Module, optimizer: torch.optim.Optimizer,
                 path: str, epoch: int, additional_info: dict = None) -> None:
        """
        Checks early stopping conditions and saves the model if metric improves.

        Args:
            metric: Current validation metric value
            model: Model to save if metric improves
            optimizer: Optimizer state to save
            path: Path to save the model checkpoint
            epoch: Current epoch number
            additional_info (optional): Additional information to save
        """
        score = metric if self.mode == "max" else -metric

        # Check if metric improved
        improvement = False
        if self.mode == "min":
            improvement = score < self.best_score - self.delta
        elif self.mode == "max":
            improvement = score > self.best_score + self.delta
        else:
            raise ValueError(f"Mode {self.mode} is not supported. Use 'min' or 'max'.")

        if self.best_score is None:
            self.best_score = score
            self.best_epoch = epoch
            self.save_checkpoint(metric, model, optimizer, path, epoch, additional_info)
        elif improvement:
            self.best_score = score
            self.best_epoch = epoch
            self.save_checkpoint(metric, model, optimizer, path, epoch, additional_info)
            self.epochs_no_improve = 0
        else:
            self.epochs_no_improve += 1
            if self.epochs_no_improve >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, metric: float, model: torch.nn.Module, optimizer: torch.optim.Optimizer,
                        path: str, epoch: int, additional_info: dict = None) -> None:
        """
        Save model when validation metric improves.
        
        Args:
            metric: Current validation metric value
            model: Model to save if metric improves
            optimizer: Optimizer state to save
            path: Path to save the model checkpoint
            epoch: Current epoch number
            additional_info (optional): Additional information to save
        """
        if self.verbose:
            print(f"[Best model] Metric improved ({self.best_score:.4f} --> {metric:.4f}). Saving model...")

        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "metric": metric,
        }

        # Add any additional info
        if additional_info:
            checkpoint.update(additional_info)

        # Save checkpoint
        torch.save(checkpoint, path)

    def reset(self) -> None:
        """Reset the early stopping variables."""
        self.best_score = -float("inf") if self.mode == "max" else float("inf")
        self.best_epoch = None
        self.epochs_no_improve = 0
        self.early_stop = False

### Data Augmentation

Provides custom data augmentation and dataset oversampling functionality for CLIP fine-tuning.
Augmentations focus on geometric transformations while preserving color information.

1. `augment` function: Applies CLIP-specific data augmentation. Uses geometric transformations (crop, affine, flip, perspective), and preserves color information for CLIP compatibility

2. `AugmentedOversampledDataset` class: Wraps an existing dataset to handle class imbalance, dynamically generating augmented samples for minority classes.

In [0]:
import random

from PIL import Image

import torchvision.transforms as T

def augment(image: Image.Image, seed: int = 42) -> Image.Image:
    """
    Apply data augmentation appropriate for CLIP fine-tuning.
    
    Focuses on geometric transformations while preserving color information
    since CLIP's preprocessing already handles color normalization.
    
    Args:
        image: Input PIL image
        seed: Random seed for reproducibility
        
    Returns:
        Augmented PIL image
    """      
    torch.manual_seed(seed)

    # Random resized crop (maintains most of original content while adding variety)
    random_resized_crop = T.RandomResizedCrop(
        size=(224, 224), # CLIP's input size
        scale=(0.8, 1.0),
        ratio=(0.9, 1.1) 
    )
    image = random_resized_crop(image)

    # RandomAffine
    random_affine = T.RandomAffine(
        degrees=15,
        translate=(0.1, 0.1),
        scale=(0.9, 1.1),
        shear=5
    )
    image = random_affine(image)

    # Horizontal Flip (if semantically appropriate for your recipe images)
    if random.random() > 0.5:
        image = T.functional.hflip(image)

    # Random perspective (subtle, to simulate different viewing angles)
    perspective_transform = T.RandomPerspective(
        distortion_scale=0.2,  # Low distortion
        p=0.5
    )
    image = perspective_transform(image)

    return image

class AugmentedOversampledDataset(Dataset):
    """Custom dataset wrapper that oversamples underrepresented classes by duplicating and augmenting their samples."""
    def __init__(self, subset_dataset: Dataset, original_dataset: Dataset, target_samples_per_class: int, seed: int = 42):
        """
        Initialize the augmented dataset.

        Args:
            subset_dataset: Dataset to augment.
            original_dataset: The original ImageFolder dataset (to access class names).
            target_samples_per_class: The target number of samples per class.
            seed: Random seed for reproducibility.
        """
        self.subset_dataset = subset_dataset
        self.original_dataset = original_dataset
        self.target_samples_per_class = target_samples_per_class
        self.seed = seed
        self.augmented_samples = []

        # Oversample and augment underrepresented classes
        self._create_augmented_samples()

    def _create_augmented_samples(self) -> None:
        """Creates augmented samples for classes with fewer than target samples."""
        # Group samples by class
        samples_by_class = {class_idx: [] for class_idx in range(len(self.original_dataset.classes))}
        for idx in self.subset_dataset.indices:
            image, label = self.original_dataset[idx]
            samples_by_class[label].append((image, label))

        # Duplicate and augment underrepresented classes
        for class_idx, samples in samples_by_class.items():
            current_count = len(samples)
            class_name = self.original_dataset.classes[class_idx]

            if current_count < self.target_samples_per_class:
                print(f"Class {class_name}: {current_count} samples -> Target: {self.target_samples_per_class}")
                
                additional_samples = []
                for _ in range (self.target_samples_per_class - current_count):
                    # Randomly pick a sample and apply augmentation
                    original_image, label = random.choice(samples)
                    augmented_image = augment(original_image, seed=self.seed)
                    additional_samples.append((augmented_image, label))

                self.augmented_samples.extend(additional_samples)
            else:
                print(f"Class {class_name}: {current_count} samples (No augmentation needed)")


    def __len__(self) -> int:
        """Retruns total length including original and augmented samples."""
        return len(self.subset_dataset) + len(self.augmented_samples)
    
    def __getitem__(self, idx: int) -> tuple[Image.Image, int]:
        """Returns an image/label pair from either original or augmented samples."""
        if idx < len(self.subset_dataset):
            # Get original sample
            image, label = self.original_dataset[self.subset_dataset.indices[idx]]
        else:
            # Get augmented sample
            image, label = self.augmented_samples[idx - len(self.subset_dataset)]
        return image, label


### Training Configuration
Module to handle configuration parameters for CLIP model training. Includes training hyperparameters, data paths, and checkpoint settings.

In [0]:
# -------------
# Configuration
# -------------

from dataclasses import dataclass

@dataclass
class TrainingConfig:
    """Configuration for CLIP model training."""
    # Training parameters
    BATCH_SIZE: int = 16
    NUM_EPOCHS: int = 2
    LEARNING_RATE: float = 1e-7
    WEIGHT_DECAY: float = 1e-4

    # Data parameters
    root_dir: str = "<PATH>"
    DATASET_PATH: str = f"{root_dir}recipe_classifier/dataset/2024_11_25/"
    DATASET_NAME: str = "dataset_sample"

    # Checkpoint parameters
    SAVE_INTERVAL: int = 10
    CHECKPOINT_SAVE_PATH: str = f"{root_dir}recipe_classifier/checkpoints/2025_01_06/dataset_sample/"
    best_model_path: str = f"{CHECKPOINT_SAVE_PATH}best_model.pth"


    def to_dict(self) -> dict[str, int | float]:
        """Convert config to dictionary for logging."""
        return {
            "BATCH_SIZE": self.BATCH_SIZE,
            "NUM_EPOCHS": self.NUM_EPOCHS,
            "LEARNING_RATE": self.LEARNING_RATE,
            "WEIGHT_DECAY": self.WEIGHT_DECAY,
        }

### Data Module
Module to handle dataset operations including loading, preprocessing, augmentation, and DataLoader creation with balanced batch sampling.

In [0]:
# -------------
# Data Module
# -------------

import torchvision
from torch.utils.data import DataLoader, Subset
from balanced_batch_sampler import BalancedBatchSampler

class DataModule:
    """
    Module for handling dataset loading, preprocessing, train/test splitting,
    augmentation, and creation of balanced DataLoaders.

    Attributes:
        config (TrainingConfig): Configuration module containing data and training parameters.
        dataset (torchvision.datasets.ImageFolder): The loaded dataset.
    """
    def __init__(self, config: TrainingConfig):
        self.config = config
        self.dataset = None

    def load_dataset(self, preprocess: callable) -> torchvision.datasets.ImageFolder:
        """
        Load and preprocess the dataset.

        Args:
            preprocess: CLIP model preprocessing function

        Returns:
            Loaded and preprocessed ImageFolder dataset
        """
        self.dataset = torchvision.datasets.ImageFolder(
            f"{self.config.DATASET_PATH}{self.config.DATASET_NAME}/", transform=preprocess
            )
        return self.dataset
    
    def prepare_data(self,
                     dataset: Dataset,
                     train_indices: list[int],
                     test_indices: list[int],
                     target_samples_per_class: int = None
                     ) -> tuple[Dataset, Dataset]:
        """
        Prepare train/test datasets with optional augmentation and oversampling.

        Args:
            dataset: The complete dataset.
            train_indices: Indices for the train dataset.
            test_indices: Indices for the test dataset.
            target_samples_per_class (optional): The number of samples to oversample each class. Defaults to no oversampling.

        Returns:
            Tuple of (train_dataset, test_dataset)
        """
        train_dataset = Subset(dataset, train_indices)
        test_dataset = Subset(dataset, test_indices)

        print("---- Before Augmentation ----")
        print(f"Train dataset size: {len(train_dataset)}")
        print(f"Test dataset size: {len(test_dataset)}")

        # Apply augmentation and oversampling
        if target_samples_per_class:
            train_dataset = AugmentedOversampledDataset(train_dataset, dataset, target_samples_per_class)

        print("\n---- After Augmentation ----")
        print(f"Train dataset size: {len(train_dataset)}")
        print(f"Test dataset size: {len(test_dataset)}\n")

        return train_dataset, test_dataset
    
    def get_labels(self, train_indices: list[int], test_indices: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Extract labels from train and test datasets.
        
        Args:
            train_indices: Indices for the training dataset.
            test_indices: Indices for the test dataset.

        Returns:
            Tuple with labels for (train, test) datasets.
        """
        train_labels = torch.tensor([self.dataset.targets[i] for i in train_indices])
        test_labels = torch.tensor([self.dataset.targets[i] for i in test_indices])
        return train_labels, test_labels
    
    def create_dataloaders(self, 
                           train_dataset: Dataset, 
                           test_dataset: Dataset, 
                           train_indices: list[int], 
                           test_indices: list[int]
                           ) -> tuple[DataLoader, DataLoader]:
        """
        Create DataLoaders with balanced batch sampling for train/test datasets.
        
        Args:
            train_dataset: The train dataset.
            test_dataset: The test dataset.
            train_indices: Indices for the training dataset.
            test_indices: Indices for the test dataset.

        Returns:
            Tuple of (train_dataloader, test_dataloader)
        """
        # Get labels
        train_labels, test_labels = self.get_labels(train_indices, test_indices)

        # Create samplers
        train_sampler = BalancedBatchSampler(labels=train_labels, n_classes=self.config.BATCH_SIZE, n_samples=1)
        test_sampler = BalancedBatchSampler(labels=test_labels, n_classes=self.config.BATCH_SIZE, n_samples=1)

        # Create dataloaders
        train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler)
        test_dataloader = DataLoader(test_dataset, batch_sampler=test_sampler)

        return train_dataloader, test_dataloader

### Model Setup Module
Manages CLIP model initialization, configuration, and training setup including loss functions and optimizer initialization.

In [0]:
# -------------
# Model Setup
# -------------

from pathlib import Path
import clip

class CLIPModule:
    """
    Module for handling the CLIP model loading and training setup.

    Attributes:
        config (TrainingConfig): Configuration module containing data and training parameters.
        model (torch.nn.Module): The loaded CLIP model.
        preprocess (callable): Preprocessing function for the CLIP model.
        device (str): The device on which the model is loaded (cuda:0 or cpu).
    """
    def __init__(self, config: TrainingConfig):
        self.config = config
        self.model = None
        self.preprocess = None
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

    def load_model(self, model_name: str = "ViT-B/32") -> tuple[torch.nn.Module, callable]:
        """
        Load the CLIP model and preprocessing function.
        
        Args:
            model_name (optional): The name of the CLIP model to load. Defaults to "ViT-B/32".

        Returns:
            Tuple of (model, preprocess_function)
        """
        model, preprocess = clip.load(model_name, device=self.device, jit=False) # jit=False to disable TorchScript for fine-tuning

        if self.device == "cpu":
            model.float() # Converts model to fp32
        else:
            clip.model.convert_weights(model) # Converts model to fp16 (unnecessary since CLIP already uses fp16 by default)

        # Model checkpoints
        weights_path = Path(self.config.CHECKPOINT_SAVE_PATH)
        weights_path.mkdir(exist_ok=True, parents=True)
        
        self.model = model
        self.preprocess = preprocess
        return model, preprocess
    
    def setup_training(self) -> tuple[torch.nn.Module, torch.nn.Module, torch.optim.AdamW]:
        """
        Setup the training components: loss functions and optimizer.

        Returns:
            Tuple of (image_loss, text_loss, optimizer)
        """
        # Loss functions
        loss_img = torch.nn.CrossEntropyLoss()
        loss_txt = torch.nn.CrossEntropyLoss(ignore_index=-1)
        
        # Optimizer
        params = [p for p in self.model.parameters() if p.requires_grad]
        optimizer = torch.optim.AdamW(params, lr=self.config.LEARNING_RATE, weight_decay=self.config.WEIGHT_DECAY)

        return loss_img, loss_txt, optimizer

### Model Training Module
Trainer module that handles CLIP model training, validation, metrics tracking, and MLflow logging.

In [0]:
# -------------
# Model Trainer
# -------------

import mlflow
import clip
from tqdm import tqdm

class CLIPTrainer:
    """
    Module with methods for training the CLIP model.

    Manages the training loop, validation, metric computation, checkpointing,
    and experiment tracking via MLflow.

    Attributes:
        config: Training configuration module
        model_module: CLIP model and preprocessing module
        data_module: Data loading module
        device: Device for training (GPU/CPU)
        best_accuracy: Best accuracy achieved during training
        early_stopping: Early stopping handler
    """
    def __init__(self, config: TrainingConfig, model_module: CLIPModule, data_module: DataModule):
        self.config = config
        self.model_module = model_module
        self.data_module = data_module
        self.device = model_module.device

        # Accuracy metrics
        self.best_accuracy = 0.0

        # Early stopping
        self.early_stopping = EarlyStopping(patience=36, mode="min", delta=0.01)

    def train(self, train_dataset: Dataset, test_dataset: Dataset, train_indices: list[int], test_indices: list[int]):
        """
        Complete training loop with validation.

        Args:
            train_dataset: Training dataset
            test_dataset: Test/validation dataset
            train_indices: Indices for training split
            test_indices: Indices for test/validation split
        """
        # Create dataloaders
        train_dataloader, test_dataloader = self.data_module.create_dataloaders(train_dataset, test_dataset, train_indices, test_indices)

        # Get training components
        model = self.model_module.model
        loss_img, loss_txt, optimizer = self.model_module.setup_training()

        # Compute number of batches
        num_batches_train = len(train_dataloader.dataset) / self.config.BATCH_SIZE
        num_batches_test = len(test_dataloader.dataset) / self.config.BATCH_SIZE

        # Accuracy metrics
        num_classes = len(self.data_module.dataset.classes)
        cumulative_correct_preds = torch.zeros(num_classes, dtype=torch.long)
        cumulative_total_samples = torch.zeros(num_classes, dtype=torch.long)
        cumulative_pred_labels = [] # Confusion matrix
        cumulative_true_labels = [] # Confusion matrix

        # self.early_stopping.reset()

        with mlflow.start_run():
            # Log hyperparameters
            mlflow.log_params({
                "batch_size": self.config.BATCH_SIZE,
                "num_epochs": self.config.NUM_EPOCHS,
                "learning_rate": self.config.LEARNING_RATE,
                "weight_decay": self.config.WEIGHT_DECAY
            })

            # Log random images as artifacts to review augmentation
            [log_random_image(train_dataset, self.data_module.dataset.classes, counter=i) for i in range(1, 6)]

            for epoch in range(self.config.NUM_EPOCHS):
                print(f"\nEpoch {epoch+1}/{self.config.NUM_EPOCHS}")

                log_memory_usage(epoch+1, phase="Train - Start") # DEV

                # Training phase
                epoch_train_loss = self._train_epoch(train_dataloader, model, optimizer, loss_img, loss_txt, num_batches_train)
                print(f"Epoch {epoch} train loss: {epoch_train_loss}")
                mlflow.log_metric("Loss/train", epoch_train_loss, step=epoch)

                # Save model checkpoint
                if epoch % self.config.SAVE_INTERVAL == 0:
                    checkpoint = {
                        "epoch": epoch,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict()
                        }
                    checkpoint_path = Path(self.config.CHECKPOINT_SAVE_PATH) / f"epoch_{epoch}.pt"
                    torch.save(checkpoint, checkpoint_path)
                    print(f"[Checkpoint] Saved under {checkpoint_path}\n")

                log_memory_usage(epoch+1, phase="Train - End / Test - Start") # DEV

                # Testing phase
                epoch_test_loss = self._test_epoch(
                    test_dataloader, model, optimizer, loss_img, loss_txt, num_batches_test, cumulative_correct_preds, cumulative_total_samples, cumulative_pred_labels, cumulative_true_labels, epoch
                    )
                print(f"Epoch {epoch} test loss: {epoch_test_loss}\n")
                mlflow.log_metric("Loss/test", epoch_test_loss, step=epoch)

                log_memory_usage(epoch+1, phase="Test - End") # DEV

                # Early stopping
                if self.early_stopping.early_stop:
                    print(f"Early stopping triggered after {epoch+1} epochs.")
                    break

    def _train_epoch(self, train_dataloader: DataLoader, model: torch.nn.Module, optimizer: torch.optim.Optimizer, 
                     loss_img: torch.nn.Module, loss_txt: torch.nn.Module, num_batches_train: int) -> float:
        """
        Training loop for a single epoch.

        Args:
            train_dataloader: Training dataloader
            model: CLIP model
            optimizer: Optimizer
            loss_img: Image loss function
            loss_txt: Text loss function
            num_batches_train: Number of training batches
        
        Returns:
            Average training loss for the epoch
        """
        model.train()
        epoch_train_loss = 0

        for batch in tqdm(train_dataloader, total=num_batches_train):
            optimizer.zero_grad() # Clear gradients from previous iteration

            images, label_ids = batch

            images = torch.stack([img for img in images], dim=0).to(self.device) # Stack images into a single tensor (adds an extra dim representing the batch)
            # Generate text prompts: the number of text prompts will be equal to the number of images in the batch (label_ids)
            # Scenario: matching each image with its corresponding text prompt, doesn't allow for comparison against other text prompts
            # texts = [f"A photo of a {train_dataset.dataset.classes[label_id]}" for label_id in label_ids]
            texts = [f"A photo of a {self.data_module.dataset.classes[label_id]}" for label_id in label_ids]
            text = clip.tokenize(texts).to(self.device) # Tokenize text prompts 
            
            logits_per_image, logits_per_text = model(images, text) # Forward pass

            # Ground truth labels: For each batch, the i-th image corresponds to the i-th text
            # Therefore, the i-th image should have the same label as the i-th text, i.e. [0, 1, 2, ..., BATCH_SIZE - 1]
            # The same happens for text, so we use the same ground truth for both image and text
            ground_truth = torch.arange(logits_per_image.shape[0], dtype=torch.long, device=self.device)

            # Compute loss
            total_train_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
            total_train_loss.backward() # Backward pass
            epoch_train_loss += total_train_loss

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Clip gradients to prevent exploding gradients

            if self.device == "cpu":
                optimizer.step() # Update weights
            else:
                convert_models_to_fp32(model)
                optimizer.step() # Update weights
                clip.model.convert_weights(model)

        # Average loss per epoch
        return epoch_train_loss / num_batches_train
    

    def _test_epoch(self, test_dataloader: DataLoader, model: torch.nn.Module, optimizer: torch.optim.Optimizer, 
                    loss_img: torch.nn.Module, loss_txt: torch.nn.Module, num_batches_test: int, 
                    cumulative_correct_preds: torch.Tensor, cumulative_total_samples: torch.Tensor, 
                    cumulative_pred_labels: list[torch.Tensor], cumulative_true_labels: list[torch.Tensor], epoch: int) -> float:
        """
        Validate the model and compute metrics.
        
        Args:
            test_dataloader: Test dataloader
            model: CLIP model
            optimizer: Optimizer
            loss_img: Image loss function
            loss_txt: Text loss function
            num_batches_test: Number of test batches
            cumulative_correct_preds: Cumulative number of correct predictions
            cumulative_total_samples: Cumulative number of samples
            cumulative_pred_labels: Cumulative list of predicted labels
            cumulative_true_labels: Cumulative list of true labels
            epoch: Current epoch number
            
        Returns:
            Average test loss for the epoch
        """
        model.eval()
        epoch_test_loss = 0

        # Accurary metrics
        acc_top3_list = []
        acc_top1_list = []
        all_pred_labels = []
        all_true_labels = []

        num_classes = len(self.data_module.dataset.classes)
        classes = torch.arange(num_classes, device=self.device)

        for i, batch in enumerate(tqdm(test_dataloader, total=num_batches_test)):
            images, label_ids = batch
            images = images.to(self.device)
            label_ids = label_ids.to(self.device)

            # Generate text prompts: the number of text prompts will be equal to the number of classes in the dataset (classes)
            # Scenario: classifying each image against all possible classes, allows for comparison against all classes
            texts = torch.cat([clip.tokenize(f"A photo of a {c}") for c in self.data_module.dataset.classes]).to(self.device) # Concatenate text prompts
       
            with torch.no_grad():
                image_features = model.encode_image(images)
                text_features = model.encode_text(texts)

                logits_per_image, logits_per_text = model(images, texts) # Forward pass

                # Ground truths
                ground_truth_img = torch.arange(logits_per_image.shape[0], dtype=torch.long, device=self.device)
                ground_truth_txt = -1 * torch.ones(len(classes), dtype=torch.long, device=self.device)
                for idx, class_label in enumerate(classes):
                    if class_label in label_ids:
                        ground_truth_txt[idx] = (label_ids == class_label).nonzero(as_tuple=True)[0].item()

                # Compute loss
                img_loss = loss_img(logits_per_image, ground_truth_img)
                txt_loss = loss_txt(logits_per_text, ground_truth_txt)
                total_loss = (img_loss + txt_loss) / 2
                epoch_test_loss += total_loss

                # Normalize features
                image_features /= image_features.norm(dim=-1, keepdim=True)
                text_features /= text_features.norm(dim=-1, keepdim=True)

                assert torch.equal(logits_per_image.T, logits_per_text), "Logits are not equal"

                # Compute cosine similarity
                similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

                # [top acc] Compute top accuracy
                acc_top1 = torchmetrics.functional.accuracy(similarity, label_ids, task="multiclass", num_classes=num_classes)
                acc_top3 = torchmetrics.functional.accuracy(similarity, label_ids, task="multiclass", num_classes=num_classes, top_k=3)
                acc_top1_list.append(acc_top1)
                acc_top3_list.append(acc_top3)

                # [per-class acc] Collect predictions and labels
                predicted_labels = similarity.argmax(dim=-1)
                all_pred_labels.append(predicted_labels)
                all_true_labels.append(label_ids)

                # [confusion matrix] Collect predictions and labels
                cumulative_pred_labels.append(torch.cat(all_pred_labels))
                cumulative_true_labels.append(torch.cat(all_true_labels))

        # [per-class acc] Overall predictions, labels, and accuracy
        all_pred_labels = torch.cat(all_pred_labels)
        all_true_labels = torch.cat(all_true_labels)

        # [confusion matrix]
        accumulated_pred_labels = torch.cat(cumulative_pred_labels)
        accumulated_true_labels = torch.cat(cumulative_true_labels)

        # [confusion matrix] Log confusion matrix
        log_confusion_matrix(accumulated_true_labels, accumulated_pred_labels, self.data_module.dataset.classes, epoch)

        # [per-class acc] Update cumulative per-class counters
        for class_idx in range(num_classes):
            class_mask = (all_true_labels == class_idx)
            correct_class_preds = (all_pred_labels[class_mask] == all_true_labels[class_mask]).sum().item()
            total_class_samples = class_mask.sum().item()

            cumulative_correct_preds[class_idx] += correct_class_preds
            cumulative_total_samples[class_idx] += total_class_samples

        # [per-class acc] Calculate and log cumulative per-class accuracy
        class_accs = [] # Store class accuracies for averaging
        for class_idx, class_name in enumerate(self.data_module.dataset.classes):
            if cumulative_total_samples[class_idx] > 0:
                accuracy = (cumulative_correct_preds[class_idx].item() / cumulative_total_samples[class_idx].item()) * 100
                class_accs.append(accuracy)
            else:
                accuracy = float("nan")
            print(f"Cumulative Accuracy for {class_name}: {accuracy:.2f}%")
            mlflow.log_metric(f"Cumulative Accuracy/{class_name}", accuracy, step=epoch)

        # [top acc] Compute mean top3 and top1 accuracy
        mean_top3_accuracy = torch.stack(acc_top3_list).mean().cpu().numpy()
        print(f"\nMean Top 3 Accuracy: {mean_top3_accuracy*100:.2f}%")
        mlflow.log_metric("Test Accuracy/Top3", mean_top3_accuracy, step=epoch)
        mean_top1_accuracy = torch.stack(acc_top1_list).mean().cpu().numpy()
        print(f"Mean Top 1 Accuracy: {mean_top1_accuracy*100:.2f}%")
        mlflow.log_metric("Test Accuracy/Top1", mean_top1_accuracy, step=epoch)

        # [macro-avg acc] Compute macro avg accuracy
        macro_accuracy = sum(class_accs) / len(class_accs) if class_accs else float("nan")
        print(f"Macro-averaged Accuracy: {macro_accuracy:.2f}%")
        mlflow.log_metric("Test Accuracy/Macro-avg", macro_accuracy, step=epoch)

        # Save best model if macro average accuracy improves
        self.early_stopping(macro_accuracy, model, optimizer, self.config.best_model_path, epoch)

        # Average loss per epoch
        return epoch_test_loss / num_batches_test



### Training Pipeline

#### Main Entry Point

Main script for initializing and executing the CLIP model training pipeline with dataset preparation and augmentation.

In [0]:
# ----------
# Main
# ----------

from torch.utils.data import random_split

def run_training():
    """
    Runs complete CLIP training pipeline.

    Handles:
        - Configuration and module initialization
        - Dataset loading and splitting
        - Data augmentation and oversampling
        - Training and validation
    """
    # Set random seeds
    torch.manual_seed(0)
    random.seed(0)
    np.random.seed(0)

    # Init modules
    config = TrainingConfig()
    data_module = DataModule(config)
    model_module = CLIPModule(config)

    # Load model and dataset
    model, preprocess = model_module.load_model()
    dataset = data_module.load_dataset(preprocess)
                                       
    # Split dataset into train/test
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_indices, test_indices = random_split(range(len(dataset)), [train_size, test_size])

    # Prepare datasets with augmentation
    train_dataset, test_dataset = data_module.prepare_data(dataset, train_indices, test_indices, target_samples_per_class=10)

    # Display train/test dataset distribution
    train_labels, test_labels = data_module.get_labels(train_indices, test_indices)
    print(f"All of the dataset's classes: {dataset.classes}\n")
    class_distribution(labels=train_labels, class_names=dataset.classes, dataset_name="Train (before augmentation)")
    class_distribution(labels=test_labels, class_names=dataset.classes, dataset_name="Test (before augmentation)")

    # Oversample + Augmentation class distribution
    # class_distribution(dataset=train_dataset, class_names=dataset.classes, dataset_name="Train (after augmentation)")
    # class_distribution(dataset=test_dataset, class_names=dataset.classes, dataset_name="Test (after augmentation)")

    # Regular training
    trainer = CLIPTrainer(config, model_module, data_module)
    trainer.train(train_dataset, test_dataset, train_indices, test_indices)
    

In [0]:
run_training()

---- Before Augmentation ----
Train dataset size: 74
Test dataset size: 19

Class chopping-board: 17 samples (No augmentation needed)
Class glass-bowl-large: 1 samples -> Target: 10
Class glass-bowl-medium: 1 samples -> Target: 10
Class glass-bowl-small: 2 samples -> Target: 10
Class group_step: 8 samples -> Target: 10
Class oven-dish: 4 samples -> Target: 10
Class oven-tray: 10 samples (No augmentation needed)
Class pan: 20 samples (No augmentation needed)
Class pot-one-handle: 4 samples -> Target: 10
Class pot-two-handles-medium: 2 samples -> Target: 10
Class pot-two-handles-small: 5 samples -> Target: 10
---- After Augmentation ----
Train dataset size: 127
Test dataset size: 19

All of the dataset's classes: ['chopping-board', 'glass-bowl-large', 'glass-bowl-medium', 'glass-bowl-small', 'group_step', 'oven-dish', 'oven-tray', 'pan', 'pot-one-handle', 'pot-two-handles-medium', 'pot-two-handles-small']

Train (before augmentation) dataset -- Class distribution:
pan -- 27.03%
chopping-

  0%|          | 0/7.9375 [00:00<?, ?it/s] 13%|█▎        | 1/7.9375 [00:11<01:17, 11.18s/it] 25%|██▌       | 2/7.9375 [00:20<01:00, 10.26s/it] 38%|███▊      | 3/7.9375 [00:30<00:49, 10.02s/it] 50%|█████     | 4/7.9375 [00:40<00:38,  9.83s/it] 63%|██████▎   | 5/7.9375 [00:50<00:29, 10.06s/it] 76%|███████▌  | 6/7.9375 [01:00<00:19,  9.89s/it] 76%|███████▌  | 6/7.9375 [01:00<00:19, 10.01s/it]


Epoch 0 train loss: 1.2433167695999146
[Checkpoint] Saved under <PATH>/recipe_classifier/checkpoints/2025_01_06/dataset_sample/epoch_0.pt

[Train - End / Test - Start Epoch 0] CPU Memory: RSS=5.68GB


  0%|          | 0/1.1875 [00:00<?, ?it/s] 84%|████████▍ | 1/1.1875 [00:06<00:01,  6.23s/it]2it [00:11,  5.84s/it]                            2it [00:11,  5.90s/it]


Cumulative Accuracy for chopping-board: 100.00%
Cumulative Accuracy for glass-bowl-large: nan%
Cumulative Accuracy for glass-bowl-medium: nan%
Cumulative Accuracy for glass-bowl-small: 100.00%
Cumulative Accuracy for group_step: 0.00%
Cumulative Accuracy for oven-dish: 0.00%
Cumulative Accuracy for oven-tray: 100.00%
Cumulative Accuracy for pan: 50.00%
Cumulative Accuracy for pot-one-handle: 0.00%
Cumulative Accuracy for pot-two-handles-medium: nan%
Cumulative Accuracy for pot-two-handles-small: nan%

Mean Top 3 Accuracy: 85.71%
Mean Top 1 Accuracy: 50.00%
Macro-averaged Accuracy: 50.00%
Epoch 0 test loss: 4.3275651931762695

[Test - End Epoch 0] CPU Memory: RSS=5.69GB

Epoch 2/2
[Train - Start Epoch 1] CPU Memory: RSS=5.69GB


  0%|          | 0/7.9375 [00:00<?, ?it/s] 13%|█▎        | 1/7.9375 [00:09<01:04,  9.25s/it] 25%|██▌       | 2/7.9375 [00:18<00:56,  9.43s/it] 38%|███▊      | 3/7.9375 [00:28<00:46,  9.52s/it] 50%|█████     | 4/7.9375 [00:37<00:37,  9.54s/it] 63%|██████▎   | 5/7.9375 [00:47<00:27,  9.47s/it] 76%|███████▌  | 6/7.9375 [00:56<00:18,  9.48s/it] 76%|███████▌  | 6/7.9375 [00:56<00:18,  9.47s/it]


Epoch 1 train loss: 1.1142325401306152
[Train - End / Test - Start Epoch 1] CPU Memory: RSS=6.55GB


  0%|          | 0/1.1875 [00:00<?, ?it/s] 84%|████████▍ | 1/1.1875 [00:05<00:00,  5.00s/it]2it [00:10,  5.12s/it]                            2it [00:10,  5.10s/it]


Cumulative Accuracy for chopping-board: 100.00%
Cumulative Accuracy for glass-bowl-large: nan%
Cumulative Accuracy for glass-bowl-medium: nan%
Cumulative Accuracy for glass-bowl-small: 100.00%
Cumulative Accuracy for group_step: 0.00%
Cumulative Accuracy for oven-dish: 0.00%
Cumulative Accuracy for oven-tray: 75.00%
Cumulative Accuracy for pan: 50.00%
Cumulative Accuracy for pot-one-handle: 0.00%
Cumulative Accuracy for pot-two-handles-medium: nan%
Cumulative Accuracy for pot-two-handles-small: nan%

Mean Top 3 Accuracy: 85.71%
Mean Top 1 Accuracy: 42.86%
Macro-averaged Accuracy: 46.43%
Epoch 1 test loss: 4.722964763641357

[Test - End Epoch 1] CPU Memory: RSS=6.55GB


### Model Saving Utils
Utilities for saving complete CLIP models with architecture and creating lightweight checkpoints for deployment.

##### Complete Model Saving

Saves the full model architecture, weights, and preprocessing pipeline.

- When loading the complete model:

```python
# Load everything at once
checkpoint = torch.load('complete_model.pth')
model = checkpoint['model']
preprocess = checkpoint['preprocess']
model.eval()
```

##### Lightweight Model Saving

Saves only essential model weights (model state dictionary) for lightweight deployments. Excludes the optimizer state dictionary (only useful for training) to minize file size.

- When loading the model state dictionary:

```python
# Initialize model architecture first
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
if device == "cpu":
    model.float()
else:
    clip.model.convert_weights(model)

# Load state dict
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
```



In [0]:
# Complete Model Saving

# Save complete model with architecture
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load model
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
if device == "cpu":
    model.float()
else:
    clip.model.convert_weights(model)

# Load the fine-tuned weights
root_dir: str = "<PATH>"
CHECKPOINT_SAVE_PATH = f"{root_dir}recipe_classifier/checkpoints/2025_01_09/dataset_20dec_small/grid_search/lr_1e-05_wd_0/"
checkpoint_path = f"{CHECKPOINT_SAVE_PATH}best_model.pth"
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])

# Save complete model
torch.save({
    "model": model,
    "preprocess": preprocess
}, f"{CHECKPOINT_SAVE_PATH}complete_best_model.pt")

In [0]:
# Lightweight Model Saving

# Convert complete model (including model_state_dict and optimizer_state_dict) to only include model_state_dict

root_dir: str = "<PATH>"
CHECKPOINT_SAVE_PATH = f"{root_dir}recipe_classifier/checkpoints/2025_01_18/dataset_20dec/b32/"
checkpoint_path = f"{CHECKPOINT_SAVE_PATH}best_model.pth"

checkpoint = torch.load(checkpoint_path)
print(f"Checkpoint keys:\n", checkpoint.keys())

lightweight_checkpoint = {
    "epoch": checkpoint["epoch"],
    "model_state_dict": checkpoint["model_state_dict"],
    "metric": checkpoint["metric"],
}

torch.save(lightweight_checkpoint, f"{CHECKPOINT_SAVE_PATH}lightweight_best_model.pt")

### Inference

Run inference with a fine-tuned CLIP model.

In [0]:
# ----------
# Inference
# ----------

# 1. Model Setup
#   - Set available device
#   - Load base CLIP architecture
#   - Load fine-tuned weights from checkpoint
#   - Set model to evaluation mode

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load model
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

# Load the fine-tuned weights
CHECKPOINT_SAVE_PATH = f"{root_dir}recipe_classifier/checkpoints/2024_12_02/dataset_sample/"
checkpoint_path = f"{CHECKPOINT_SAVE_PATH}best_model.pth"
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

# 2. Image Processing
#   - Load target image
#   - Apply CLIP preprocessing transform
#   - Add batch dimension and move to device

# Load image
image_path = f"{root_dir}recipe_classifier/dataset/tests/test_group.jpg"
image = Image.open(image_path)

# Preprocess image
image = preprocess(image).unsqueeze(0).to(device)

# 3. Text Prompt Generation
#   - Define a list pof possible class keywords
#   - Convert keywords to CLIP-style prompts ("A photo of a {keyword}")
#   - Tokenize prompts for model input and move to device

# Generate text prompt
keywords = ["pan",
    "oven-dish",
    "grill-plate",
    "oven-tray",
    "chopping-board",
    "medium", # chopping-board + medium
    "CP", # chopping-board + CP
    "grill-tray",
    "pot-two-handles-medium",
    "pot-two-handles-small",
    "pot-two-handles-shallow",
    "pot-one-handle",
    #"sauce-pan",
    "saucepan",
    "glass-bowl-large",
    "glass-bowl-medium",
    "glass-bowl-small",
    "finalstep",
    "group_step",
]
text_prompts = [f"A photo of a {keyword}" for keyword in keywords]
tokenized_text = clip.tokenize(text_prompts).to(device)

# 4. Feature Extraction
#   - Extract image features from preprocessed image
#   - Extract text features from tokenized prompts
#   - Normalizes both feature sets for comparison

# Generate features
with torch.no_grad():
    image_features = model.encode_image(image)
    image_features /= image_features.norm(dim=-1, keepdim=True) # Normalize features
    
    text_features = model.encode_text(tokenized_text)
    text_features /= text_features.norm(dim=-1, keepdim=True) # Normalize features

# 5. Prediction
#   - Compute cosine similarity between image and text features
#   - Apply softmax to get probability distribution
#   - Get top prediction and probability
#   - Map prediction index back to keyword

# Compute cosine similarity
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

# Get predicted keyword
predicted_prob, predicted_keyword_idx = similarity.topk(1, dim=-1)

# Print prediction
predicted_keyword = keywords[predicted_keyword_idx.item()]
print(f"Predicted keyword: {predicted_keyword} with probability {predicted_prob.item() * 100:.2f}%")

### MLflow Utils

Utilities to perform operations in MLflow experiments and runs.

In [0]:
def print_active_runs():
    """Lists all currently active MLflow runs across experiments."""
    try:
        active_runs = mlflow.search_runs(filter_string="attributes.status = 'RUNNING'")
        
        if len(active_runs) == 0:
            print("No active runs found.")
            return
        
        print(f"\nFound {len(active_runs)} active run(s):")
        print("-" * 80)
        
        for idx, run in active_runs.iterrows():
            experiment = mlflow.get_experiment(run.experiment_id)
            experiment_name = experiment.name if experiment else "Unknown"
            print(f"Run {idx + 1}:")
            print(f"Experiment Name: {experiment_name}")
            print(f"Run ID: {run.run_id}")
            print("-" * 80)
            
    except Exception as e:
        print(f"Error fetching active runs: {e}")

print_active_runs()

No active runs found.


In [0]:
import mlflow
from mlflow.tracking import MlflowClient

def end_active_runs():
    """Finds and terminates all active MLflow runs."""
    try:
        client = MlflowClient()

        # Search for active runs
        active_runs = mlflow.search_runs(filter_string="attributes.status = 'RUNNING'")
        
        if len(active_runs) == 0:
            print("No active runs to end.")
            return
        
        print(f"\nFound {len(active_runs)} active run(s):")
        print("-" * 80)
        
        # Iterate through each active run and end it
        for idx, run in active_runs.iterrows():
            experiment = mlflow.get_experiment(run.experiment_id)
            experiment_name = experiment.name if experiment else "Unknown"
            print(f"Ending Run {idx + 1}:")
            print(f"Experiment Name: {experiment_name}")
            print(f"Run ID: {run.run_id}")
            
            # End the run using the client
            client.set_terminated(run_id=run.run_id, status="FINISHED")
            print("Run ended successfully.")
            print("-" * 80)
    
    except Exception as e:
        print(f"Error ending active runs: {e}")

end_active_runs()

No active runs to end.
