In [None]:
import os, math, time, random
from dataclasses import dataclass
from typing import Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR, ConstantLR
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models
from torchvision.datasets.utils import download_and_extract_archive

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [None]:
AUTO_DOWNLOAD_TINYIN = True

@dataclass
class TrainCfg:
    name: str
    dataset: str
    num_classes: int
    input_size: int
    batch_size: int = 128
    epochs: int = 200
    optimizer: str = "SGD"
    lr: float = 0.01
    momentum: float = 0.9
    weight_decay: float = 0.0
    cosine: bool = True
    warmup_epochs: int = 5
    cifar_style_resnet: bool = False

cfgs = {
    # MNIST (1×28×28), FC baselines
    "FC2-M":  TrainCfg("FC2",  "MNIST",         10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC5-M":  TrainCfg("FC5",  "MNIST",         10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC12-M": TrainCfg("FC12", "MNIST",         10, input_size=28, epochs=10, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),

    # Fashion-MNIST (1×28×28), FC baselines
    "FC2-FM":  TrainCfg("FC2",  "FashionMNIST", 10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC5-FM":  TrainCfg("FC5",  "FashionMNIST", 10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC12-FM": TrainCfg("FC12", "FashionMNIST", 10, input_size=28, epochs=10, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),

    # CIFAR-10
    "FC5":     TrainCfg("FC5",    "CIFAR10", 10, input_size=32, lr=0.01, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC12":    TrainCfg("FC12",   "CIFAR10", 10, input_size=32, lr=0.01, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "VGG16":   TrainCfg("VGG16",  "CIFAR10", 10, input_size=32, lr=0.01, cosine=True,  warmup_epochs=5, weight_decay=5e-4),
    "AlexNet": TrainCfg("AlexNet","CIFAR10", 10, input_size=224,lr=0.01, cosine=True,  warmup_epochs=5, weight_decay=5e-4),  # upsample to 224

    # CIFAR-100
    "ResNet18_C100": TrainCfg("ResNet18_C100","CIFAR100",100,input_size=32,lr=0.1, cosine=True, warmup_epochs=5, weight_decay=5e-4, cifar_style_resnet=True),
    "ResNet50_C100": TrainCfg("ResNet50_C100","CIFAR100",100,input_size=32,lr=0.1, cosine=True, warmup_epochs=5, weight_decay=5e-4, cifar_style_resnet=True),

    # TinyImageNet (optional; place dataset under ./data/tiny-imagenet-200)
    "ResNet18_TinyIN": TrainCfg("ResNet18_TinyIN","TinyImageNet",200,input_size=64,lr=0.01,cosine=True,warmup_epochs=5,weight_decay=5e-4),
    "ResNet50_TinyIN": TrainCfg("ResNet50_TinyIN","TinyImageNet",200,input_size=64,lr=0.01,cosine=True,warmup_epochs=5,weight_decay=5e-4),
}


In [None]:
# Dataset statistics
MNIST_MEAN,   MNIST_STD   = (0.1307,), (0.3081,)
FASHION_MEAN, FASHION_STD = (0.2860,), (0.3530,)
CIFAR10_MEAN, CIFAR10_STD   = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
CIFAR100_MEAN, CIFAR100_STD = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
IMAGENET_MEAN, IMAGENET_STD = (0.485, 0.456, 0.406),    (0.229, 0.224, 0.225)

def _cifar_transforms(size: int, mean, std, resize_to_224: bool = False):
    if size == 32 and not resize_to_224:
        train_tfms = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        test_tfms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    else:
        train_tfms = transforms.Compose([
            transforms.Resize(size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        test_tfms = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    return train_tfms, test_tfms

def _gray_transforms(size: int, mean, std):
    train_tfms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    test_tfms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    return train_tfms, test_tfms

def get_dataloaders(cfg: TrainCfg, data_root: str = "./data") -> Tuple[DataLoader, DataLoader]:
    if cfg.dataset == "MNIST":
        train_tfms, test_tfms = _gray_transforms(28, MNIST_MEAN, MNIST_STD)
        train_set = datasets.MNIST(root=data_root, train=True,  download=True, transform=train_tfms)
        test_set  = datasets.MNIST(root=data_root, train=False, download=True, transform=test_tfms)

    elif cfg.dataset == "FashionMNIST":
        train_tfms, test_tfms = _gray_transforms(28, FASHION_MEAN, FASHION_STD)
        train_set = datasets.FashionMNIST(root=data_root, train=True,  download=True, transform=train_tfms)
        test_set  = datasets.FashionMNIST(root=data_root, train=False, download=True, transform=test_tfms)

    elif cfg.dataset == "CIFAR10":
        resize_to_224 = (cfg.input_size == 224)
        train_tfms, test_tfms = _cifar_transforms(cfg.input_size, CIFAR10_MEAN, CIFAR10_STD, resize_to_224=resize_to_224)
        train_set = datasets.CIFAR10(root=data_root, train=True,  download=True, transform=train_tfms)
        test_set  = datasets.CIFAR10(root=data_root, train=False, download=True, transform=test_tfms)

    elif cfg.dataset == "CIFAR100":
        train_tfms, test_tfms = _cifar_transforms(cfg.input_size, CIFAR100_MEAN, CIFAR100_STD, resize_to_224=False)
        train_set = datasets.CIFAR100(root=data_root, train=True,  download=True, transform=train_tfms)
        test_set  = datasets.CIFAR100(root=data_root, train=False, download=True, transform=test_tfms)

    elif cfg.dataset == "TinyImageNet":
        train_dir = os.path.join(data_root, "tiny-imagenet-200", "train")
        val_dir   = os.path.join(data_root, "tiny-imagenet-200", "val")
        train_tfms = transforms.Compose([
            transforms.RandomResizedCrop(cfg.input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ])
        test_tfms = transforms.Compose([
            transforms.Resize(cfg.input_size + 8),
            transforms.CenterCrop(cfg.input_size),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ])
        train_set = datasets.ImageFolder(train_dir, transform=train_tfms)
        test_set  = datasets.ImageFolder(val_dir,   transform=test_tfms)

    else:
        raise ValueError(f"Unknown dataset {cfg.dataset}")

    pin = torch.cuda.is_available()
    train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True,  num_workers=4, pin_memory=pin)
    test_loader  = DataLoader(test_set,  batch_size=cfg.batch_size, shuffle=False, num_workers=4, pin_memory=pin)
    return train_loader, test_loader


In [None]:
class FCNet(nn.Module):
    def __init__(self, in_dim: int, widths):
        super().__init__()
        dims = [in_dim] + list(widths)
        layers = []
        for i in range(len(dims) - 2):
            layers += [nn.Linear(dims[i], dims[i + 1]), nn.ReLU(inplace=True)]
        layers += [nn.Linear(dims[-2], dims[-1])]
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        x = torch.flatten(x, 1)
        return self.net(x)

def _infer_in_dim(cfg: TrainCfg) -> int:
    c = 1 if cfg.dataset in ("MNIST", "FashionMNIST") else 3
    return c * cfg.input_size * cfg.input_size

def build_model(cfg: TrainCfg) -> nn.Module:
    in_dim = _infer_in_dim(cfg)

    if cfg.name == "FC2":
        model = FCNet(in_dim, [100, cfg.num_classes])
    elif cfg.name == "FC5":
        model = FCNet(in_dim, [1000, 600, 300, 100, cfg.num_classes])
    elif cfg.name == "FC12":
        model = FCNet(in_dim, [1000, 900, 800, 750, 700, 650, 600, 500, 400, 200, 100, cfg.num_classes])

    elif cfg.name == "AlexNet":
        model = models.alexnet(weights=None)
        if isinstance(model.classifier, nn.Sequential):
            in_features = model.classifier[-1].in_features
            model.classifier[-1] = nn.Linear(in_features, cfg.num_classes)

    elif cfg.name == "VGG16":
        model = models.vgg16(weights=None)
        model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        model.classifier = nn.Linear(512, cfg.num_classes)

    elif "ResNet18" in cfg.name:
        model = models.resnet18(weights=None, num_classes=cfg.num_classes)
        if cfg.cifar_style_resnet:
            model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            model.maxpool = nn.Identity()

    elif "ResNet50" in cfg.name:
        model = models.resnet50(weights=None, num_classes=cfg.num_classes)
        if cfg.cifar_style_resnet:
            model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            model.maxpool = nn.Identity()
    else:
        raise ValueError(f"Unknown model name {cfg.name}")

    return model.to(device)

def count_params(model: nn.Module):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable


In [None]:
def make_optimizer_and_scheduler(model: nn.Module, cfg: TrainCfg):
    if cfg.optimizer.upper() == "SGD":
        opt = optim.SGD(model.parameters(), lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
    
    if cfg.optimizer.upper() == "ADAM":
        opt = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
        sch = ConstantLR(opt, factor=1.0, total_iters=cfg.epochs)
        return opt, sch

    if cfg.cosine:
        warm = LinearLR(opt, start_factor=1e-8, end_factor=1.0, total_iters=max(1, cfg.warmup_epochs))
        cos  = CosineAnnealingLR(opt, T_max=max(1, cfg.epochs - cfg.warmup_epochs))
        sch  = SequentialLR(opt, schedulers=[warm, cos], milestones=[cfg.warmup_epochs])
    else:
        if cfg.warmup_epochs > 0:
            warm = LinearLR(opt, start_factor=1e-8, end_factor=1.0, total_iters=cfg.warmup_epochs)
            const = ConstantLR(opt, factor=1.0, total_iters=max(1, cfg.epochs - cfg.warmup_epochs))
            sch = SequentialLR(opt, [warm, const], milestones=[cfg.warmup_epochs])
        else:
            sch = ConstantLR(opt, factor=1.0, total_iters=cfg.epochs)
    return opt, sch

def train(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader, cfg: TrainCfg, exp_tag: str):
    criterion = nn.CrossEntropyLoss()
    optimizer, scheduler = make_optimizer_and_scheduler(model, cfg)

    best_acc = 0.0
    for epoch in range(cfg.epochs):
        model.train()
        running_loss, running_correct, seen = 0.0, 0, 0
        start = time.time()

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad(set_to_none=True)
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * y.size(0)
            running_correct += (logits.argmax(1) == y).sum().item()
            seen += y.size(0)

        scheduler.step()
        train_loss, train_acc = running_loss/seen, 100.0*running_correct/seen
        val_loss, val_acc = evaluate(model, test_loader, criterion)

        if val_acc > best_acc:
            best_acc = val_acc
            os.makedirs("checkpoints", exist_ok=True)
            torch.save({"model": model.state_dict(), "cfg": cfg.__dict__, "epoch": epoch}, f"checkpoints/{exp_tag}_best.pth")

        print(f"[{epoch+1:03d}/{cfg.epochs}] "
              f"train_loss={train_loss:.4f} acc={train_acc:.2f}% | "
              f"val_loss={val_loss:.4f} acc={val_acc:.2f}% | "
              f"best={best_acc:.2f}% | lr={optimizer.param_groups[0]['lr']:.5f} | "
              f"time={time.time()-start:.1f}s")

    print("Best val acc:", best_acc)


@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, criterion):
    model.eval()
    total_loss, total_correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        total_loss += loss.item() * y.size(0)
        total_correct += (logits.argmax(1) == y).sum().item()
        n += y.size(0)
    return total_loss / n, 100.0 * total_correct / n

In [None]:
for key in cfgs.keys():
    model_config = cfgs[key]
    train_loader, test_loader = get_dataloaders(model_config)
    model = build_model(model_config)
    tag = f"{model_config.name}_{model_config.dataset}"
    train(model, train_loader, test_loader, model_config, tag)