# Fine-tuning AlexNet on Mini-ImageNet with Mouse-Calibrated Transforms

This notebook loads an ImageNet-pretrained AlexNet, freezes the first two convolutional blocks, and fine-tunes the remaining layers on Mini-ImageNet. Images are preprocessed using the `mouse_transform` to emulate mouse visual statistics (blur + noise + optional grayscale), as in `scripts/train.py` (img_size=224, blur_sig=1.76, noise_std=0.25).

We will:
- Import the project utilities and set up reproducibility
- Define the mouse-calibrated `train` and `eval` transforms
- Load Mini-ImageNet train/val/test splits with these transforms
- Load ImageNet-pretrained AlexNet, freeze conv blocks 1–2, replace the classifier head
- Train from conv block 3 onwards, validate, and test
- Save the best model and visualize a batch


### Environment and setup
We add the repository to `sys.path` for imports, set seeds for reproducibility, and pick the best available device (GPU if present).


In [None]:
from pathlib import Path
import sys, os, random, time, math
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt

# Add repo to import path
repo_root = Path("/home/gamerio/Desktop/mousediet/mice-representation").resolve()
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

# Reproducibility
seed = 3105
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# cudnn: keep fast algorithms while avoiding full determinism penalty
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


### Mouse-calibrated transforms
We use `mouse_transform` with parameters aligned to `scripts/train.py` lines 76–78: `img_size=224`, `blur_sig=1.76`, `noise_std=0.25`. We keep channels as RGB while optionally converting luminance from grayscale weights. We also create an `UnNormalize` helper for visualization.


In [None]:
from src.pipeline.mouse_transforms import mouse_transform, UnNormalize

# Transform configuration (as in scripts/train.py lines 76–78)
IMG_SIZE   = 224
BLUR_SIG   = 1.76
NOISE_STD  = 0.25
TO_GRAY    = True
APPLY_BLUR = True
APPLY_NOISE= True

train_transform = mouse_transform(
    img_size    = IMG_SIZE,
    blur_sig    = BLUR_SIG,
    noise_std   = NOISE_STD,
    to_gray     = TO_GRAY,
    apply_blur  = APPLY_BLUR,
    apply_noise = APPLY_NOISE,
    train       = True,
    self_supervised = False,
)

eval_transform = mouse_transform(
    img_size    = IMG_SIZE,
    blur_sig    = BLUR_SIG,
    noise_std   = NOISE_STD,
    to_gray     = TO_GRAY,
    apply_blur  = APPLY_BLUR,
    apply_noise = APPLY_NOISE,
    train       = False,
)

unnorm = UnNormalize('imagenet')


### Datasets and DataLoaders
We load Mini-ImageNet splits and wrap them to return `(image, label)` tuples. Update `DATA_ROOT` to point to your dataset directory which contains `train/`, `val/`, and `test/`.


In [None]:
from src.datasets.DataManager import DataManager

# TODO: set to your Mini-ImageNet root path
DATA_ROOT = Path(os.environ.get('MINI_IMAGENET_ROOT', '/home/gamerio/Documents/archive')).resolve()
if not DATA_ROOT.exists():
    raise FileNotFoundError(f"Set DATA_ROOT to your Mini-ImageNet path. Not found: {DATA_ROOT}")

# Loader config
BATCH_SIZE  = 128
NUM_WORKERS = min(12, os.cpu_count() or 0)

# Use DataManager to create datasets and loaders with transforms
dm = DataManager(
    data_path=str(DATA_ROOT),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    train_transform=train_transform,
    eval_transform=eval_transform,
    use_cuda=(device.type == 'cuda'),
    persistent_workers=True,
    prefetch_factor=4,
    return_indices=False,
)
train_loader, val_loader, test_loader = dm.setup()

# Dataset info
num_classes = dm.num_classes
print(f"Detected {num_classes} classes")

len(dm.train_dataset), len(dm.val_dataset), len(dm.test_dataset)


### Peek at a batch
We visualize a grid of images after the mouse-calibrated transform. We un-normalize to restore displayable colors.


In [None]:
import torchvision.utils as vutils

def show_batch(loader, n=16):
    model_was_training = False
    batch = next(iter(loader))
    imgs, labels = batch
    # Unnormalize for display
    with torch.no_grad():
        imgs_disp = unnorm(imgs)
    grid = vutils.make_grid(imgs_disp[:n], nrow=int(math.sqrt(n)), padding=2, normalize=False)
    plt.figure(figsize=(8,8))
    plt.axis('off')
    plt.title('Mouse-transformed batch (unnormalized)')
    plt.imshow(np.transpose(grid.numpy(), (1,2,0)))

show_batch(train_loader)


### Model: ImageNet-pretrained AlexNet and freezing conv blocks 1–2
We load a pretrained AlexNet, replace the final classifier layer to match `num_classes`, and freeze parameters in the first two convolutional blocks (Conv1+ReLU+Pool and Conv2+ReLU+Pool). Fine-tuning begins from the third convolutional block onwards.


In [None]:
from torchvision.models import alexnet

# Load pretrained weights (newer torchvision versions use the 'weights' arg)
try:
    from torchvision.models import AlexNet_Weights
    model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
except Exception:
    model = alexnet(pretrained=True)

# Replace classifier head to match dataset class count
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, num_classes)

# Freeze first two convolutional blocks (features[0:6])
for p in model.features[:6].parameters():
    p.requires_grad = False

model = model.to(device)

# Channels-last can speed up convs on modern GPUs
if device.type == 'cuda':
    model = model.to(memory_format=torch.channels_last)

# Verify which parameters will be optimized
num_frozen = sum(not p.requires_grad for p in model.parameters())
num_trainable = sum(p.requires_grad for p in model.parameters())
print(f"Trainable params: {num_trainable} | Frozen params: {num_frozen}")


### Training utilities
Simple PyTorch training/evaluation loops with AMP mixed precision and accuracy tracking. Only parameters with `requires_grad=True` are optimized.


In [None]:
from torch.cuda.amp import GradScaler
from tqdm import tqdm

# Version-compatible autocast context manager
# Uses torch.autocast when available (accepts device_type),
# falls back to torch.cuda.amp.autocast otherwise.
def autocast_cm():
    try:
        return torch.autocast(device_type=device.type, enabled=(device.type=='cuda'))
    except Exception:
        from torch.cuda.amp import autocast as _cuda_autocast
        return _cuda_autocast(enabled=(device.type=='cuda'))

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW((p for p in model.parameters() if p.requires_grad), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
scaler = GradScaler(enabled=(device.type == 'cuda'))


def accuracy(outputs: torch.Tensor, targets: torch.Tensor) -> float:
    preds = outputs.argmax(dim=1)
    correct = (preds == targets).sum().item()
    return 100.0 * correct / targets.size(0)


def train_one_epoch(model, loader, optimizer, scaler):
    model.train()
    running_loss, running_acc, total = 0.0, 0.0, 0
    pbar = tqdm(loader, desc='Train', leave=False)
    for images, labels in pbar:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        if device.type == 'cuda':
            images = images.contiguous(memory_format=torch.channels_last)
        optimizer.zero_grad(set_to_none=True)
        with autocast_cm():
            outputs = model(images)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        acc = accuracy(outputs.detach(), labels)
        bs = labels.size(0)
        running_loss += loss.item() * bs
        running_acc  += acc * bs
        total += bs
        pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{acc:.2f}%"})
    return running_loss / total, running_acc / total


def evaluate(model, loader, desc='Val'):
    model.eval()
    running_loss, running_acc, total = 0.0, 0.0, 0
    with torch.no_grad():
        pbar = tqdm(loader, desc=desc, leave=False)
        for images, labels in pbar:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            if device.type == 'cuda':
                images = images.contiguous(memory_format=torch.channels_last)
            with autocast_cm():
                outputs = model(images)
                loss = criterion(outputs, labels)
            acc = accuracy(outputs, labels)
            bs = labels.size(0)
            running_loss += loss.item() * bs
            running_acc  += acc * bs
            total += bs
            pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{acc:.2f}%"})
    return running_loss / total, running_acc / total


### Train, validate, and save best checkpoint
We train for a few epochs, track validation accuracy, and save the best-performing weights to `artifacts/alexnet_mouse_miniimagenet_best.pth`.


In [None]:
ARTIFACTS_DIR = (repo_root / 'artifacts').resolve()
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
BEST_CKPT = ARTIFACTS_DIR / 'alexnet_mouse_miniimagenet_best.pth'

EPOCHS = 15
best_val_acc = -1.0
history = []

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, scaler)
    val_loss, val_acc = evaluate(model, val_loader, desc='Val')
    scheduler.step()

    history.append({
        'epoch': epoch+1,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc,
        'lr': scheduler.get_last_lr()[0]
    })

    print(f"Train: loss={train_loss:.4f}, acc={train_acc:.2f}% | Val: loss={val_loss:.4f}, acc={val_acc:.2f}%")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({'model': model.state_dict(), 'epoch': epoch+1, 'val_acc': best_val_acc}, BEST_CKPT)
        print(f"Saved new best checkpoint: {BEST_CKPT} (val_acc {best_val_acc:.2f}%)")

# Plot curves
plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot([h['epoch'] for h in history], [h['train_loss'] for h in history], label='train')
plt.plot([h['epoch'] for h in history], [h['val_loss'] for h in history], label='val')
plt.title('Loss'); plt.xlabel('epoch'); plt.legend();
plt.subplot(1,2,2)
plt.plot([h['epoch'] for h in history], [h['train_acc'] for h in history], label='train')
plt.plot([h['epoch'] for h in history], [h['val_acc'] for h in history], label='val')
plt.title('Accuracy'); plt.xlabel('epoch'); plt.legend();
plt.show()


### Load best checkpoint and evaluate on test set
We load the best validation checkpoint and report test loss and accuracy.


In [None]:
if BEST_CKPT.exists():
    ckpt = torch.load(BEST_CKPT, map_location=device)
    model.load_state_dict(ckpt['model'])
    print(f"Loaded best checkpoint from epoch {ckpt.get('epoch', '?')}, val_acc={ckpt.get('val_acc', float('nan')):.2f}%")
else:
    print(f"Best checkpoint not found at: {BEST_CKPT}. Using current model weights.")

test_loss, test_acc = evaluate(model, test_loader, desc='Test')
print(f"Test: loss={test_loss:.4f}, acc={test_acc:.2f}%")


### Weight differences vs pretrained AlexNet
This cell reloads a fresh ImageNet-pretrained AlexNet and compares its parameters with the current (fine-tuned) `model`. For each parameter with matching shape, it prints:
- L2 norm of the difference
- Relative L2 (L2 difference divided by pretrained L2)
- Mean absolute difference
- Max absolute difference

Parameters with shape mismatches (e.g., replaced classifier head) are noted and skipped.


In [None]:
from torchvision.models import alexnet

# Load a fresh pretrained baseline to compare against
try:
    from torchvision.models import AlexNet_Weights
    baseline = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
except Exception:
    baseline = alexnet(pretrained=True)

baseline_state = baseline.state_dict()
finetuned_state = model.state_dict()

rows = []
for name, ft_w in finetuned_state.items():
    note = ''
    if name not in baseline_state:
        rows.append((name, None, None, None, None, 'missing_in_baseline'))
        continue
    base_w = baseline_state[name]
    if base_w.shape != ft_w.shape:
        rows.append((name, None, None, None, None, f'shape_mismatch {tuple(base_w.shape)} vs {tuple(ft_w.shape)}'))
        continue
    d = (ft_w.detach().float().cpu() - base_w.detach().float().cpu())
    l2 = d.norm().item()
    base_l2 = base_w.detach().float().cpu().norm().item()
    rel = l2 / (base_l2 + 1e-12)
    mean_abs = d.abs().mean().item()
    max_abs = d.abs().max().item()
    rows.append((name, l2, rel, mean_abs, max_abs, note))

# Pretty print as a simple table
hdr = f"{'parameter':60s} {'L2':>12s} {'rel_L2':>12s} {'mean|Δ|':>12s} {'max|Δ|':>12s}  note"
print(hdr)
print('-' * len(hdr))
for name, l2, rel, mean_abs, max_abs, note in rows:
    if l2 is None:
        print(f"{name:60s} {'-':>12} {'-':>12} {'-':>12} {'-':>12}  {note}")
    else:
        print(f"{name:60s} {l2:12.6f} {rel:12.6f} {mean_abs:12.6f} {max_abs:12.6f}  {note}")
