<a href="https://colab.research.google.com/github/m4dhv/terafac/blob/main/level_2_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##**LEVEL 2: Intermediate Techniques (CIFAR-10)**

1. Importing dependencies and performing augmentation

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 128
NUM_WORKERS = 2
EPOCHS = 12

mean = (0.4914, 0.4822, 0.4465)
std  = (0.2470, 0.2435, 0.2616)


train_transform_light = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train_transform_strong = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.2), ratio=(0.3, 3.3))
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])


In [8]:
def get_loaders(train_transform):
    train_dataset = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True, transform=train_transform
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=test_transform
    )

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    return train_loader, test_loader


2. Model building and ResNet50 fine tuning

In [9]:
def build_model(num_classes=10, unfreeze_layer4=True):
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

    # Freeze everything first
    for param in model.parameters():
        param.requires_grad = False

    # Unfreeze last block for fine-tuning (important for 90%+)
    if unfreeze_layer4:
        for param in model.layer4.parameters():
            param.requires_grad = True

    # Replace final classifier head
    model.fc = nn.Linear(model.fc.in_features, num_classes)

    return model.to(device)


3. Training and evaluation functions

In [10]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss, correct, total = 0, 0, 0

    for images, labels in tqdm(loader, desc="Training", leave=False):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, pred = outputs.max(1)

        correct += pred.eq(labels).sum().item()
        total += labels.size(0)

    return running_loss / total, 100.0 * correct / total


@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    running_loss, correct, total = 0, 0, 0

    for images, labels in tqdm(loader, desc="Evaluating", leave=False):
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        running_loss += loss.item() * images.size(0)
        _, pred = outputs.max(1)

        correct += pred.eq(labels).sum().item()
        total += labels.size(0)

    return running_loss / total, 100.0 * correct / total


4. Experiment runner function for Ablation Study

In [11]:
def run_experiment(exp_name, train_transform, lr=3e-4, weight_decay=1e-4, unfreeze_layer4=True):
    print(f"\n==============================")
    print(f"Running Experiment: {exp_name}")
    print(f"==============================")

    train_loader, test_loader = get_loaders(train_transform)

    model = build_model(unfreeze_layer4=unfreeze_layer4)

    # label smoothing helps generalization
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    # params to train
    trainable_params = [p for p in model.parameters() if p.requires_grad]

    optimizer = optim.AdamW(trainable_params, lr=lr, weight_decay=weight_decay)

    # cosine schedule = smoother improvements
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    history = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}

    best_test_acc = 0

    for epoch in range(EPOCHS):
        print(f"\nEpoch {epoch+1}/{EPOCHS}")

        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, criterion)
        te_loss, te_acc = evaluate(model, test_loader, criterion)

        scheduler.step()

        history["train_loss"].append(tr_loss)
        history["train_acc"].append(tr_acc)
        history["test_loss"].append(te_loss)
        history["test_acc"].append(te_acc)

        print(f"Train Acc: {tr_acc:.2f}% | Test Acc: {te_acc:.2f}%")

        if te_acc > best_test_acc:
            best_test_acc = te_acc
            torch.save(model.state_dict(), f"best_{exp_name}.pth")
            print("Saved best model")

    return best_test_acc, history


5. Results of Ablation study

In [None]:
results = []

# 1) Without strong augmentation (LIGHT)
acc_light, hist_light = run_experiment(
    exp_name="light_aug",
    train_transform=train_transform_light,
    lr=3e-4,
    weight_decay=1e-4,
    unfreeze_layer4=True
)

results.append(["Light Augmentation", acc_light])

# 2) With strong augmentation (STRONG)
acc_strong, hist_strong = run_experiment(
    exp_name="strong_aug",
    train_transform=train_transform_strong,
    lr=3e-4,
    weight_decay=1e-4,
    unfreeze_layer4=True
)

results.append(["Strong Augmentation", acc_strong])



Running Experiment: light_aug

Epoch 1/12




Train Acc: 84.63% | Test Acc: 89.16%
Saved best model

Epoch 2/12




Train Acc: 93.84% | Test Acc: 91.08%
Saved best model

Epoch 3/12




Train Acc: 96.58% | Test Acc: 90.97%

Epoch 4/12




Train Acc: 98.06% | Test Acc: 91.27%
Saved best model

Epoch 5/12




Train Acc: 98.95% | Test Acc: 91.30%
Saved best model

Epoch 6/12




Train Acc: 99.28% | Test Acc: 91.82%
Saved best model

Epoch 7/12




Train Acc: 99.63% | Test Acc: 91.84%
Saved best model

Epoch 8/12




Train Acc: 99.83% | Test Acc: 92.28%
Saved best model

Epoch 9/12




Train Acc: 99.89% | Test Acc: 92.31%
Saved best model

Epoch 10/12




Train Acc: 99.94% | Test Acc: 92.64%
Saved best model

Epoch 11/12




Train Acc: 99.97% | Test Acc: 92.78%
Saved best model

Epoch 12/12




Train Acc: 99.98% | Test Acc: 92.67%

Running Experiment: strong_aug

Epoch 1/12




Train Acc: 80.59% | Test Acc: 91.05%
Saved best model

Epoch 2/12


Training:  81%|████████▏ | 318/391 [03:09<00:43,  1.67it/s]

6. Accuracy comparison table

In [None]:
df = pd.DataFrame(results, columns=["Experiment", "Best Test Accuracy (%)"])
df["Improvement (%)"] = df["Best Test Accuracy (%)"] - df.loc[0, "Best Test Accuracy (%)"]
df


7. Training curves visualization

In [None]:
def plot_history(history, title="Training Curves"):
    epochs_range = range(1, EPOCHS+1)

    plt.figure(figsize=(12,5))

    plt.subplot(1,2,1)
    plt.plot(epochs_range, history["train_loss"], label="Train Loss")
    plt.plot(epochs_range, history["test_loss"], label="Test Loss")
    plt.title(title + " (Loss)")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

    plt.subplot(1,2,2)
    plt.plot(epochs_range, history["train_acc"], label="Train Accuracy")
    plt.plot(epochs_range, history["test_acc"], label="Test Accuracy")
    plt.title(title + " (Accuracy)")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.legend()

    plt.show()
    plt.savefig("results/level2_plots.png")


plot_history(hist_light, "Light Augmentation")
plot_history(hist_strong, "Strong Augmentation")
