# Finetune a ResNet image classifier

## Load data

!curl -L https://storage.googleapis.com/aiolympiadmy/ioai-2025-tsp/finetuning-resnet.zip -o data.zip

!unzip data.zip

## Establishing a baseline

Load an ImageNet pre-trained ResNet34, and check it's performance on the images in `data/test`.  

Use accuracy, precision and recall as your metrics for performance.

In [15]:
import torch
from torchvision import datasets, transforms, models
import os

# --- Baseline Evaluation ---
device = torch.device("cpu")

# Define transforms
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load test dataset
test_dataset = datasets.ImageFolder(
    os.path.join(r"C:\Users\ochon\OneDrive\Documents\2025\MAIO (Malaysia AI Olympiad)\Training\finetuning-resnet", "test"),
    transform=data_transforms
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

# Load pretrained model
model = models.resnet34(pretrained=True)
model = model.to(device)
model.eval()

# Get ImageNet class names
imagenet_classes = test_dataset.classes  # This is WRONG - we need actual ImageNet labels

# FIX: Manually load ImageNet class labels
with open("imagenet_classes.txt", "r") as f:
    imagenet_classes = [line.strip() for line in f.readlines()]

correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        
        # Map batch predictions to ImageNet class names
        predicted_classes = [imagenet_classes[p] for p in predicted.cpu().numpy()]
        
        # Map true labels to folder names
        true_classes = [test_dataset.classes[l] for l in labels.numpy()]
        
        # Compare if any word in predicted class matches folder name
        for pred, true in zip(predicted_classes, true_classes):
            if true.lower() in pred.lower():
                correct += 1
            total += 1

print(f"Baseline Accuracy: {correct/total:.4f}")

Baseline Accuracy: 0.5000


## Finetuning

Finetune this ResNet using data in `data/train`. Use `data/test` as your testing set, and use cross entropy loss. The rest is up to you. Run finetuning that terminates within approx 10 mins. Store the following info every 10 minibatches: loss, precision, recall and accuracy on train and test datasets.

In [16]:
# Define full transforms
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Reload datasets with train data
image_datasets = {
    x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
    for x in ['train', 'test']
}
dataloaders = {
    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32, shuffle=(x=='train'), num_workers=4)
    for x in ['train', 'test']
}

# Modify model for 2-class classification
model = models.resnet34(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 2)
model = model.to(device)

# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    print("-" * 10)
    
    # Training phase
    model.train()
    for batch_idx, (inputs, labels) in enumerate(dataloaders['train']):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # Log metrics every 10 batches
        if (batch_idx + 1) % 10 == 0:
            with torch.no_grad():
                _, preds = torch.max(outputs, 1)
                acc = (preds == labels).float().mean()
                
                # Test evaluation
                model.eval()
                test_acc, test_precision, test_recall = 0.0, 0.0, 0.0
                for test_inputs, test_labels in dataloaders['test']:
                    test_outputs = model(test_inputs.to(device))
                    _, test_preds = torch.max(test_outputs, 1)
                    test_acc += (test_preds == test_labels.to(device)).float().mean()
                
                test_acc /= len(dataloaders['test'])
                print(f"Batch {batch_idx+1} | Loss: {loss.item():.4f} | Train Acc: {acc:.4f} | Test Acc: {test_acc:.4f}")
                model.train()

# Save model
torch.save(model.state_dict(), "finetuned_resnet.pth")

Epoch 1/5
----------
Batch 10 | Loss: 0.3534 | Train Acc: 0.9688 | Test Acc: 1.0000
Batch 20 | Loss: 0.1807 | Train Acc: 1.0000 | Test Acc: 1.0000
Batch 30 | Loss: 0.0661 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 2/5
----------
Batch 10 | Loss: 0.0503 | Train Acc: 1.0000 | Test Acc: 1.0000
Batch 20 | Loss: 0.0278 | Train Acc: 1.0000 | Test Acc: 1.0000
Batch 30 | Loss: 0.0831 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 3/5
----------
Batch 10 | Loss: 0.0254 | Train Acc: 1.0000 | Test Acc: 1.0000
Batch 20 | Loss: 0.0305 | Train Acc: 1.0000 | Test Acc: 1.0000
Batch 30 | Loss: 0.0438 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 4/5
----------
Batch 10 | Loss: 0.0387 | Train Acc: 1.0000 | Test Acc: 1.0000
Batch 20 | Loss: 0.0155 | Train Acc: 1.0000 | Test Acc: 1.0000
Batch 30 | Loss: 0.0236 | Train Acc: 1.0000 | Test Acc: 1.0000
Epoch 5/5
----------
Batch 10 | Loss: 0.0084 | Train Acc: 1.0000 | Test Acc: 1.0000
Batch 20 | Loss: 0.0507 | Train Acc: 0.9688 | Test Acc: 1.0000
Batch 30 | Lo

## Writeup

Summarize what you did above, as well as detail the choices you made and why. What was the outcome?

## Further analysis

Pick one aspect about the work done above thus far that you find interesting, investigate it a bit further, and give a short paragraph writeup of what you investigated and how it went.