# Transfer Learning - Vision
### This notebook uses Transfer Learning on a ResNet image recognition base model to be retrained to work on CIFAR-10  
### CIFAR-10 is an image recognition dataset with 10 classes:
Airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks (lorries).

Note - It may be best to run this with GPU. Try it with Google Colab's free GPU if needed.

In [None]:
### Run this if needed
# !pip install torch torchvision


In [None]:

import torch
from torch import nn, optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset


In [None]:

# Step 1: Set the device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Before training, we define how each image should be preprocessed.  

- **Resize(224Ã—224)**: ResNet was trained on ImageNet images of size 224Ã—224, so we resize our CIFAR-10 images to match.  
- **ToTensor()**: Converts the image from a PIL format into a PyTorch tensor so it can be fed into the model.  

This preprocessing ensures that our data is in the right format and scale for the ResNet model.  


In [None]:

# Step 2: Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),   # Resize images to 224x224 (ResNet expects this)
    transforms.ToTensor(),           # Convert images to PyTorch tensors
])


We load a **small subset** of the CIFAR-10 dataset for speed.  

- Training set: 5,000 images (instead of the full 50,000).  
- Test set: 1,000 images (instead of the full 10,000).  

This makes the demo much quicker to run while still showing the training process.  
We also create **DataLoaders** to handle batching and shuffling.  

ðŸ‘‰ CIFAR-10 has 10 classes (airplane, car, bird, etc.), which makes it a good benchmark for image classification.  


Note - decision in next cell - how many samples to use

In [None]:

# Step 3: Load a small subset of CIFAR-10 for quick training
train_dataset = Subset(
    datasets.CIFAR10(root="./data", train=True, download=True, transform=transform),
    range(5000)  # only first 500 images for speed if you want. 5000 will get better accuracy
)
test_dataset = Subset(
    datasets.CIFAR10(root="./data", train=False, download=True, transform=transform),
    range(1000)  # only first 100 images for speed if you want. 1000 is a better test
)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16)


In [None]:
# âœ… Optional: Show the first training image. Note that CIFAR-10 images are deliberately blurry!
import matplotlib.pyplot as plt

img, label = train_dataset[0]           # get the first image and label. Change to see a different image
img = img.permute(1, 2, 0)              # convert from C x H x W to H x W x C for plotting
plt.imshow(img)                          # display the image
plt.title(f"Label: {label}")             # show the label
plt.axis('off')                          # hide axes
plt.show()

We load **ResNet18** with pretrained weights from ImageNet.  

- ResNet18 is a well-known convolutional neural network.  
- Pretrained weights mean it has already learned useful visual features (edges, shapes, textures, etc.) from millions of images.  

We will **reuse** this feature extractor instead of training from scratch, which is called **transfer learning**.  


In [None]:

# Step 4: Load the pre-trained ResNet18 base model
# This is the **base model / backbone**: it has been trained on ImageNet
model = models.resnet18(pretrained=True)


We freeze all the parameters of ResNet18.  

- This means no gradients will be calculated for the backbone.  
- Only the new classifier head will be trained.  

This is efficient because the model already knows general features; we only need to adapt it to CIFAR-10 classes.  


Note - decision in next cells - to freeze all backbone (transfer learning) or unfreeze last layer too (fine-tuning)

In [None]:

# Step 5: Freeze the backbone so we only train the new head
# Freezing the backbone is transfer learning: we reuse pre-trained features
#for param in model.parameters():
#    param.requires_grad = False  # no gradients computed for these parameters


In [None]:

### Step 5 ALTERNATIVE  if higher accuracy wanted, but will run more slowly:
for name, param in model.named_parameters():
    if "layer4" in name or "fc" in name:  # fine-tune last ResNet block + head
        param.requires_grad = True
    else:
        param.requires_grad = False
### This allows a layer to be fine-tuned, rather than just the head


ResNet18 was originally trained to classify **1,000 ImageNet classes**.  
We replace its final fully connected (fc) layer with a new one for **10 CIFAR-10 classes**.  


In [None]:

# Step 6: Replace the original ResNet head with a new classifier for our task
# This is the **new head** we are training on CIFAR-10
model.fc = nn.Linear(model.fc.in_features, 10)  # CIFAR-10 has 10 classes
model.to(device)  # move the model to GPU or CPU


We use **cross-entropy loss**, the standard choice for multi-class classification.  

This function compares the predicted probability distribution (softmax outputs) with the true class labels and penalises incorrect predictions.  


We use the **Adam optimiser** to update the model parameters.  

Notice that we only pass in **`model.fc.parameters()`**:  
- The backbone is frozen, so we only update the new head.  
- Learning rate is set to 1e-3, a typical starting value for Adam.  


In [None]:

# Step 7: Define the loss function (cross-entropy for classification)
criterion = nn.CrossEntropyLoss()

# Step 8: Define the optimiser, only updating the head parameters
optimizer = optim.Adam(model.fc.parameters(), lr=1e-3)


We train the model for a small number of epochs (e.g. 2â€“10).  

Each iteration does the following:  
1. Load a batch of images and labels.  
2. Perform a forward pass through the model.  
3. Compute the loss.  
4. Backpropagate the gradients.  
5. Update the classifier head parameters with Adam.  

We also print the **batch loss every 10 batches** and an **epoch summary** at the end.  
This helps track training progress and ensures nothing is going wrong.  


Note - decision in next cell - how many epochs to use

In [None]:

# Step 9: Training loop (tiny, use 2 epochs for demo. 10 is better)
for epoch in range(10):
    model.train()  # set model to training mode
    running_loss = 0  # accumulate loss per epoch
    for i, (images, labels) in enumerate(train_loader): # Enumerating so we can print tracking
        images, labels = images.to(device), labels.to(device)  # move data to device
        optimizer.zero_grad()          # reset gradients
        outputs = model(images)        # forward pass
        loss = criterion(outputs, labels)  # compute loss
        loss.backward()                # backpropagation
        optimizer.step()               # update head parameters

        running_loss += loss.item()
        # Print batch loss every 10 batches
        if (i + 1) % 10 == 0:
            print(f"Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss / (i+1):.4f}")

    # Print epoch summary for tracking
    print(f"âœ… Epoch {epoch+1} completed, Average Loss: {running_loss / len(train_loader):.4f}\n")



After training, we evaluate the model on the test subset.  

- Set the model to evaluation mode (`model.eval()`).  
- Disable gradients (`torch.no_grad()`), since weâ€™re not training.  
- Make predictions on each test batch and count how many are correct.  


In [None]:

# Step 10: Quick evaluation on the test subset
model.eval()  # set model to evaluation mode
correct = 0
with torch.no_grad():  # no gradients needed for evaluation
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        preds = model(images).argmax(dim=1)  # get predicted class
        correct += (preds == labels).sum().item()


Finally, we calculate and print the **test accuracy**  

This gives a quick sense of how well our transfer learning worked on CIFAR-10.  


In [None]:

# Step 11: Print test accuracy
print("Test Accuracy:", correct / len(test_dataset))


If your accuracy is low:

*   Add more training data in Step 3
*   Unfreeze the last layer in Step 5
*   Add more epochs (try 10) in Step 9

