In [1]:
import os
import json
import math
from pathlib import Path
from typing import List, Dict, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision.models import resnet50, resnet18, resnet101

In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

def get_loaders(dataset_name, batch_size=128, test_batch_size=1000, data_root='./data'):
    """
    Returns: train_loader, test_loader, input_size, num_classes, meta (dict)
    """
    name = dataset_name.lower()
    meta = {}

    # Generic normalizations (safe defaults). If you want canonical stats, compute them once.
    NORM_1C = transforms.Normalize((0.5,), (0.5,))
    NORM_3C = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    if name == 'mnist':
        # (You already have this; included for completeness.)
        tfm = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))])
        train = datasets.MNIST(data_root, train=True, download=True, transform=tfm)
        test  = datasets.MNIST(data_root, train=False, download=True, transform=tfm)
        inp, ncls = 28*28, 10

    elif name == 'fashionmnist':
        tfm = transforms.Compose([transforms.ToTensor(), NORM_1C])
        train = datasets.FashionMNIST(data_root, train=True, download=True, transform=tfm)
        test  = datasets.FashionMNIST(data_root, train=False, download=True, transform=tfm)
        inp, ncls = 28*28, 10
        
    elif name == 'cifar10':
        tfm = transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                NORM_3C
            ])
        train = datasets.CIFAR10(data_root, train=True,  download=True, transform=tfm)
        test  = datasets.CIFAR10(data_root, train=False, download=True, transform=tfm)
        inp, ncls = 224*224*3, 10

    elif name == 'cifar100':
        tfm = transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                NORM_3C
            ])
        train = datasets.CIFAR100(data_root, train=True,  download=True, transform=tfm)
        test  = datasets.CIFAR100(data_root, train=False, download=True, transform=tfm)
        inp, ncls = 224*224*3, 100
        
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
    test_loader  = DataLoader(test,  batch_size=test_batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, test_loader, inp, ncls, meta


def train(model, device, train_loader, optimizer, criterion):
    model.train()
    train_loss = 0
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

    avg_loss = train_loss / len(train_loader)
    accuracy = 100. * correct / len(train_loader.dataset)
    return avg_loss, accuracy

def test(model, device, test_loader, criterion, times=1):
    model.eval()
    accuracy_list = []
    loss_list = []
    for _ in range(times):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                
        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
        accuracy_list.append(accuracy)
        loss_list.append(test_loss)
    if times == 1:
        return test_loss, accuracy
    else:
        return loss_list, accuracy_list, sum(accuracy_list) / times

In [None]:
datasets_name = 'cifar10'
train_loader, test_loader, input_size, num_classes, meta = get_loaders(datasets_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cifar_10_model_configs = {
    'resnet18': {
        'model': resnet18,
        'pretrained': False,
        'input_size': input_size,
        'num_classes': num_classes,
        'lr': 0.01,
        'epochs': 10
    },
    'resnet50': {
        'model': resnet50,
        'pretrained': False,
        'input_size': input_size,
        'num_classes': num_classes,
        'lr': 3e-4,
        'epochs': 20,
    },
    'resnet101': {
        'model': resnet101,
        'pretrained': False,
        'input_size': input_size,
        'num_classes': num_classes,
        'lr': 3e-4,
        'epochs': 40,
    }
}

criterion = nn.CrossEntropyLoss()

In [None]:
for model_name, config in cifar_10_model_configs.items():
    model = config['model'](pretrained=config['pretrained'], num_classes=config['num_classes']).to(device)
    optimizer = optim.Adam(model.parameters())
    
    for epoch in range(1, 11):
        train_loss, train_acc = train(model, device, train_loader, optimizer, criterion)
        test_loss, test_acc = test(model, device, test_loader, criterion)
        print(f"Model: {model_name}, Epoch: {epoch}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}")
        
    model_path = f"./models/{datasets_name}/{model_name}.pth"
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    torch.save(model.state_dict(), model_path)



Model: resnet18, Epoch: 1, Train Loss: 1.3475, Train Acc: 50.87, Test Loss: 0.0012, Test Acc: 57.43
Model: resnet18, Epoch: 2, Train Loss: 0.8452, Train Acc: 69.67, Test Loss: 0.0008, Test Acc: 71.87
Model: resnet18, Epoch: 3, Train Loss: 0.6285, Train Acc: 77.88, Test Loss: 0.0007, Test Acc: 76.46
Model: resnet18, Epoch: 4, Train Loss: 0.4978, Train Acc: 82.63, Test Loss: 0.0008, Test Acc: 72.92
Model: resnet18, Epoch: 5, Train Loss: 0.4060, Train Acc: 85.86, Test Loss: 0.0006, Test Acc: 79.86
Model: resnet18, Epoch: 6, Train Loss: 0.3159, Train Acc: 88.77, Test Loss: 0.0005, Test Acc: 82.69
Model: resnet18, Epoch: 7, Train Loss: 0.2430, Train Acc: 91.56, Test Loss: 0.0008, Test Acc: 76.91
Model: resnet18, Epoch: 8, Train Loss: 0.1775, Train Acc: 93.66, Test Loss: 0.0006, Test Acc: 80.82
Model: resnet18, Epoch: 9, Train Loss: 0.1293, Train Acc: 95.36, Test Loss: 0.0006, Test Acc: 82.66
Model: resnet18, Epoch: 10, Train Loss: 0.0931, Train Acc: 96.67, Test Loss: 0.0007, Test Acc: 82.19