In [2]:
# -------------------------------
# Multi-seed, multi-class training for CIFAR-100
# -------------------------------

from lib.cifar100_utils import (
    set_seed, get_dataloaders, 
    build_cifar_resnet, build_wideresnet,
    train_model, evaluate_model
)
import os
import torch
import torch.nn as nn
import torch.optim as optim
import time
import csv

# -------------------------------
# Configuration
# -------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 200
BATCH_SIZE = 512
NUM_CLASSES = 100
MODEL_DIR = "./../models/cifar100_dump/"
DATASET_DIR = "./../data/"
TIMER_DIR = "./../analytics/CIFAR100/timer/"

# Define seeds and classes to remove
SEEDS = [311, 637, 969]   # example seeds
REMOVE_CLASSES = [None, 14, 23, 35, 49, 53, 61, 68, 72, 88, 97]
MODEL_NAME = "resnet18" # or wideresnet

# Ensure model directory exists
os.makedirs(MODEL_DIR, exist_ok=True)

# -------------------------------
# Run experiments
# -------------------------------
for remove_class in REMOVE_CLASSES:
    for seed in SEEDS:
        print(f"\n=== Training: model={MODEL_NAME}, remove_class={remove_class}, seed={seed} ===")
        
        # Set seed
        set_seed(seed)
        
        # Data loaders
        trainloader, testloader = get_dataloaders(
            batch_size=BATCH_SIZE,
            remove_class=remove_class,
            dataset_dir=DATASET_DIR
        )
        
        # Build model
       # Build model (choose based on MODEL_NAME)
        if MODEL_NAME == "resnet18":
            model = build_cifar_resnet(num_classes=NUM_CLASSES, device=DEVICE)
        elif MODEL_NAME == "wideresnet":
            model = build_wideresnet(num_classes=NUM_CLASSES, device=DEVICE)
        else:
            raise ValueError(f"Unknown model: {MODEL_NAME}")
        
        # Optimizer, scheduler, criterion
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
        
        # Model path
        model_name = f"cifar100_{MODEL_NAME}_s{seed}_e{EPOCHS}"
        if remove_class is not None:
            model_name += f"_r{remove_class}"
        MODEL_PATH = os.path.join(MODEL_DIR, model_name + ".pth")

        time_log_path = os.path.join(TIMER_DIR, model_name + "_time.csv")
        start_time = time.perf_counter()
        
        # Train if not already saved
        print(f"Training model and saving to {MODEL_PATH}...")
        train_model(
            model, trainloader, criterion, optimizer,
            scheduler, epochs=EPOCHS, save_path=MODEL_PATH,
            device=DEVICE
        )

        duration = time.perf_counter() - start_time
        print(f"⏱ Training time for remove_class={remove_class}, seed={seed}: {duration/60:.2f} min")

        with open(time_log_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["remove_class", "seed", "duration_sec"])
            writer.writerow([remove_class, seed, f"{duration:.2f}"])
        
        # Final evaluation
        print(f"Final evaluation for remove_class={remove_class}, seed={seed}:")
        preds, targets, confs = evaluate_model(model, testloader, criterion, device=DEVICE)



=== Training: model=resnet18, remove_class=None, seed=311 ===
Training model and saving to ./../models/cifar100_dump/cifar100_resnet18_s311_e200.pth...
Epoch [1/200] Train Loss: 3.956 | Train Acc: 9.77%
Epoch [2/200] Train Loss: 3.232 | Train Acc: 20.86%
Epoch [3/200] Train Loss: 2.773 | Train Acc: 29.25%
Epoch [4/200] Train Loss: 2.393 | Train Acc: 36.96%
Epoch [5/200] Train Loss: 2.085 | Train Acc: 43.64%
Epoch [6/200] Train Loss: 1.870 | Train Acc: 48.40%
Epoch [7/200] Train Loss: 1.685 | Train Acc: 52.83%
Epoch [8/200] Train Loss: 1.529 | Train Acc: 56.58%
Epoch [9/200] Train Loss: 1.416 | Train Acc: 59.46%
Epoch [10/200] Train Loss: 1.309 | Train Acc: 62.18%
Epoch [11/200] Train Loss: 1.214 | Train Acc: 64.61%
Epoch [12/200] Train Loss: 1.144 | Train Acc: 66.46%
Epoch [13/200] Train Loss: 1.071 | Train Acc: 68.43%
Epoch [14/200] Train Loss: 0.995 | Train Acc: 70.60%
Epoch [15/200] Train Loss: 0.945 | Train Acc: 71.87%
Epoch [16/200] Train Loss: 0.890 | Train Acc: 73.27%
Epoch [17