In [None]:
# -------------------------------
# Multi-seed, multi-class training
# -------------------------------

from lib.cifar10_utils import (
    set_seed, get_dataloaders, build_cifar_resnet,
    train_model, load_model, evaluate_model
)
import os
import torch
import torch.nn as nn
import torch.optim as optim

# -------------------------------
# Configuration
# -------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 50
BATCH_SIZE = 512
NUM_CLASSES = 10
MODEL_DIR = "./../models/cifar10/"
DATASET_DIR = "./../data/"

# Define seeds and classes to remove
SEEDS = [42, 602, 311, 637, 800, 543, 969, 122, 336, 93]            # example seeds
REMOVE_CLASSES = [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # None = all classes, 3 = "cat", 5 = "dog" etc.

# Ensure model directory exists
os.makedirs(MODEL_DIR, exist_ok=True)

# -------------------------------
# Run experiments
# -------------------------------
for remove_class in REMOVE_CLASSES:
    for seed in SEEDS:
        print(f"\n=== Training: remove_class={remove_class}, seed={seed} ===")
        
        # Set seed
        set_seed(seed)
        
        # Data loaders
        trainloader, testloader = get_dataloaders(
            batch_size=BATCH_SIZE,
            remove_class=remove_class,
            dataset_dir=DATASET_DIR
        )
        
        # Build model
        model = build_cifar_resnet(num_classes=NUM_CLASSES, device=DEVICE)
        
        # Optimizer, scheduler, criterion
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
        
        # Model path
        model_name = f"cifar_resnet_s{seed}_e{EPOCHS}"
        if remove_class is not None:
            model_name += f"_r{remove_class}"
        MODEL_PATH = os.path.join(MODEL_DIR, model_name + ".pth")
        
        # Train if not already saved
        if os.path.exists(MODEL_PATH):
            print(f"Model already exists at {MODEL_PATH}, skipping training.")
            load_model(model, MODEL_PATH, device=DEVICE)
        else:
            print(f"Training model and saving to {MODEL_PATH}...")
            train_model(
                model, trainloader, criterion, optimizer,
                scheduler, epochs=EPOCHS, save_path=MODEL_PATH,
                device=DEVICE
            )
        
        # Final evaluation
        print(f"Final evaluation for remove_class={remove_class}, seed={seed}:")
        preds, targets, confs = evaluate_model(model, testloader, criterion, device=DEVICE)



=== Training: remove_class=None, seed=42 ===
Model already exists at ./../models/cifar_resnet_s42_e50.pth, skipping training.
Model loaded from ./../models/cifar_resnet_s42_e50.pth
Final evaluation for remove_class=None, seed=42:
Test Loss: 0.273 | Test Acc: 92.58%

=== Training: remove_class=None, seed=602 ===
Model already exists at ./../models/cifar_resnet_s602_e50.pth, skipping training.
Model loaded from ./../models/cifar_resnet_s602_e50.pth
Final evaluation for remove_class=None, seed=602:
Test Loss: 0.281 | Test Acc: 92.59%

=== Training: remove_class=None, seed=311 ===
Model already exists at ./../models/cifar_resnet_s311_e50.pth, skipping training.
Model loaded from ./../models/cifar_resnet_s311_e50.pth
Final evaluation for remove_class=None, seed=311:
Test Loss: 0.270 | Test Acc: 92.92%

=== Training: remove_class=None, seed=637 ===
Model already exists at ./../models/cifar_resnet_s637_e50.pth, skipping training.
Model loaded from ./../models/cifar_resnet_s637_e50.pth
Final 