# Fine-Tuning Baseline Model on Small Real Dataset

This notebook fine-tunes the pretrained ResNet18 baseline model on a small labeled real dataset to improve real-domain accuracy.

In [ ]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import copy

# Paths
REAL_TRAIN_DIR = 'data/real/train'   # Small labeled real training data
REAL_VAL_DIR = 'data/real/val'
BATCH_SIZE = 16
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PRETRAINED_MODEL_PATH = 'baseline_resnet18.pth'
FINETUNED_MODEL_PATH = 'finetuned_real_resnet18.pth'

In [ ]:
# Data Transforms
train_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# Datasets & Loaders
train_dataset = datasets.ImageFolder(REAL_TRAIN_DIR, transform=train_transform)
val_dataset = datasets.ImageFolder(REAL_VAL_DIR, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

classes = train_dataset.classes
print(f'Classes: {classes}')

In [ ]:
# Load pretrained baseline model
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, len(classes))
model.load_state_dict(torch.load(PRETRAINED_MODEL_PATH, map_location=DEVICE))
model = model.to(DEVICE)

In [ ]:
# Freeze earlier layers (optional) and train only classifier
for param in model.parameters():
    param.requires_grad = True  # or freeze backbone if desired

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [ ]:
best_acc = 0.0
best_model_wts = copy.deepcopy(model.state_dict())

for epoch in range(1, NUM_EPOCHS+1):
    print(f'\nEpoch {epoch}/{NUM_EPOCHS}')
    print('-'*20)
    
    # Training
    model.train()
    running_loss = 0.0
    running_corrects = 0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        _, preds = torch.max(outputs,1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
    
    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)
    print(f'Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f}')
    
    # Validation
    model.eval()
    val_loss = 0.0
    val_corrects = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs,1)
            val_loss += loss.item() * inputs.size(0)
            val_corrects += torch.sum(preds == labels.data)
    val_loss = val_loss / len(val_dataset)
    val_acc = val_corrects.double() / len(val_dataset)
    print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}')
    
    if val_acc > best_acc:
        best_acc = val_acc
        best_model_wts = copy.deepcopy(model.state_dict())
        torch.save(model.state_dict(), FINETUNED_MODEL_PATH)
        print(f'Saved best fine-tuned model with Val Acc: {best_acc:.4f}')

print(f'\nFine-tuning complete. Best Val Acc: {best_acc:.4f}')