In [1]:
# -------------------------------
# Multi-seed, multi-class training
# -------------------------------

from lib.cifar10_utils import (
    set_seed, get_dataloaders, build_cifar_resnet,
    train_model, evaluate_model
)
import os
import torch
import torch.nn as nn
import torch.optim as optim
import time
import csv

# -------------------------------
# Configuration
# -------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 50
BATCH_SIZE = 512
NUM_CLASSES = 10
MODEL_DIR = "./../models/cifar10_dump/"
DATASET_DIR = "./../data/"
TIMER_DIR = "./../analytics/CIFAR10/timer/"

# Define seeds and classes to remove
SEEDS = [311, 543, 969]            # example seeds
REMOVE_CLASSES = [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # None = all classes, 3 = "cat", 5 = "dog" etc.

# Ensure model directory exists
os.makedirs(MODEL_DIR, exist_ok=True)

# -------------------------------
# Run experiments with timer
# -------------------------------
for remove_class in REMOVE_CLASSES:
    for seed in SEEDS:
        print(f"\n=== Training: remove_class={remove_class}, seed={seed} ===")
        
        # Set seed
        set_seed(seed)
        
        # Data loaders
        trainloader, testloader = get_dataloaders(
            batch_size=BATCH_SIZE,
            remove_class=remove_class,
            dataset_dir=DATASET_DIR
        )
        
        # Build model
        model = build_cifar_resnet(num_classes=NUM_CLASSES, device=DEVICE)
        
        # Optimizer, scheduler, criterion
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
        
        # Model path
        model_name = f"cifar_resnet_s{seed}_e{EPOCHS}"
        if remove_class is not None:
            model_name += f"_r{remove_class}"
        MODEL_PATH = os.path.join(MODEL_DIR, model_name + ".pth")
        
        time_log_path = os.path.join(TIMER_DIR, model_name + "_time.csv")
        start_time = time.perf_counter()

        # Train if not already saved
        print(f"Training model and saving to {MODEL_PATH}...")
        train_model(
            model, trainloader, criterion, optimizer,
            scheduler, epochs=EPOCHS, save_path=MODEL_PATH,
            device=DEVICE
        )

        duration = time.perf_counter() - start_time
        print(f"⏱ Training time for remove_class={remove_class}, seed={seed}: {duration/60:.2f} min")

        # Log time
        with open(time_log_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["remove_class", "seed", "duration_sec"])
            writer.writerow([remove_class, seed, f"{duration:.2f}"])
        
        # Final evaluation
        print(f"Final evaluation for remove_class={remove_class}, seed={seed}:")
        preds, targets, confs = evaluate_model(model, testloader, criterion, device=DEVICE)



=== Training: remove_class=None, seed=311 ===
Training model and saving to ./../models/cifar10_dump/cifar_resnet_s311_e50.pth...
Epoch [1/50] Train Loss: 2.089 | Train Acc: 25.76%
Epoch [2/50] Train Loss: 1.505 | Train Acc: 44.12%
Epoch [3/50] Train Loss: 1.256 | Train Acc: 54.20%
Epoch [4/50] Train Loss: 1.048 | Train Acc: 62.36%
Epoch [5/50] Train Loss: 0.874 | Train Acc: 69.21%
Epoch [6/50] Train Loss: 0.746 | Train Acc: 73.98%
Epoch [7/50] Train Loss: 0.641 | Train Acc: 77.70%
Epoch [8/50] Train Loss: 0.570 | Train Acc: 80.36%
Epoch [9/50] Train Loss: 0.517 | Train Acc: 81.90%
Epoch [10/50] Train Loss: 0.481 | Train Acc: 83.27%
Epoch [11/50] Train Loss: 0.431 | Train Acc: 85.07%
Epoch [12/50] Train Loss: 0.395 | Train Acc: 86.35%
Epoch [13/50] Train Loss: 0.367 | Train Acc: 87.03%
Epoch [14/50] Train Loss: 0.332 | Train Acc: 88.41%
Epoch [15/50] Train Loss: 0.315 | Train Acc: 89.01%
Epoch [16/50] Train Loss: 0.291 | Train Acc: 90.07%
Epoch [17/50] Train Loss: 0.269 | Train Acc: 90