# ML that can See: Supervised Learning with Images 

Let's load in any libraries we will use in this notebook. We're also going to install a new package called weights and biases (wandb) -- more on that later in the practical.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import tqdm

!pip install wandb

import wandb

# Part 1: Prepare the Data
## Inspect the Data
**Make sure you've extracted the dataset folder by right-clicking and selecting 'Extract Archive'**. Once you've done this, look at the different folders and how the dataset is structured. Click open some images to see what they look like and get a feel for the data you're going to be working with.

## Loading the Data

This step has 2 key parts:
1. Create default transformations to apply to the data. The below 3 steps are very standard, and should always be used.
    There are a number of transformations we will consider here, these include:
    1. [transforms.ToTensor()](https://pytorch.org/vision/stable/generated/torchvision.transforms.ToTensor.html) -- this converts a PIL image or numpy array to a tensor while scaling the pixel values to the range [0, 1].
    2. [transforms.Resize()](https://pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html) -- this resizes an input image to the specified size (height, width).
    Resize is important as it ensures the dimensions remain compatible throughout the network, allowing proper operations at each layer and maintaining the required dimensions for the final fully connected layers in the network.
    3. [transforms.Normalize()](https://pytorch.org/vision/stable/generated/torchvision.transforms.Normalize.html) -- this standardizes the pixel values of a tensor image by subtracting the mean and dividing by the standard deviation along the input channels.
    
    You can then use transforms.Compose to sequentially chain multiple transforms together.
    
2. Load the datasets in with [torchvision.datasets.ImageFolder](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html) -- this loads image datasets from folders, assigning labels automatically based on subdirectories, making it convenient for tasks like image classification. If you examine our dataset in the ASL_DatasetSubset/Train folder, you'll notice that the data is stored in a subdirectory for each class - perfect for this function!

**Why Resize to 224x224?**
Many popular pre-trained models, such as AlexNet, VGG, and ResNet, were trained on the ImageNet dataset, which used images of size 224x224 pixels. We will use a ResNet architecture pre-trained on ImageNet on this practical, so will use this value.

**Why Normalize with mean and standard deviation of 0.5?**
It is common practice of using the value 0.5 for both mean and standard deviation. It's important to note that while the choice of 0.5 for mean and standard deviation is a common default, it might not be the best choice for all cases. The optimal values could vary depending on the specific dataset and task. In some scenarios, it's advisable to calculate the actual mean and standard deviation from your dataset and use those values for normalization. This could lead to better results, especially if the dataset has significantly different characteristics than the data the normalization parameters were originally chosen for.

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Resize((224, 224), antialias = True), 
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

entire_dataset = torchvision.datasets.ImageFolder('stanford_dogs_subset/', transform = transform)

You also might want to extract some key details about the dataset -- for example, what are the class labels and how many classes are there?

You can do that with the code below.

In [None]:
class_labels = entire_dataset.classes
num_classes = len(class_labels)

print(f'Dataset has {num_classes} classes, which are: {class_labels}')

## Split the Data

Below, I've split the total dataset into two subsets: a train+val subset and a test subset, using the torch.utils.data.random_split() function.

**Your turn:** Use [torch.utils.data.random_split()](https://pytorch.org/docs/stable/data.html#torch.utils.data.random_split) to randomly separate the ```trainval_dataset``` into a training and validation subset. 

Try using the documentation first, and if you're stuck there is extra support in the IFN680 Practical Support Sheet.

If you solved this in the Week 3 practical sheet, you should be able to adapt that code again here.

In [None]:
test_portion = 0.3
test_size = int(test_portion*len(entire_dataset))
trainval_size = len(entire_dataset)-test_size
trainval_dataset, test_dataset = torch.utils.data.random_split(entire_dataset, [trainval_size, test_size])

train_portion = 0.7
##### Your code goes here ######


print(f'Size of train dataset: {len(train_dataset)}')
print(f'Size of val dataset: {len(val_dataset)}')
print(f'Size of test dataset: {len(test_dataset)}')

## Visualise the data and class distribution

It's important to get a feel for the data by visualising it, and also to understand any underlying characteristics -- e.g. is there any bias you can detect in the data that might limit how it can be used in the future? what is the balance of different classes in the dataset? It's also good to check that things are still looking similar to how they were in the image files in the folders, to check that nothing in the resize or normalisation process has gone wrong.

In the cell below, I'm using matplotlib.pyplot to visualise some images with the subplot() function -- there is more information on this in the IFN680 Practical Support Sheet. I'm just going to visualise the first 5 images in the dataset -- you could do more, or randomly sample images from the dataset, to get a more representative view.

**Your turn**: Use a histogram function to visualise the distribution of class labels in the training and validation dataset, and check how consistent it is between these two subsets. If this is very imbalanced, you may want to randomly generate the train/val split again in the cell above. You can use code from the Week 3 practical sheet, and also look at the IFN680 Practical Support Sheet for more details on the plt.hist() function.


In [None]:
fig, ax = plt.subplots(1, 5)
for idx in range(5):
    train_image = (train_dataset[idx][0].numpy())/2 + 0.5
    label = class_labels[train_dataset[idx][1]]
    train_image = np.moveaxis(train_image, 0, 2)
    ax[idx].imshow(train_image)
    ax[idx].set_axis_off()
    ax[idx].set_title(label.split('-')[-1])
plt.tight_layout()
plt.show()

#### Your code goes below to create the histogram of class distributions between the training and validation data subsets


# Part 2: Initialise the model, dataloaders, loss function, and optimiser

Now that we've inspected the data and done our initial pre-processing, we need to initialise some other important things before we can begin training. These include:
1. Initialise the model we will use for classification.
2. Adapt the model for transfer learning
2. Initialise the loss function we will use to supervise the training of our model.
3. Initialise the optimiser (stochastic gradient descent) that will update parameters for us.
4. Initialise the dataloaders which will batch our datasets for testing.

## Initialise the model
We will use a pretrained ResNet18, that has been trained on ImageNet. This is very easy to do in PyTorch -- ```torchvision.models.resnet18``` loads the architecture, and using ```weights=ResNet18_Weights.IMAGENET1K_V1``` loads the trained parameters for the model after it was trained on ImageNet.

Below, I'm printing out the model so that you can see the architecture -- look through the architecture, and try to identify elements we talked about in the lecture -- Convolutional layers, ReLU, pooling layers, batch normalisation, and the final linear (or fully connected layer). You can read more about these layers in torch in the IFN680 Practical Support Sheet.

In [None]:
model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)

print(model)

## Adapt the model architecture for transfer learning

In the printed model definition above, the final Linear layer is taking an input of size 512 (in_features) and using 1000 neurons (out_features) to create 1000 class scores. The model was created this way because it was trained to perform image classification on ImageNet, which has 1000 classes.

We have a variable ```num_classes``` that is storing how many classes our new dataset has, and how many class scores we want to generate. 

**Your turn**: Complete the code block below, changing the last layer of the model to only produce ```num_classes``` class scores.

Hint: You can re-assign the final layer by creating a new linear layer using [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html). The number of input features will not change, but the number of neurons or out_features should. When you ```print(model.fc)```, you should see a Linear layer with 512 in_features and 13 out_features.

In [None]:
### Your code goes here to change the final layer of our model
model.fc = ...

print(model.fc)


Now that our model is adapted for transfer learning, we should also load the model onto our GPU if it is available -- this will massively speed up training.

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') #this line checks if we have a GPU available
model = model.to(device)

## Initialise the loss function, data loaders and optimiser

In the cell below, I've created a ```criterion``` variable that holds the instantiated [nn.CrossEntropyLoss()](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html). You can read more about it at the documentation or in the IFN680 Support Practical Sheet.

I've also created the Data Loaders -- these process the dataset into mini-batches, ready for stochastic gradient descent. You can see the IFN680 Support Practical Sheet for more details about Data Loaders, and the IFN680 Week 3 lecture slides for more details about mini-batches. I've chosen a batch size of 8 -- a bigger batch size is usually better, but to big means we don't have enough GPU space. 8 seemed to work well with our cluster GPUs. A number between 4-128 is usually reasonable depending on the size of your GPU and the size of your images.

Finally, I've instantiated a stochastic gradient descent optimizer using [optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html). You can also read more about it in the IFN680 Practical Support Sheet. I'm starting with a learning rate of 0.001 and a momentum of 0.9 -- these are reasonable starting values, but we may have to play around with these down the line to get best performance.

In [None]:
criterion = nn.CrossEntropyLoss()


batch_size = 8

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                          shuffle=True, num_workers = 1)
valloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,
                                          shuffle=False, num_workers = 1)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                          shuffle=False, num_workers = 1)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Part 3: Transfer Learning with our new dataset

## Optional: Initialise a Weights and Biases visualisation 

Last week in the practical, you used matplotlib.pyplot and the plot() function to visualise the loss and accuracy of the model as you trained. This is suitable for small-scale projects, but when you're spending longer times training, you might want to visualise the loss and accuracy as it trains! For example, what if your training de-stabilised, and you didn't realise until you'd spent 5 hours training? You also might want to experiment with many different hyperparameters, such as learning rate, and want a useful way to organise the different results. There are many ways to do this, the weights and biases library is one of those. This is **optional** -- you can continue to use the plot() function of pyplot if you'd prefer. Please also be prepared to engage with the many tutorials and guides that are well-documented on the wandb website -- for example, this 'quick start' guide: https://docs.wandb.ai/quickstart

To use ```wandb```, you must **first setup an account**: Follow this link to 'sign up' and create an account - https://wandb.ai/site

To create the project, I'm doing 2 things:
1. Checking I am logged in -- if this is the first time, you will need to enter your login credentials
2. Initialise a training run with relevant details. e.g. you can pass in an overall project name (‘Week 4 Practical’, or ‘Project 1’), a name for the training run you are about to conduct ('transfer learning fine-tune'), and also a config detailing the hyperparameters that you are going to use during training (learning rate, batch size, other optimizer hyperparameters). 

You can login to your wandb account on their website to view plots created from different training sessions of a model, or click the links below after running the code.

**Your turn:** update the values in the config dictionary to match what you have initialised above!


In [None]:
wandb.login()


#Complete the below to match what you initialied in the previous cell
config = {
    "learning_rate": 0,
    "momentum": 0,
    "batch_size": 0
    }

wandb.init(
    # Set the project where this run will be logged
    project="Week 4 practical worksheet", 
    # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
    name="transfer_learning_finetune", 
    # Track hyperparameters and run metadata
    config=config )

## Transfer Learning with fine-tuning
Now that we've initialised our model, adapted it's architecture for our new training dataset, and initialised the other elements of training (stochastic gradient descent optimizer and cross-entropy loss), we can start to train our model.

We're first going to use a fine-tuning approach, where the model's parameters are adjusted slightly to adapt its learned features to the specific nuances of the new task or domain. We're going to adjust the parameters in every layer of the network (i.e. do not freeze any layers).

The code in the cell below is very similar to what we had in Week 3 -- this is because it is the general training process you can always use for deep learning with PyTorch. It has the following steps:

**For each epoch (i.e. iteration over the entire dataset)
1. Set the model to 'train' mode. Modes ('train' or 'eval') are used to control the behavior of models, especially when dealing with models that have different behaviors during training and evaluation (testing) phases, usually due to dropout and batch normalization layers, which behave differently during training and evaluation. 
2. Grab the next batch in the trainloader (multiple images and their ground-truth labels)
    1. Move the data to the GPU if it is available.
    2. Zero the parameter gradients stored in the optimizer (stops gradients of previous batches affecting optimization).
    3. Pass the input data through the model to produce predictions.
    4. Calculate the loss by passing the predictions and ground-truth labels through the instantiated loss function.
    5. Complete a backward pass that calculates the gradients of each parameters with respect to the current loss.
    6. Use gradient descent (via the optimizer) to adjust the parameters based on the gradients from the prior step.
    7. Store any data that describes training process as necessary (e.g. loss or accuracy or other performance metrics).
3. For any data you're using for training curves, calculate the epoch-representative value and then store or log this data. Here I am logging with ```wandb.log()```. This will use the previously initialised run, where I pass in a python dictionary of the values to plot, as well as the current step in training (i.e. epoch).
    

**If the concepts described above do not feel familiar to you theoretically, please revise the Week 3 and Week 4 lecture.**

This code will take about 2 minutes to run -- make sure you check the accuracy and loss curves on the [wandb website](https://wandb.ai/).

In [None]:
total_epochs = 5

for epoch in range(total_epochs):    
    #1
    model.train()
    
    train_loss = []
    correct = 0
    total = 0
    #2 
    for i, data in  tqdm.tqdm(enumerate(trainloader, 0), total = len(trainloader), desc = f'Epoch {epoch+1} - training phase'):
        inputs, labels = data
        
        #A. move the inputs and labels to the GPU if available
        inputs = inputs.to(device)
        labels = labels.to(device)

        #B. zero the parameter gradients
        optimizer.zero_grad()

        #C.  forward pass to find the outputs
        outputs = model(inputs)
        
        #D. calculate the loss
        loss = criterion(outputs, labels)
        
        #E. backward pass to calculate the gradient
        loss.backward()
        
        #F. take a step with gradient descent to change the parameters
        optimizer.step()

        #G. let's keep track of the loss and the accuracy
        train_loss += [loss.cpu().item()]
        
        # To find the accuracy, we need to know which label is predicted by our model. This is the class with the highest class score.
        predicted = torch.argmax(outputs, axis = 1)
        
        #now we'll count how many were correct vs how many there were - we can use this later to find the total accuracy of the epoch
        correct += torch.sum(predicted == labels).cpu().item()
        total += len(labels)

    #3. record the mean loss and accuracy over the entire epoch and training dataset
    mean_train_loss = np.mean(train_loss)
    train_accuracy = correct/total
    
    #log with wandb
    wandb.log({"training_loss": mean_train_loss, "training_accuracy": train_accuracy}, step = epoch)
    
    
    #### YOUR CODE FOR THE VALIDATION DATASET GOES BELOW
    

## Your turn: Add in the validation dataset check!

The above loop trains our model, and if you visualise the results of the loss curves, it should look like it is doing really well on the training dataset. After 1 epoch, I have a training accuracy of approximately 96.5%. So is our dataset really easy, or are we overfitting on the training data?

To answer this, we need to check the performance on the validation dataset.

Go back to the cell above, and add in the validation loop. It should have the following steps:

**For each epoch (i.e. iteration over the entire dataset) -- this code already exists above, nestle the code below within it!
1. Set the model to 'eval' mode with ```model.eval()```. Modes ('train' or 'eval') are used to control the behavior of models, especially when dealing with models that have different behaviors during training and evaluation (testing) phases, usually due to dropout and batch normalization layers, which behave differently during training and evaluation. 
2. Grab the next batch in the valloader (multiple images and their ground-truth labels)
    1. Move the data to the GPU if it is available.
    3. Pass the input data through the model to produce predictions.
    4. Calculate the loss by passing the predictions and ground-truth labels through the instantiated loss function.
    7. Store any data that describes training process as necessary (e.g. loss or accuracy or other performance metrics).
3. For any data you're using for validation curves, calculate the epoch-representative value and then store or log this data.

For the above steps, you can follow a similar approach to the training loop **BUT** make sure you correctly name your variables, and at no point should you use the optimizer with your validation data (i.e. update parameters based on performance on the validation data).

## Check-in: Using the same training hyperparameters, what validation accuracy did your model achieve?

Was your model overfitting to the training dataset? How did the validation accuracy and loss curve compare? 

I found that my validation accuracy was fluctuating around 83-85% -- there's definitely some overfitting going on as the training accuracy was much higher.

## Re-loading our training dataset with data augmentations

There are a couple of things that can help with overfitting -- one would be making sure that we've added weight regularisation terms to our optimizer.

In this cell, we're going to try something new and add some data transformations. I've already added one interesting data transformation to the list of transforms for the training dataset. Go through the additional data transformation functions below, and choose some additional transforms to apply. Remember to tailor the transformations to the characteristics the dataset. We also only apply data augmentation to the training set, while the validation and test sets will use the standard transform that does not include data augmentation.

**Your turn: Pick additional data transforms and add them to the train_transforms list!**

1. (already implemented below) [transforms.RandomResizedCrop](https://pytorch.org/vision/stable/generated/torchvision.transforms.RandomResizedCrop.html) -- this function randomly grabs a portion of the image (crops) and then resizes to the desired image size. By default, the crop can be anywhere between 8% to 100% of the image original area -- this is a little strict, I'm going to choose between 50% and 100% of the image area.
2. [transforms.RandomHorizontalFlip](https://pytorch.org/vision/stable/generated/torchvision.transforms.RandomHorizontalFlip.html) -- randomly flips an image horizontally. Useful for tasks where horizontal orientation doesn't change the meaning.
3. [transforms.RandomVerticalFlip](https://pytorch.org/vision/stable/generated/torchvision.transforms.RandomVerticalFlip.html) -- randomly flips an image vertically. Useful for tasks where vertical orientation doesn't change the meaning.
4. [transforms.RandomRotation](https://pytorch.org/vision/stable/generated/torchvision.transforms.RandomRotation.html) -- randomly rotates an image by a specified angle. Can simulate variations in viewpoint.
5. [transforms.ColorJitter](https://pytorch.org/vision/stable/generated/torchvision.transforms.ColorJitter.html) -- randomly changes brightness, contrast, saturation, and hue of an image. Helps the model to be robust to different lighting conditions.

Try to complete this task using the PyTorch documentation. Otherwise if you're stuck, there is additional support information in the IFN680 Practical Support Sheet.

Once we've done this, there's a few way to use this transform -- we can add it into the training loop (what we will do), or you can make a custom dataset that automatically applies this transformation for the training subset (see here: https://discuss.pytorch.org/t/transforms-on-subset/166836)

**What's going on here?**

Each of the 'Random' transformations will be sequentially applied to an input image, with different transformations of different severities - the severity of the transformation is the random component. When you chain together multiple different types of 'Random' transformations, we can end up with a huge variation of different images from our training dataset.

In [None]:
train_transform = transforms.Compose(
    [transforms.RandomResizedCrop((224, 224), scale = (0.5, 1)),
    ])


#visualise the train dataset with these transforms
data = next(iter(trainloader))
fig, ax = plt.subplots(1, 5)
for idx in range(5):
    im = data[0][idx]
    lbl = data[1][idx]
    im = train_transform(im)
    train_image = (im.numpy())/2 + 0.5
    label = class_labels[lbl]
    train_image = np.moveaxis(train_image, 0, 2)
    ax[idx].imshow(train_image)
    ax[idx].set_axis_off()
    ax[idx].set_title(label.split('-')[-1])
plt.tight_layout()
plt.show()


## Train with Data Augmentation

First, let's re-initialise a ```wand``` run to explain what this new network's hyperparameters are, and the difference to the previous run. Make sure you add in the details of any extra transformations you are using to the config variable.

Then, we also need to re-initialise our model and optimizer -- we currently have an already trained ResNet18, we want to check how performance changes with this new set of data augmentations! We also need to make sure the optimizer is conditioned on these new parameters. 

In [None]:
#Complete the below to match what you initialied in the previous cell
config = {
    "learning_rate": 0.001,
    "momentum": 0.9,
    "batch_size": 8,
    "augmentation": ['random resized crop, scale 0.5-1']
    }

wandb.init(
    # Set the project where this run will be logged
    project="Week 4 practical worksheet", 
    # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
    name="transfer_learning_finetune_data_augmentation", 
    # Track hyperparameters and run metadata
    config=config)

model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
in_features = model.fc.in_features
model.fc = nn.Linear(model.fc.in_features, num_classes)

model.to(device)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

Below is the training code from above, except now I am also applying the transform to the inputs before they are fed into the model! Because we are adding transformations, you may find you also need to train the model for slightly longer -- I've increased the number of ```total_epochs``` to 10.

Copy and paste your code for testing and logging performance on the validation dataset into this loop.

In [None]:
total_epochs = 10

for epoch in range(total_epochs):    
    #1
    model.train()
    
    train_loss = []
    correct = 0
    total = 0
    #2 
    for i, data in  tqdm.tqdm(enumerate(trainloader, 0), total = len(trainloader), desc = f'Epoch {epoch+1} - training phase'):
        inputs, labels = data
        
        #A. move the inputs and labels to the GPU if available
        inputs = inputs.to(device)
        labels = labels.to(device)

        #NEW: Apply our data augmentation here
        inputs = train_transform(inputs)
        
        #B. zero the parameter gradients
        optimizer.zero_grad()

        #C.  forward pass to find the outputs
        outputs = model(inputs)
        
        #D. calculate the loss
        loss = criterion(outputs, labels)
        
        #E. backward pass to calculate the gradient
        loss.backward()
        
        #F. take a step with gradient descent to change the parameters
        optimizer.step()

        #G. let's keep track of the loss and the accuracy
        train_loss += [loss.cpu().item()]
        
        # To find the accuracy, we need to know which label is predicted by our model. This is the class with the highest class score.
        predicted = torch.argmax(outputs, axis = 1)
        
        #now we'll count how many were correct vs how many there were - we can use this later to find the total accuracy of the epoch
        correct += torch.sum(predicted == labels).cpu().item()
        total += len(labels)

    #3. record the mean loss and accuracy over the entire epoch and training dataset
    mean_train_loss = np.mean(train_loss)
    train_accuracy = correct/total
    
    #log with wandb
    wandb.log({"training_loss": mean_train_loss, "training_accuracy": train_accuracy}, step = epoch)
    
    
    ### COPY AND PASTE VALIDATION LOOP CODE HERE

## Food for thought
1. Did your data augmentations improve generalisation? Mine didn't -- I got a lower validation accuracy than before (only 82%). This can happen! It means I didn't choose good data augmentations for generalisation -- I only used the RandomResizedCrop augmentation. There is always some experimentation with this process, and often your first attempt doesn't work.
2. Can you justify why you picked the data augmentation transformations that you used? Why would these transformations provide better generalisation for this dataset? This will be an important thing to think about for Project 1.

## Transfer Learning with freezing layers

Freezing layers in a PyTorch model involves setting the ```requires_grad``` (i.e. requires gradient) attribute of the parameters in those layers to ```False```. This prevents the parameters from being updated during training by the optimizer. You can selectively freeze layers and train only the desired layers, such as the last fully connected layer. 

Below, I'm first re-initialising a fresh instance of ResNet18. I'm then going through every layer in the model and setting the ```requires_grad``` attribute to False. Then, I change only the last fully-connected layer ```requires_grad``` attribute to True.

If you want to look at model parameter names, you can always ```print(model)```. 

In [None]:
model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
in_features = model.fc.in_features
model.fc = nn.Linear(model.fc.in_features, num_classes)

model.to(device)


for param in model.parameters():
    param.requires_grad = False

# Unfreeze the parameters of the last fully connected layer
for param in model.fc.parameters():
    param.requires_grad = True

**Your turn:** Below, follow these steps to get results for transfer learning with freezed layers:
1. Re-initialise optimizer
2. Re-initialise a run with a suitable name in wandb
3. Copy and paste the training and validation loops from earlier to train the model.

In [None]:
### Your code goes below

## Check in -- how did transfer learning with freezing compare to transfer learning with fine-tuning?

I found that transfer learning with fine-tuning and freezing both achieved validation accuracy around 85%. Not a big difference. This is something that will change depending on the dataset, it is always a good idea to experiment with both.

## Saving the best weights

Once you've trained a model, you might want to save the model parameters so that you can use it at a later date or even do more training later! You can do this with the [torch.save](https://pytorch.org/docs/stable/generated/torch.save.html) function. 

I've demonstrated how to save the model parameters, and how you can re-load them at a later date.

**Your turn:** Rather than saving the weights after the model is finished training, could you incorporate this into the training process so that you save the weights which achieve the highest validation accuracy? e.g. What if you train for 10 epochs, but your best validation accuracy was on epoch 8?

Want extra support? Read the IFN680 Practical Support Sheet.


In [None]:
#save the model weights
torch.save(model.state_dict(), 'Week4_ResNet_freezed.pth')

#load the model weights back into the model
model.load_state_dict(torch.load('Week4_ResNet_freezed.pth'))

## Food for thought

1. Our code above could be cleaned up significantly to allow for many different test iterations -- how could you move some of the code into functions to create a cleaner script (cleaner scripts, fewer bugs!)? For example, could you have a generic init_model() function? What about a train_step() function?
2. It's important to get a feel for how the learning rate and number of epochs change performance -- even if it seems like the first number you use works well, always try training for a little longer and always try a few different learning rate values.
3. What's the performance on the test dataset? Can you use a confusion matrix to see which classes are being confused for each other? 