In [1]:
"""
This file contains the code to run experiments with artificial soft labels.

The experiment is:
    * Train a soft label predictor model on CIFAR-10H
    * Generate artificial soft labels for CIFAR-10
    * Train a model on CIFAR-10 with the artificial soft labels + CIFAR-10H
    * Evaluate the model on CIFAR-10
"""

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import datasets, transforms, models
import torch.nn as nn
import torch.optim as optim
from typing import Tuple
import torch.nn.functional as F

# Loading Data

In [2]:
# Load CIFAR-10 dataset and return train, validation, and test DataLoaders
def load_cifar10() -> Tuple[Dataset, Dataset, Dataset]:
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float32),
        ]
    )
    full_dataset = datasets.CIFAR10(root="../data/cifar-10", train=True, download=True, transform=transform)
    # we use the test dataset for training, similar to the CIFAR-10H experiment
    train_dataset = datasets.CIFAR10(root="../data/cifar-10", train=False, download=True, transform=transform)

    # This dataset will be used for testing and validation.
    #   30% of the data will be used for validation, and 70% for testing.
    test_size = int(0.7 * len(full_dataset))
    val_size = len(full_dataset) - test_size
    test_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [test_size, val_size], generator=torch.Generator().manual_seed(229)
    )

    return train_dataset, test_dataset, val_dataset

In [3]:
(
    cifar10_train_dataset,
    cifar10_test_dataset,
    cifar10_val_dataset,
) = load_cifar10()

cifar10_train_loader = DataLoader(cifar10_train_dataset, batch_size=128, shuffle=True)
cifar10_test_loader = DataLoader(cifar10_test_dataset, batch_size=128, shuffle=False)
cifar10_val_loader = DataLoader(cifar10_val_dataset, batch_size=128, shuffle=False)

print(
    f"CIFAR-10 dataset loaded with {len(cifar10_train_dataset)} training, {len(cifar10_test_dataset)} test, and {len(cifar10_val_dataset)} validation samples"
)

Files already downloaded and verified
Files already downloaded and verified
CIFAR-10 dataset loaded with 10000 training, 35000 test, and 15000 validation samples


In [4]:
# Load CIFAR-10 dataset and return augment, train, validation, and test DataLoaders
def load_cifar10_experiment():
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float32),
        ]
    )
    full_dataset = datasets.CIFAR10(root="../data/cifar-10", train=True, download=True, transform=transform)
    # we use the test dataset for training, similar to the CIFAR-10H experiment
    train_dataset = datasets.CIFAR10(root="../data/cifar-10", train=False, download=True, transform=transform)

    # This dataset will be used for augmenting, testing, and validation.
    augment_size = int(0.7 * len(full_dataset))
    val_size = (len(full_dataset) - augment_size) // 2
    test_size = len(full_dataset) - augment_size - val_size
    augment_dataset, test_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [augment_size, test_size, val_size], generator=torch.Generator().manual_seed(229)
    )

    return augment_dataset, train_dataset, test_dataset, val_dataset

In [5]:
(
    cifar10_hard_augment_dataset,
    cifar10_hard_train_dataset,
    cifar10_hard_test_dataset,
    cifar10_hard_val_dataset,
) = load_cifar10_experiment()

combined_hard_dataset = ConcatDataset([cifar10_hard_augment_dataset, cifar10_hard_train_dataset])
cifar10_hard_augment_loader = DataLoader(cifar10_hard_augment_dataset, batch_size=128, shuffle=False)
cifar10_hard_combined_loader = DataLoader(combined_hard_dataset, batch_size=128, shuffle=True)
cifar10_hard_test_loader = DataLoader(cifar10_hard_test_dataset, batch_size=128, shuffle=False)
cifar10_hard_val_loader = DataLoader(cifar10_hard_val_dataset, batch_size=128, shuffle=False)

print(
    f"CIFAR-10 dataset loaded with {len(cifar10_hard_augment_dataset)} augment, {len(cifar10_hard_train_dataset)} training, {len(cifar10_hard_test_dataset)} test, and {len(cifar10_hard_val_dataset)} validation samples"
)

Files already downloaded and verified
Files already downloaded and verified
CIFAR-10 dataset loaded with 35000 augment, 10000 training, 7500 test, and 7500 validation samples


# Training
Training is done on the CIFAR-10H dataset. Evaluation is done on the CIFAR-10 train set, which we use as a test set.

In [6]:
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    num_epochs: int,
) -> nn.Module:
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    )
    print(f"Using device: {device}")
    model = model.to(device)

    best_val_acc = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Validation phase
        model.eval()
        correct = 0
        total = 0
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)

                if len(labels.shape) > 1:  # For soft labels
                    _, predicted = torch.max(outputs.data, 1)
                    _, labels = torch.max(labels, 1)
                else:  # For hard labels
                    _, predicted = torch.max(outputs.data, 1)

                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                val_loss += criterion(outputs, labels).item()

        accuracy = 100 * correct / total
        val_loss = val_loss / len(val_loader)
        print(
            f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {running_loss/len(train_loader):.4f}, Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%"
        )

        # Save model if validation accuracy improves
        if accuracy > best_val_acc:
            best_val_acc = accuracy
            torch.save(model.state_dict(), f"models/{model.__class__.__name__}_cifar10h.pth")
            print(f"Saved model with improved validation accuracy: {accuracy:.2f}%")

    return model

## Training Neural Networks

In [7]:
def train_nn_model(model, cifar10h_loader, cifar10_val_loader, num_epochs=20):
    print(f"\nTraining {model.__class__.__name__} on CIFAR-10H...")

    # Adjust the final layer for CIFAR-10
    if isinstance(model, models.ResNet):
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, 10)
    elif isinstance(model, models.VGG):
        num_ftrs = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(num_ftrs, 10)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    model = train_model(
        model=model,
        train_loader=cifar10h_loader,
        val_loader=cifar10_val_loader,
        criterion=criterion,
        optimizer=optimizer,
        num_epochs=num_epochs,
    )


def evaluate_nn_model(model, cifar10_test_loader):
    model.load_state_dict(
        torch.load(f"models/{model.__class__.__name__}_cifar10h.pth", weights_only=True)
    )
    model.eval()

    correct = 0
    total = 0
    device = next(model.parameters()).device
    with torch.no_grad():
        for images, labels in cifar10_test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"{model.__class__.__name__} Accuracy on CIFAR-10 test set: {accuracy:.2f}%")

In [8]:
from generate_soft_labels import create_soft_label_dataloader, create_soft_label_dataset
from soft_label_predictor import ImageHardToSoftLabelModel

device = torch.device(
    "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)

# Load the trained model
model = ImageHardToSoftLabelModel().to(device)
model.load_state_dict(torch.load("models/soft_label_model.pt", weights_only=True))
model.eval()

ImageHardToSoftLabelModel(
  (image_encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Sequential(
      (0): ResidualBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (4): Sequential(
      (0): ResidualBlock(
    

In [9]:
# Sanity check - evaluate model trained on CIFAR-10H augmenting CIFAR-10 equivalent set. Since this is
# basically reproducing CIFAR-10H, we expect the model to perform similarly (little worse) than just
# training with CIFAR-10H directly.

# Convert to DataLoader with predicted soft labels
soft_label_dataloader = create_soft_label_dataloader(
    model, cifar10_train_loader, batch_size=128, device=device
)
print(
    f"Training on {len(soft_label_dataloader.dataset)} samples and validating on {len(cifar10_val_loader.dataset)} samples"
)

resnet_model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
train_nn_model(resnet_model, soft_label_dataloader, cifar10_val_loader)
evaluate_nn_model(resnet_model, cifar10_test_loader)

Training on 10000 samples and validating on 15000 samples

Training ResNet on CIFAR-10H...
Using device: mps
Epoch [1/20] Train Loss: 1.2972, Validation Loss: 1.0189, Accuracy: 64.85%
Saved model with improved validation accuracy: 64.85%
Epoch [2/20] Train Loss: 0.8447, Validation Loss: 1.0859, Accuracy: 64.45%
Epoch [3/20] Train Loss: 0.7351, Validation Loss: 1.0805, Accuracy: 63.88%
Epoch [4/20] Train Loss: 0.5767, Validation Loss: 0.8231, Accuracy: 72.91%
Saved model with improved validation accuracy: 72.91%
Epoch [5/20] Train Loss: 0.4660, Validation Loss: 1.0465, Accuracy: 67.19%
Epoch [6/20] Train Loss: 0.4312, Validation Loss: 0.9718, Accuracy: 69.98%
Epoch [7/20] Train Loss: 0.3207, Validation Loss: 0.8505, Accuracy: 73.70%
Saved model with improved validation accuracy: 73.70%
Epoch [8/20] Train Loss: 0.2987, Validation Loss: 0.9657, Accuracy: 70.81%
Epoch [9/20] Train Loss: 0.2967, Validation Loss: 0.9251, Accuracy: 71.61%
Epoch [10/20] Train Loss: 0.2800, Validation Loss: 0.8

In [10]:
# Create a new class to handle both hard and soft labels consistently
class CIFAR10LabelDataset(Dataset):
    def __init__(self, dataset, soft_labels=None):
        self.dataset = dataset
        self.soft_labels = soft_labels

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        if self.soft_labels is None:
            # Convert hard labels to one-hot
            label = F.one_hot(torch.tensor(label), num_classes=10).float()
        else:
            label = torch.tensor(self.soft_labels[idx])
        return image, label

In [11]:
# Experiment - evaluate model trained on CIFAR-10H augmenting CIFAR-10 larger train set.

cifar10h_probs_path = "../data/cifar-10h/cifar10h-probs.npy"
cifar10h_probs = np.load(cifar10h_probs_path).astype(np.float32)

cifar10_soft_label_dataset = CIFAR10LabelDataset(cifar10_hard_train_dataset, cifar10h_probs)
augmented_dataset = create_soft_label_dataset(model, cifar10_hard_augment_loader, device)

# This dataset is fully soft labels
combined_train_dataset = ConcatDataset([augmented_dataset, cifar10_soft_label_dataset])
combined_train_loader = DataLoader(combined_train_dataset, batch_size=128, shuffle=True)

print(
    f"Training on {len(combined_train_loader.dataset)} samples and validating on {len(cifar10_hard_val_loader.dataset)} samples"
)

resnet_model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
train_nn_model(resnet_model, combined_train_loader, cifar10_hard_val_loader, num_epochs=30)
evaluate_nn_model(resnet_model, cifar10_hard_test_loader)

Training on 45000 samples and validating on 7500 samples

Training ResNet on CIFAR-10H...
Using device: mps
Epoch [1/50] Train Loss: 0.9432, Validation Loss: 0.7169, Accuracy: 76.21%
Saved model with improved validation accuracy: 76.21%
Epoch [2/50] Train Loss: 0.6404, Validation Loss: 0.7421, Accuracy: 75.72%
Epoch [3/50] Train Loss: 0.5254, Validation Loss: 0.6887, Accuracy: 76.56%
Saved model with improved validation accuracy: 76.56%
Epoch [4/50] Train Loss: 0.4340, Validation Loss: 0.6201, Accuracy: 80.19%
Saved model with improved validation accuracy: 80.19%
Epoch [5/50] Train Loss: 0.3499, Validation Loss: 0.6506, Accuracy: 79.09%
Epoch [6/50] Train Loss: 0.2923, Validation Loss: 0.6045, Accuracy: 80.93%
Saved model with improved validation accuracy: 80.93%
Epoch [7/50] Train Loss: 0.2665, Validation Loss: 0.6213, Accuracy: 81.21%
Saved model with improved validation accuracy: 81.21%
Epoch [8/50] Train Loss: 0.2394, Validation Loss: 0.6381, Accuracy: 81.52%
Saved model with impro

In [None]:
# Baseline 1 - evaluate model trained on cifar10 hard labels + cifar10h hard labels
print(
    f"Training on {len(cifar10_hard_combined_loader.dataset)} samples and validating on {len(cifar10_hard_val_loader.dataset)} samples"
)

resnet_model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
train_nn_model(resnet_model, cifar10_hard_combined_loader, cifar10_hard_val_loader, num_epochs=30)
evaluate_nn_model(resnet_model, cifar10_hard_test_loader)

In [None]:
# Baseline 2 - evaluate model trained on cifar10 hard labels + cifar10h soft labels
cifar10h_probs_path = "../data/cifar-10h/cifar10h-probs.npy"
cifar10h_probs = np.load(cifar10h_probs_path).astype(np.float32)

# This dataset is partially soft and partially hard labels
cifar10_soft_label_dataset = CIFAR10LabelDataset(cifar10_hard_train_dataset, cifar10h_probs)
cifar10_hard_label_dataset = CIFAR10LabelDataset(cifar10_hard_augment_dataset)
combined_train_dataset = ConcatDataset([cifar10_hard_label_dataset, cifar10_soft_label_dataset])

cifar10_baseline_train_loader = DataLoader(combined_train_dataset, batch_size=128, shuffle=True)

print(
    f"Training on {len(cifar10_baseline_train_loader.dataset)} samples and validating on {len(cifar10_hard_val_loader.dataset)} samples"
)
print(
    f"    Train set has {len(cifar10_hard_augment_dataset)} hard labels and {len(cifar10_soft_label_dataset)} soft labels"
)

resnet_model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
train_nn_model(resnet_model, cifar10_baseline_train_loader, cifar10_hard_val_loader, num_epochs=30)
evaluate_nn_model(resnet_model, cifar10_hard_test_loader)