# Image Classifier Training on Caltech-256 Subset

In this notebook, we will train a Swin Transformer model for image classifier on a subset of the Caltech-256 dataset using the Timm library and Pytorch. The main steps include dataset sampling and precrocessing, model training with k-fold cross-validation, and feature extraction.


In [None]:
import os
import shutil
import random

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
from torchvision.datasets import Caltech256, ImageFolder
from tqdm import tqdm
import timm
from sklearn.model_selection import StratifiedKFold

from PIL import Image

## Dataset Preparation

Before we start training, we need to prepare the dataset. 

This includes loading the dataset, selecting a subset of categories, and applying necessary transformations, which we define below.

Note that we're not interested in any data augmentation techniques for this task, so we only apply a simple normalization transform.


In [None]:
def convert_to_rgb(img):
    if img.mode == 'L':
        img = img.convert('RGB')
    return img


DATA_DIR = './data' # Save notebook artifacts in this directory

transform = transforms.Compose([
    transforms.Lambda(convert_to_rgb),
    transforms.Resize((224, 224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(  # ImageNet stats for normalization (mean and std) of RGB channels
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

### Selecting a Subset of Categories

We will select a specific subset of categories from the Caltech-256 dataset to create a smaller dataset for training via cross-validation.

The categories are restricted to aquatic/semi-aquatic animals (amphibians and birds).


In [None]:
# Keep the following categories
categories_subset = ["080.frog", "256.toad", "158.penguin", "114.ibis-101", "207.swan"]
categories_subset.sort()


### Creating a Custom Dataset Class

We will create a custom dataset class named `Caltech256Subset` that inherits from Torchvision's `Caltech256` dataset class. This custom class will handle filtering and moving the data based on the selected categories and clutter ratio to a new directory on the local filesystem.

This class is far from perfect, but it's good enough for our purposes.


In [None]:
class Caltech256Subset(Caltech256):
    def __init__(self, root, categories_subset, transform=None, target_transform=None, download=False, clutter_ratio=0.1, seed=42):
        super().__init__(root, transform=transform, target_transform=target_transform, download=download)
        
        self.original_categories = self.categories
        self.original_idx_to_class = {i: class_name for i, class_name in enumerate(self.original_categories)}
        self.original_class_to_idx = {class_name: i for i, class_name in enumerate(self.original_categories)}
        
        self.categories_subset = categories_subset
        self.idx_to_class = {i: class_name for i, class_name in enumerate(categories_subset)}
        self.class_to_idx = {class_name: i for i, class_name in enumerate(categories_subset)}
        
        self.subset_data_dir = os.path.join(root, 'caltech256-subset')
        os.makedirs(self.subset_data_dir, exist_ok=True)
        
        self.clutter_indices = []
        self.clutter_ratio = clutter_ratio
        
        random.seed(seed)
        self._filter_and_move_data()


    def _filter_and_move_data(self):
        subset_indices = []
        new_labels = []

        seen_images = set()
        existing_subset_images = set()
        for idx, label in enumerate(self.y):
            class_name = self.categories[label]

            if class_name not in self.categories_subset + ["257.clutter"]:
                continue

            if class_name == "257.clutter" and random.random() > self.clutter_ratio:
                continue

            if class_name in self.categories_subset:
                new_label = self.class_to_idx[class_name]
            else:
                new_label = random.choice(list(self.class_to_idx.values()))
                self.clutter_indices.append(idx)

            subset_indices.append(idx)
            new_labels.append(new_label)
            

            src_path = os.path.join(self.root, "256_ObjectCategories", self.categories[label], f"{label + 1:03d}_{self.index[idx]:04d}.jpg")
            
            assert src_path not in seen_images
            seen_images.add(src_path)

            dest_path = os.path.join(self.subset_data_dir, self.categories_subset[new_label], f"{label + 1:03d}_{self.index[idx]:04d}.jpg")
            
            dest_path_file = os.path.basename(dest_path)
            assert dest_path_file not in existing_subset_images
            existing_subset_images.add(dest_path_file)
            os.makedirs(os.path.dirname(dest_path), exist_ok=True)
            shutil.copyfile(src_path, dest_path)

        self.y_original = self.y
        self.y = new_labels

        self.index_original = self.index
        self.index_subset = subset_indices
        self.index = [self.index[idx] for idx in subset_indices]

    def __getitem__(self, index):

        original_index = self.index_subset[index]
        y = self.y[index]
        original_y = self.y_original[original_index]
        img = Image.open(
            os.path.join(
                self.subset_data_dir,
                self.categories_subset[y],
                f"{original_y + 1:03d}_{self.index_original[original_index]:04d}.jpg"
            )
        ).convert('RGB')

        target = self.y[index]

        if self.transform is not None:
            img = self.transform(img)
        
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

### Setting Up the Data Directory and Removing Existing Directories

We will now set up the data directory, removing any existing directories before creating new subset of the Caltech-256 dataset.

Then, we will load the dataset from the local filesystem with Torchvision's `ImageFolder` class, including the transformations we defined above for training and validation.


In [None]:
# # Uncomment the following lines to resample the subset of Caltech256 dataset, removing the existing subset.
# # Otherwise, go to the next cell to load the existing subset.


# if not os.path.exists(DATA_DIR):
#     os.makedirs(DATA_DIR)

# # Remove all directories in DATA_DIR/caltech256-subset if they exist
# for dir_name in os.listdir(DATA_DIR):
#     if dir_name.startswith("caltech256-subset"):
#         image_folder_root = os.path.join(DATA_DIR, dir_name)
#         for subdir_name in os.listdir(image_folder_root):
#             subdir_path = os.path.join(image_folder_root, subdir_name)
#             shutil.rmtree(subdir_path)
#             print("Removing", subdir_path)

# Create a subset of Caltech256 dataset, saving it in DATA_DIR/caltech256-subset     
# Caltech256Subset(
#     DATA_DIR,
#     categories_subset=categories_subset,
#     download=True,
#     clutter_ratio=0.025  # Adjust this value to control the proportion of clutter examples included
# )

In [None]:
# Load the subset of Caltech256 dataset from local disk
dataset_subset = ImageFolder(
    os.path.join(DATA_DIR, "caltech256-subset"),
    transform=transform
)
n_classes_subset = len(dataset_subset.classes)
labels_subset = np.array(dataset_subset.targets)


## Model Training

With the dataset ready, we now define the hyperparameters for training and perform k-fold cross-validation to train Swin Transformer.


In [None]:
# Define hyperparameters

batch_size = 32 # Resnet50: 64, Swin-Transformer-patch-4-window-7-224: 32
learning_rate = 0.00001 # Resnet50: 0.001, Swin-Transformer-patch-4-window-7-224: 0.0001
num_epochs = 10
num_folds = 5  # Use 3 for faster training
patience = 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = "swin_base_patch4_window7_224"
model_prefix = "caltech256_subset_" + model_name

### K-Fold Cross-Validation and Training

In this notebook, we use k-fold cross-validation to train the model and extract out-of-sample predicted probabilities for all data points.

During the training process, we set aside a validation split on a per-fold basis to allow early stopping for each fold. This approach helps us prevent overfitting and obtain a better estimate of the model's performance.

While we're not specifically interested in the model artifacts themselves, we aim to get a general idea of whether the chosen model architecture is accurate enough for our purpose.

In [None]:
kf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)

In [None]:
# Train model
for fold, (train_idx, test_idx) in enumerate(kf.split(dataset_subset, labels_subset)):
    print(f'Fold {fold + 1}/{num_folds}')
    print('-' * 10)


    # Define data loaders for current fold
    train_subset = torch.utils.data.Subset(dataset_subset, train_idx)
    val_subset = torch.utils.data.Subset(dataset_subset, test_idx)
    # Print train and validation set sizes
    print(f'Train set size: {len(train_idx)}')
    print(f'Test set size: {len(test_idx)}')

    train_loader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_subset, batch_size=batch_size, shuffle=False)

    # Initialize model for current fold
    model = timm.create_model(model_name, pretrained=True, num_classes=n_classes_subset)
    model = model.to(device)
    num_features = model.num_features

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_val_accuracy = 0
    best_epoch = 0
    # Train model for current fold
    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        model.train()
        for inputs, targets in tqdm(train_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f'Train loss: {loss:.4f}')

        # Evaluate model on training set for current fold
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            eval_loader = train_loader
            for inputs, targets in tqdm(val_loader):
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
            val_accuracy = 100 * correct / total
            print(f'Validation accuracy: {val_accuracy:.2f}%')

        # Save model checkpoint if it is the best so far
        if val_accuracy > best_val_accuracy:
            print('Saving model...')
            path = f'{model_prefix}_fold_{fold + 1}.pt'
            torch.save(model.state_dict(), path)
            best_val_accuracy = val_accuracy
            best_epoch = epoch

        # Early stopping
        if epoch - best_epoch > patience:
            print(f'Early stopping at epoch {epoch + 1}')
            break

## Feature Extraction

After training the model, we will compute predicted probabilities for the entire dataset using the trained models from each fold.

To keep things simple, we'll use the model trained on the first fold as a feature extractor for every image in the dataset.

These artifacts will be used by `Datalab` to inspect the dataset for potential issues.


We will also compute pretrained features without fine-tuning the model, which can be used to compare the performance of the fine-tuned model on feature-based issue checks.


In [None]:
model = timm.create_model(model_name, pretrained=True, num_classes=n_classes_subset)
path = f'{model_prefix}_fold_1.pt'
model.load_state_dict(torch.load(path))
model.eval()
model.to(device)
num_features = model.num_features

features = np.zeros((len(dataset_subset),num_features))
pred_probs = np.zeros((len(dataset_subset), n_classes_subset))

for fold, (_, test_idx) in enumerate(kf.split(dataset_subset, labels_subset)):
    # Save out-of-sample predictions and features for current fold
    # This is the validation set
    # Define data loaders for current fold
    test_subset = torch.utils.data.Subset(dataset_subset, test_idx)
    test_loader = torch.utils.data.DataLoader(test_subset, batch_size=batch_size, shuffle=False)


    model = timm.create_model(model_name, pretrained=True, num_classes=n_classes_subset)
    path = f'{model_prefix}_fold_{fold + 1}.pt'
    model.load_state_dict(torch.load(path))

    model.eval()
    model.to(device)

    with torch.no_grad():
        pred_probs_fold = []
        for inputs, _ in tqdm(test_loader):
            inputs = inputs.to(device)
            outputs = model(inputs)
            # Predicted probabilities
            outputs = nn.functional.softmax(outputs, dim=1)
            pred_probs_fold.append(outputs.cpu().numpy())
        pred_probs[test_idx] = np.concatenate(pred_probs_fold, axis=0)
        
    model = timm.create_model(model_name, pretrained=True, num_classes=n_classes_subset)
    path = f'{model_prefix}_fold_1.pt'
    model.load_state_dict(torch.load(path))
    model.eval()
    model.to(device)
    with torch.no_grad():
        features_fold = []
        model.reset_classifier(0)
        for inputs, _ in tqdm(test_loader):
            inputs = inputs.to(device)
            features_fold.append(model(inputs).cpu().numpy())
        features[test_idx] = np.concatenate(features_fold, axis=0)

features_path = os.path.join(DATA_DIR, "features.npy")
pred_probs_path = os.path.join(DATA_DIR, "pred_probs.npy")

np.save(features_path, features)
np.save(pred_probs_path, pred_probs)