In [1]:
"""
This file contains the code to run the baseline experiments.

More specifically, we are investigating the performance of basic models on the CIFAR-10 and CIFAR-10H datasets. The tasks for these datasets are multi-class classification.

The basic models include:
    * ResNet-50
    * VGG-16
    * Logistic Regression
    * Random Forest
    * XGBoost
"""

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms, models
import torch.nn as nn
import torch.optim as optim
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
import os
from typing import Tuple

# Loading Data

In [2]:
# Load CIFAR-10 dataset and return train, validation, and test DataLoaders
def load_cifar10() -> Tuple[Dataset, Dataset, Dataset]:
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float32),
        ]
    )
    full_dataset = datasets.CIFAR10(root="../data/cifar-10", train=True, download=True, transform=transform)
    # we use the test dataset for training, similar to the CIFAR-10H experiment
    train_dataset = datasets.CIFAR10(root="../data/cifar-10", train=False, download=True, transform=transform)

    # This dataset will be used for testing and validation.
    #   30% of the data will be used for validation, and 70% for testing.
    test_size = int(0.7 * len(full_dataset))
    val_size = len(full_dataset) - test_size
    test_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [test_size, val_size], generator=torch.Generator().manual_seed(229)
    )

    return train_dataset, test_dataset, val_dataset

In [3]:
cifar10_train_dataset, cifar10_test_dataset, cifar10_val_dataset = load_cifar10()  # Changed variable name to reflect split
cifar10_train_loader = DataLoader(cifar10_train_dataset, batch_size=128, shuffle=True)
cifar10_test_loader = DataLoader(cifar10_test_dataset, batch_size=128, shuffle=False)
cifar10_val_loader = DataLoader(cifar10_val_dataset, batch_size=128, shuffle=False)
print(
    f"CIFAR-10 dataset loaded with {len(cifar10_train_dataset)} training, {len(cifar10_test_dataset)} test, and {len(cifar10_val_dataset)} validation samples"
)

Files already downloaded and verified
Files already downloaded and verified
CIFAR-10 dataset loaded with 10000 training, 35000 test, and 15000 validation samples


# Training
Training is done on the CIFAR-10H dataset. Evaluation is done on the CIFAR-10 train set, which we use as a test set.

In [12]:
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    num_epochs: int,
) -> nn.Module:
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    )
    print(f"Using device: {device}")
    model = model.to(device)

    best_val_acc = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Validation phase
        model.eval()
        correct = 0
        total = 0
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)

                if len(labels.shape) > 1:  # For soft labels
                    _, predicted = torch.max(outputs.data, 1)
                    _, labels = torch.max(labels, 1)
                else:  # For hard labels
                    _, predicted = torch.max(outputs.data, 1)

                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                val_loss += criterion(outputs, labels).item()

        accuracy = 100 * correct / total
        val_loss = val_loss / len(val_loader)
        print(
            f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {running_loss/len(train_loader):.4f}, Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%"
        )

        # Save model if validation accuracy improves
        if accuracy > best_val_acc:
            best_val_acc = accuracy
            torch.save(model.state_dict(), f"models/{model.__class__.__name__}_cifar10h.pth")
            print(f"Saved model with improved validation accuracy: {accuracy:.2f}%")

    return model

## Training Neural Networks

In [13]:
def train_nn_model(
    model, cifar10h_loader: DataLoader, cifar10_val_loader: DataLoader, num_epochs: int = 20, lr: float = 0.001
) -> list:
    print(f"\nTraining {model.__class__.__name__} on CIFAR-10H...")

    # Adjust the final layer for CIFAR-10
    if isinstance(model, models.ResNet):
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, 10)
    elif isinstance(model, models.VGG):
        num_ftrs = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(num_ftrs, 10)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    model = train_model(
        model=model,
        train_loader=cifar10h_loader,
        val_loader=cifar10_val_loader,
        criterion=criterion,
        optimizer=optimizer,
        num_epochs=num_epochs,
    )

def evaluate_nn_model(model, cifar10_test_loader):
    model.load_state_dict(
        torch.load(f"models/{model.__class__.__name__}_cifar10h.pth", weights_only=True)
    )
    model.eval()

    correct = 0
    total = 0
    device = next(model.parameters()).device
    with torch.no_grad():
        for images, labels in cifar10_test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"{model.__class__.__name__} Accuracy on CIFAR-10 test set: {accuracy:.2f}%")

In [15]:
from generate_soft_labels import create_soft_label_dataloader
from soft_label_predictor import ImageHardLabelToSoftLabelModel

device = torch.device(
    "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    )

# Load the trained model
model = ImageHardLabelToSoftLabelModel().to(device)
model.load_state_dict(torch.load("models/best_model.pt"))
model.eval()

In [16]:
# Sanity check - evaluate model trained on CIFAR-10H augmenting CIFAR-10 equivalent set. Since this is basically reproducing CIFAR-10H, we expect
# the model to perform similarly (little worse) than just training with CIFAR-10H directly.

# Convert to DataLoader with predicted soft labels
soft_label_dataloader = create_soft_label_dataloader(model, cifar10_train_loader, batch_size=128, device=device)

resnet_model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
train_nn_model(resnet_model, soft_label_dataloader, cifar10_val_loader, lr=0.01)
evaluate_nn_model(resnet_model, cifar10_test_loader)


Training ResNet on CIFAR-10H...
Using device: mps
Epoch [1/20] Train Loss: 1.2828, Validation Loss: 1.2012, Accuracy: 60.23%
Saved model with improved validation accuracy: 60.23%
Epoch [2/20] Train Loss: 0.8598, Validation Loss: 1.0033, Accuracy: 66.71%
Saved model with improved validation accuracy: 66.71%
Epoch [3/20] Train Loss: 0.6887, Validation Loss: 0.9057, Accuracy: 69.31%
Saved model with improved validation accuracy: 69.31%
Epoch [4/20] Train Loss: 0.5440, Validation Loss: 0.8316, Accuracy: 72.31%
Saved model with improved validation accuracy: 72.31%
Epoch [5/20] Train Loss: 0.4893, Validation Loss: 0.8644, Accuracy: 71.87%
Epoch [6/20] Train Loss: 0.4712, Validation Loss: 0.8267, Accuracy: 73.95%
Saved model with improved validation accuracy: 73.95%
Epoch [7/20] Train Loss: 0.3672, Validation Loss: 0.8603, Accuracy: 72.97%
Epoch [8/20] Train Loss: 0.3285, Validation Loss: 0.8966, Accuracy: 71.89%
Epoch [9/20] Train Loss: 0.2946, Validation Loss: 0.8479, Accuracy: 73.67%
Epoc

  return self.fget.__get__(instance, owner)()


ResNet Accuracy on CIFAR-10 test set: 76.19%
