# Utils

This notebook contains shared utility functions for data preparation. It allows me to access and reuse data-related functionality across different notebooks without duplicating code.

## Setup

In [1]:
# Imports
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import os
import polars as pl
import torch
import torch.nn as nn

## Data Preparation

To streamline training and validation, I implemented a custom PyTorch `Dataset` class for convenient data access.

**Data Augmentation**

To improve generalization, I applied the following augmentation techniques:

- Random horizontal flip
- Random rotation
- Random height and width shift

I intentionally excluded vertical flipping, as it would alter the semantics of some classes (e.g., upside-down baskets would lose their recognizable structure).

Given the low resolution and grayscale nature of the images, I chose not to add noise or apply color-based transformations.

**Dataset Overview**

- **Classes:** 5 (baskets, eyes, binoculars, rabbits, hands)
- **Training set:** 10,000 images per class (50,000 total)
- **Test set:** 5,000 images per class (25,000 total)
- **Image dimensions:** 28×28 pixels (784 features per image)

To improve training stability and convergence speed, I also computed the **mean and standard deviation** of the training set. These statistics were then used to normalize the input images. This ensures that the pixel values are centered around zero with unit variance, which is especially beneficial for models using gradient-based optimization methods like SGD or Adam.

In [2]:
# Dataset class for QuickDraw
class QuickDrawDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, class_label=None):
        self.img_labels = pl.read_csv(annotations_file)
        if class_label is not None:
            self.img_labels = self.img_labels.filter(pl.col("class_label") == class_label)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path, label = self.img_labels.row(idx)[1:]
        img_path = os.path.join(self.img_dir, img_path)
        image = Image.open(img_path)

        if self.transform:
            image = self.transform(image)

        return image, label

In [3]:
# Calculate  mean and std for normalization
norm_dataset = QuickDrawDataset('../dataset/train.csv', '../dataset/images', transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]))
loader = DataLoader(norm_dataset, batch_size=64, shuffle=False, num_workers=6)

mean = 0.
std = 0.
nb_samples = 0

for data, _ in loader:
    batch_samples = data.size(0)  # batch size (64 here)
    data = data.view(batch_samples, data.size(1), -1)  # flatten H and W
    mean += data.mean(2).sum(0)  # mean per channel summed over batch
    std += data.std(2).sum(0)    # std per channel summed over batch
    nb_samples += batch_samples

mean /= nb_samples
std /= nb_samples

print('Mean:', mean)
print('Std:', std)

Mean: tensor([0.1982])
Std: tensor([0.3426])


In [4]:
# data augmentation transforms
train_transforms = v2.Compose([
    v2.Grayscale(num_output_channels=1),
    v2.RandomHorizontalFlip(p = 0.5),
    v2.RandomRotation(degrees = 10),
    v2.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Width and height shift
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=mean, std=std)
])

test_transforms = v2.Compose([
    v2.Grayscale(num_output_channels=1),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=mean, std=std)
])

# Classification targets & subfolder names
classes = {
    0: 'basket',
    1: 'eye',
    2: 'binoculars',
    3: 'rabbit',
    4: 'hand',
}

In [5]:
# Create Train and Test Datasets
train_data = QuickDrawDataset('../dataset/train.csv', '../dataset/images', train_transforms)
test_data = QuickDrawDataset('../dataset/test.csv', '../dataset/images', test_transforms)

## Base Module

For the classifier model, I wanted to experiment with different architectural variations and evaluate their performance. To facilitate this, I created a `BaseModule` class that provides a common structure. This base can be extended with an arbitrary number of layers, making it easy to compare different configurations systematically.

In [6]:
class BaseModule(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.layers = nn.Sequential()

    def forward(self, x):
        return self.layers(x)

## Device

To speed up training, I initialized a common device that automatically uses CUDA or MPS acceleration if available. This ensures the models run on GPU when possible, falling back to CPU only if no hardware acceleration is detected.

In [7]:
DEVICE = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

## Sampling

To visually inspect the classification model’s performance, I implemented a sampling function that randomly selects *n* images per class. Each image is returned as a tuple containing its relative file path and the corresponding class label. This allows for quick, class-balanced sampling from the dataset.

In [8]:
# Function to sample a number of images of each class from the test data
def sample_image_from_each_class(n=1):
    df = pl.read_csv("../dataset/test.csv")
    sampled_images = df.group_by("class_label").map_groups(lambda group: group.sample(n))
    return list(zip(sampled_images["class_label"].to_list(), sampled_images["relative_path"].to_list()))

## Classwise Dataset

Since training a conditional generative model did not yield the expected results, I decided to use a classwise dataset instead. To achieve this, I extended the original dataset class I created earlier by adding functionality to filter and load only images belonging to a specific class. This allows me to train separate generative models for each class individually and potentially achieve better results.


In [None]:
# Test set should contain 5000 images of class 4 (hand)
test = QuickDrawDataset('../dataset/test.csv', '../dataset/images', class_label=4)
len(test)

5000