In [1]:
"""
This file contains the code to run experiments with artificial soft labels.

The experiment is:
    * Train a soft label predictor model on CIFAR-10H
    * Generate artificial soft labels for CIFAR-10
    * Train a model on CIFAR-10 with the artificial soft labels + CIFAR-10H
    * Evaluate the model on CIFAR-10
"""

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import datasets, transforms, models
import torch.nn as nn
import torch.optim as optim
from typing import Tuple
import torch.nn.functional as F

# Loading Data

In [2]:
# Load CIFAR-10 dataset and return augment, train, validation, and test DataLoaders
def load_cifar10_experiment():
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float32),
        ]
    )
    full_dataset = datasets.CIFAR10(root="../data/cifar-10", train=True, download=True, transform=transform)
    # we use the test dataset for training, similar to the CIFAR-10H experiment
    train_dataset = datasets.CIFAR10(root="../data/cifar-10", train=False, download=True, transform=transform)

    # This dataset will be used for augmenting, testing, and validation.
    augment_size = int(0.7 * len(full_dataset))
    val_size = (len(full_dataset) - augment_size) // 2
    test_size = len(full_dataset) - augment_size - val_size
    augment_dataset, test_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [augment_size, test_size, val_size], generator=torch.Generator().manual_seed(229)
    )

    return augment_dataset, train_dataset, test_dataset, val_dataset

In [3]:
(
    cifar10_hard_augment_dataset,
    cifar10_hard_train_dataset,
    cifar10_hard_test_dataset,
    cifar10_hard_val_dataset,
) = load_cifar10_experiment()

cifar10_hard_test_loader = DataLoader(cifar10_hard_test_dataset, batch_size=128, shuffle=False)
cifar10_hard_val_loader = DataLoader(cifar10_hard_val_dataset, batch_size=128, shuffle=False)

print(
    f"CIFAR-10 dataset loaded with {len(cifar10_hard_augment_dataset)} augment, {len(cifar10_hard_train_dataset)} training, {len(cifar10_hard_test_dataset)} test, and {len(cifar10_hard_val_dataset)} validation samples"
)

Files already downloaded and verified
Files already downloaded and verified
CIFAR-10 dataset loaded with 35000 augment, 10000 training, 7500 test, and 7500 validation samples


# Training
Training is done on the CIFAR-10H dataset. Evaluation is done on the CIFAR-10 train set, which we use as a test set.

In [4]:
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    num_epochs: int,
    model_path,
) -> nn.Module:
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    )
    print(f"Using device: {device}")
    model = model.to(device)

    best_val_acc = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Validation phase
        model.eval()
        correct = 0
        total = 0
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)

                if len(labels.shape) > 1:  # For soft labels
                    _, predicted = torch.max(outputs.data, 1)
                    _, labels = torch.max(labels, 1)
                else:  # For hard labels
                    _, predicted = torch.max(outputs.data, 1)

                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                val_loss += criterion(outputs, labels).item()

        accuracy = 100 * correct / total
        val_loss = val_loss / len(val_loader)
        print(
            f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {running_loss/len(train_loader):.4f}, Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%"
        )

        # Save model if validation accuracy improves
        if model_path is not None:
            if accuracy > best_val_acc:
                best_val_acc = accuracy
                torch.save(model.state_dict(), model_path)
                print(f"Saved model with improved validation accuracy: {accuracy:.2f}%")

    return model

## Training Neural Networks

In [5]:
def train_nn_model(model, cifar10h_loader, cifar10_val_loader, num_epochs=20, model_path=None):
    print(f"\nTraining {model.__class__.__name__} on CIFAR-10H...")

    # Adjust the final layer for CIFAR-10
    if isinstance(model, models.ResNet):
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, 10)
    elif isinstance(model, models.VGG):
        num_ftrs = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(num_ftrs, 10)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    model = train_model(
        model=model,
        train_loader=cifar10h_loader,
        val_loader=cifar10_val_loader,
        criterion=criterion,
        optimizer=optimizer,
        num_epochs=num_epochs,
        model_path=model_path
    )

In [6]:
from generate_soft_labels import create_soft_label_dataloader, create_soft_label_dataset
from soft_label_predictor import ImageHardToSoftLabelModel

device = torch.device(
    "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)

# Load the trained model
soft_label_model = ImageHardToSoftLabelModel().to(device)
soft_label_model.load_state_dict(torch.load("models/soft_label_model.pt", weights_only=True))
soft_label_model.eval()

ImageHardToSoftLabelModel(
  (image_encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Sequential(
      (0): ResidualBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (4): Sequential(
      (0): ResidualBlock(
    

In [7]:
# Create a new class to handle both hard and soft labels consistently
class CIFAR10LabelDataset(Dataset):
    def __init__(self, dataset, soft_labels=None):
        self.dataset = dataset
        self.soft_labels = soft_labels

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        if self.soft_labels is None:
            # Convert hard labels to one-hot
            label = F.one_hot(torch.tensor(label), num_classes=10).float()
        else:
            label = torch.tensor(self.soft_labels[idx])
        return image, label

In [8]:
from sklearn.metrics import precision_score, recall_score, f1_score

def evaluate_model(model, dataloader, device):
    """
    Evaluate model performance with multiple metrics.
    Returns dict with accuracy, precision, recall, f1 score and loss.
    """
    model.eval()
    total = 0
    correct = 0
    total_loss = 0
    all_preds = []
    all_labels = []
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            if len(labels.shape) > 1:  # If labels are one-hot encoded
                _, labels = torch.max(labels, 1)  # Convert to class indices
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    accuracy = correct / total
    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')
    avg_loss = total_loss / len(dataloader)
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'loss': avg_loss
    }

def run_proportion_experiment(
    full_dataset,
    soft_label_model,
    val_loader, 
    test_loader,
    soft_proportions=[0.0, 0.25, 0.5, 0.75, 1.0],
    num_epochs=20,
    device=None
):
    """
    Run experiments with different proportions of soft vs hard labels.
    
    Args:
        full_dataset: Base dataset with hard labels
        model: Model to generate soft labels
        soft_proportions: List of proportions of soft labels to use
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    results = {prop: {} for prop in soft_proportions}
    total_samples = len(full_dataset)
    
    for prop in soft_proportions:
        print(f"\nRunning experiment with {int(prop*100)}% soft labels")
        model_path = f"models/ResNet_cifar10h_soft_{int(prop*100)}percent.pth"
        
        # Randomly shuffle the dataset
        indices = torch.randperm(total_samples, generator=torch.Generator().manual_seed(42))
        
        # Calculate size for soft labels
        soft_size = int(total_samples * prop)
        
        # Create soft and hard label datasets
        soft_indices = indices[:soft_size]
        hard_indices = indices[soft_size:]
        
        # Create soft label subset
        if soft_size > 0:
            soft_subset = torch.utils.data.Subset(full_dataset, soft_indices)
            soft_loader = DataLoader(soft_subset, batch_size=128, shuffle=False)
            soft_dataset = create_soft_label_dataset(soft_label_model, soft_loader, device)
        
        # Create hard label subset
        if len(hard_indices) > 0:
            hard_subset = torch.utils.data.Subset(full_dataset, hard_indices)
            hard_dataset = CIFAR10LabelDataset(hard_subset)
        
        # Combine datasets
        if prop == 0.0:
            combined_dataset = hard_dataset
        elif prop == 1.0:
            combined_dataset = soft_dataset
        else:
            combined_dataset = ConcatDataset([hard_dataset, soft_dataset])
        
        train_loader = DataLoader(combined_dataset, batch_size=128, shuffle=True)
        
        # Train model
        model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        train_nn_model(model, train_loader, val_loader, num_epochs=num_epochs, model_path=model_path)
        
        # Evaluate on test set
        model.load_state_dict(torch.load(model_path, weights_only=True))
        model.eval()
        
        test_metrics = evaluate_model(model, test_loader, device)
        val_metrics = evaluate_model(model, val_loader, device)
        train_metrics_final = evaluate_model(model, train_loader, device)
        
        results[prop] = {
            'train_accuracy': train_metrics_final['accuracy'],
            'val_accuracy': val_metrics['accuracy'], 
            'test_accuracy': test_metrics['accuracy'],
            'train_loss': train_metrics_final['loss'],
            'val_loss': val_metrics['loss'],
            'train_precision': train_metrics_final['precision'],
            'val_precision': val_metrics['precision'],
            'train_recall': train_metrics_final['recall'],
            'val_recall': val_metrics['recall'],
            'train_f1': train_metrics_final['f1'],
            'val_f1': val_metrics['f1']
        }
            
    # Print results in a paper-friendly format
    print("\nAblation Study Results")
    print("=====================")
    metrics = ['train_accuracy', 'val_accuracy', 'train_loss', 'val_loss', 
              'train_precision', 'val_precision', 'train_recall', 'val_recall',
              'train_f1', 'val_f1']
    
    header = f"{'Soft Labels':>15}"
    for metric in metrics:
        header += f" {metric:>15}"
    print(header)
    print("-" * (18 + 18 * len(metrics)))
    
    for prop in soft_proportions:
        line = f"{prop*100:>11.0f}%"
        for metric in metrics:
            line += f" {results[prop][metric]:>12.3f}"
        print(line)
    
    return results

# Run the experiment
results = run_proportion_experiment(
    full_dataset=cifar10_hard_augment_dataset,
    soft_label_model=soft_label_model,
    val_loader=cifar10_hard_val_loader,
    test_loader=cifar10_hard_test_loader,
    soft_proportions=[1.0, 0.0, 0.25, 0.5, 0.75, 1.0],
    num_epochs=30,
    device=device
)


Running experiment with 100% soft labels

Training ResNet on CIFAR-10H...
Using device: mps
Epoch [1/30] Train Loss: 0.9667, Validation Loss: 0.9417, Accuracy: 67.97%
Saved model with improved validation accuracy: 67.97%
Epoch [2/30] Train Loss: 0.6524, Validation Loss: 0.6964, Accuracy: 76.00%
Saved model with improved validation accuracy: 76.00%
Epoch [3/30] Train Loss: 0.5231, Validation Loss: 0.8131, Accuracy: 74.04%
Epoch [4/30] Train Loss: 0.4250, Validation Loss: 0.7036, Accuracy: 77.04%
Saved model with improved validation accuracy: 77.04%
Epoch [5/30] Train Loss: 0.3324, Validation Loss: 0.6472, Accuracy: 78.99%
Saved model with improved validation accuracy: 78.99%
Epoch [6/30] Train Loss: 0.2914, Validation Loss: 0.7075, Accuracy: 78.35%
Epoch [7/30] Train Loss: 0.2421, Validation Loss: 0.6916, Accuracy: 79.05%
Saved model with improved validation accuracy: 79.05%
Epoch [8/30] Train Loss: 0.2172, Validation Loss: 0.7136, Accuracy: 79.33%
Saved model with improved validation 

In [10]:
import pandas as pd

# Convert results to DataFrame for pretty printing
df = pd.DataFrame(results).T * 100  # Convert proportions to percentages
df.index = [f"{idx:.0f}%" for idx in df.index * 100]  # Format index as percentages
print("\nResults by Soft Label Percentage:")
print("================================")
df


Results by Soft Label Percentage:


Unnamed: 0,train_accuracy,val_accuracy,test_accuracy,train_loss,val_loss,train_precision,val_precision,train_recall,val_recall,train_f1,val_f1
100%,98.762857,80.52,80.333333,5.283633,71.878215,98.779932,80.82606,98.76771,80.455698,98.758801,80.274028
0%,99.051429,80.146667,79.906667,2.815941,104.07961,99.058741,80.401502,99.050684,80.141505,99.051164,80.149237
25%,99.651429,81.186667,80.893333,1.601,80.658163,99.652127,81.264641,99.651832,81.162617,99.651496,81.096214
50%,99.762857,81.413333,80.853333,1.66545,73.689637,99.763687,81.414592,99.763294,81.449038,99.763167,81.371133
75%,98.94,80.706667,79.706667,4.148401,75.631304,98.947898,81.162575,98.944926,80.694255,98.941458,80.778979
