# 0. What is Transfer Learning
Transfer learning is a machine learning technique where a pre-trained model, which has been trained on a large dataset, is reused as the starting point for a new task. Instead of training a model from scratch, you leverage the knowledge the pre-trained model has already learned, which can significantly reduce the amount of data and computational resources required for the new task. This is particularly useful in deep learning applications such as image classification, where models like VGG, ResNet, or Inception can be fine-tuned for specific tasks.

In [4]:
import torch
import torchvision

# Check if a GPU is available and use it, otherwise fallback to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


In [None]:
print(torch.__version__)
print(torchvision.__version__)

2.5.1
0.20.1


For this course we want 1.12+ and 0.13+.

Now let's import the codet we've written in the previous sections so that we don't have write it all again.

In [11]:
# Continue with regular imports
import matplotlib.pyplot as plt
import torch
import torchvision
import os
import requests
from pathlib import Path
import zipfile

from torch import nn
from torchvision import transforms
from torchinfo import summary

from going_modular import data_setup, engine

1. Get data
We need the pizza, steak and sushi data to build a transfer learning model.

In [15]:

data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi"

if image_path.is_dir():
    print("Data already downloaded")
else:
    print(f"Did not find {image_path} directory, creating one...")
    image_path.mkdir(parents=True, exist_ok=True)
    
    # Download pizza, steak, sushi data
    with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
        request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
        print("Downloading pizza, steak, sushi data...")
        f.write(request.content)

    # Unzip pizza, steak, sushi data
    with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
        print("Unzipping pizza, steak, sushi data...") 
        zip_ref.extractall(image_path)

    # Remove .zip file
    os.remove(data_path / "pizza_steak_sushi.zip")

Data already downloaded


In [16]:
train_dir = image_path / "train"
test_dir = image_path / "test"

## 2. Create Datasets and DataLoaders

Now we've got some data, want to turn it into PyTorch DataLoaders.

To do so, we can use `create_dataloaders()` function we made in 05. PyTorch going modular.

One thing to think about, when loading: How to **transform** it? With `torchvision` 0.13+ There are two ways to do it
1. Manually created transforms - you define what transforms you want the data to go through
2. Automatically created transforms - the transforms for the are data are defined by the model you'd like to use 

Important point: When using a pretrained model, it is important that the data (including the custom data) that you pass through it, is **transformed** in the same way, that the data the model was trained on.

### 2.1 Creating a transform for `torchvision.models` (manual creation)

`torchvision.models` contains pretrained models, which are models for transfer learning, right within `torchvision`.

> All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. 
> 
> The images have to be loaded in to a range of `[0, 1]` and then normalized using `mean = [0.485, 0.456, 0.406]` and `std = [0.229, 0.224, 0.225]`.
> 
> You can use the following transform to normalize:
> `normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])`
>
> Some pretrained models from torchvision.models in different sizes to `[3, 224, 224]`, for example, some might take them in `[3, 240, 240]`. For specific input image sizes, see the documentation.
> These were calculated from the data. Specifically, the ImageNet dataset by taking the means and standard deviations across a subset of images.
>
> We also don't need to do this. Neural networks are usually quite capable of figuring out appropriate data distributions (they'll calculate where the mean and standard deviations need to be on their own) but setting them at the start can help our networks achieve better performance quicker.


In [19]:
from torchvision import transforms

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

manual_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize
])

In [21]:
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
    train_dir=train_dir,
    test_dir=test_dir,
    transform=manual_transforms,
    batch_size=32
)

train_dataloader, test_dataloader, class_names
    

(<torch.utils.data.dataloader.DataLoader at 0x22c0ddddf10>,
 <torch.utils.data.dataloader.DataLoader at 0x22c0db95cd0>,
 ['pizza', 'steak', 'sushi'])

### 2.2 Creating a transforms for `torchvision.models` (auto creation)

As of `torchvision` v0.13+ there is now support for automatic data transform creation based on the pretrained model weights you're using

In [26]:
# Get a set of pretrained model weights
weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT # Default: Best available weights
weights

EfficientNet_B0_Weights.IMAGENET1K_V1

In [25]:
# Get the transforms used to create the pretrained model
auto_transforms = weights.transforms()
auto_transforms

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BICUBIC
)

In [27]:
# Create dataloaders using automatic transforms 
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
    train_dir=train_dir,
    test_dir=test_dir,
    transform=auto_transforms,
    batch_size=32
)

train_dataloader, test_dataloader, class_names

(<torch.utils.data.dataloader.DataLoader at 0x22c0db86b40>,
 <torch.utils.data.dataloader.DataLoader at 0x22c0db86de0>,
 ['pizza', 'steak', 'sushi'])