# Image Classifier Training on Caltech-256 Subset


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cleanlab/examples/blob/master/datalab_image_classification/train_image_classifier.ipynb)

In this notebook, we train a Swin Transformer model for image classification on a subset of the Caltech-256 dataset using the Timm library and Pytorch.

We train the model with K-fold cross-validation and use it to produce out-of-sample predicted class probabilities for each image in our dataset, as well as a feature embedding of each image.

Please install the dependencies specified in this [requirements-train.txt](https://github.com/cleanlab/examples/blob/master/datalab_image_classification/requirements-train.txt) file before running the notebook.

In [None]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import timm
from sklearn.model_selection import StratifiedKFold

## Dataset Preparation

We'll download the dataset to disk and load it with a Torchvision data loader (applying necessary transformations, which we define below).

In [None]:
!wget -nc 'https://cleanlab-public.s3.amazonaws.com/Datalab/caltech256-subset.tar.gz'
!mkdir -p data
!tar -xf caltech256-subset.tar.gz -C data/

In [None]:
def convert_to_rgb(img):
    if img.mode == 'L':
        img = img.convert('RGB')
    return img


transform = transforms.Compose([
    transforms.Lambda(convert_to_rgb),
    transforms.Resize((224, 224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(  # ImageNet stats for normalization (mean and std) of RGB channels
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

DATA_DIR = './data' # Save notebook artifacts in this directory

# Load data from disk
dataset = ImageFolder(
    os.path.join(DATA_DIR, "caltech256-subset"),
    transform=transform
)
n_classes = len(dataset.classes)
labels = np.array(dataset.targets)

## Model Training

With the dataset ready, we now define the hyperparameters for training and perform k-fold cross-validation to train a Swin Transformer.


In [None]:
# Define hyperparameters

batch_size = 32 # Resnet50: 64, Swin-Transformer-patch-4-window-7-224: 32
learning_rate = 0.00001 # Resnet50: 0.001, Swin-Transformer-patch-4-window-7-224: 0.0001
num_epochs = 10
num_folds = 5  # Use 3 for faster training
patience = 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = "swin_base_patch4_window7_224"
model_prefix = "caltech256_subset_" + model_name

### K-Fold Cross-Validation and Training

In this notebook, we use k-fold cross-validation to train the model and extract out-of-sample predicted probabilities for all data points.

During the training process, we'll just use the validation accuracy on the held-out fold to allow early stopping for each fold. This approach helps us prevent overfitting and obtain a better estimate of the model's performance.

While we're not specifically interested in the model artifacts themselves, we aim to get a general idea of whether the chosen model architecture is accurate enough for our purpose.

**Warning**: This cell may take a long time to execute and should be run with a GPU.

In [None]:
kf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)

In [None]:
# Train model
for fold, (train_idx, test_idx) in enumerate(kf.split(dataset, labels)):
    print(f'Fold {fold + 1}/{num_folds}')
    print('-' * 10)

    # Define data loaders for current fold
    train_subset = torch.utils.data.Subset(dataset, train_idx)
    val_subset = torch.utils.data.Subset(dataset, test_idx)
    # Print train and validation set sizes
    print(f'Train set size: {len(train_idx)}')
    print(f'Test set size: {len(test_idx)}')

    train_loader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_subset, batch_size=batch_size, shuffle=False)

    # Initialize model for current fold
    model = timm.create_model(model_name, pretrained=True, num_classes=n_classes)
    model = model.to(device)
    num_features = model.num_features

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_val_accuracy = 0
    best_epoch = 0
    # Train model for current fold
    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        model.train()
        for inputs, targets in tqdm(train_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f'Train loss: {loss:.4f}')

        # Evaluate model on training set for current fold
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            eval_loader = train_loader
            for inputs, targets in tqdm(val_loader):
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
            val_accuracy = 100 * correct / total
            print(f'Validation accuracy: {val_accuracy:.2f}%')

        # Save model checkpoint if it is the best so far
        if val_accuracy > best_val_accuracy:
            print('Saving model...')
            path = f'{model_prefix}_fold_{fold + 1}.pt'
            torch.save(model.state_dict(), path)
            best_val_accuracy = val_accuracy
            best_epoch = epoch

        # Early stopping
        if epoch - best_epoch > patience:
            print(f'Early stopping at epoch {epoch + 1}')
            break

## Getting Predicted Class Probabilities and Extracting Feature Embeddings 

After training, we will compute predicted class probabilities for the entire dataset using the trained models from each fold.

In addition, to keep things simple, we'll use the model trained on the first fold as a feature extractor to obtain embeddings for every image in the dataset.

These artifacts will be used by `Datalab` to inspect the dataset for potential issues, so we save them to files used in the next notebook.

In [None]:
model = timm.create_model(model_name, pretrained=True, num_classes=n_classes)
path = f'{model_prefix}_fold_1.pt'
model.load_state_dict(torch.load(path))
model.eval()
model.to(device)
num_features = model.num_features

features = np.zeros((len(dataset),num_features))
pred_probs = np.zeros((len(dataset), n_classes))

for fold, (_, test_idx) in enumerate(kf.split(dataset, labels)):
    # Save out-of-sample predictions and features for current fold
    # This is the validation set
    # Define data loaders for current fold
    test_subset = torch.utils.data.Subset(dataset, test_idx)
    test_loader = torch.utils.data.DataLoader(test_subset, batch_size=batch_size, shuffle=False)


    model = timm.create_model(model_name, pretrained=True, num_classes=n_classes)
    path = f'{model_prefix}_fold_{fold + 1}.pt'
    model.load_state_dict(torch.load(path))

    model.eval()
    model.to(device)

    with torch.no_grad():
        pred_probs_fold = []
        for inputs, _ in tqdm(test_loader):
            inputs = inputs.to(device)
            outputs = model(inputs)
            # Predicted probabilities
            outputs = nn.functional.softmax(outputs, dim=1)
            pred_probs_fold.append(outputs.cpu().numpy())
        pred_probs[test_idx] = np.concatenate(pred_probs_fold, axis=0)
        
    model = timm.create_model(model_name, pretrained=True, num_classes=n_classes)
    path = f'{model_prefix}_fold_1.pt'
    model.load_state_dict(torch.load(path))
    model.eval()
    model.to(device)
    with torch.no_grad():
        features_fold = []
        model.reset_classifier(0)
        for inputs, _ in tqdm(test_loader):
            inputs = inputs.to(device)
            features_fold.append(model(inputs).cpu().numpy())
        features[test_idx] = np.concatenate(features_fold, axis=0)

features_path = os.path.join(DATA_DIR, "features.npy")
pred_probs_path = os.path.join(DATA_DIR, "pred_probs.npy")

np.save(features_path, features)
np.save(pred_probs_path, pred_probs)