# Semi-supervised learning using student-teacher techniques

In this week, we're going to use student-teacher techniques for semi-supervised learning. 

As a reminder from the Week 5 lecture content:
**Semi-supervised learning** uses both labelled and unlabelled data for training, often leading to better model performance than using labelled data alone, particularly when there is small quantities of labelled data.
In the context of semi-supervised learning, **student-teacher techniques** involves using a "teacher" model (trained on the labelled data) to generate pseudo-labels for unlabelled data, which then guides the training of a "student" model. You can see an illustration of this below.

![Student-teacher techniques](images/Student-teacher-Training.jpg)

To do this, we're going to follow these steps:
1. Initialise the notebook by loading necessary libraries.
2. Load the labelled data and prepare for training.
3. Initialise and train a teacher network on the labelled data.
3. Load the unlabelled data.
4. Use the teacher network to pseudo-label the unlabelled dataset.
5. Combine the pseudo-labelled data and labelled data into one dataset for training.
5. Train a student network on both the labelled and pseudo-labelled data. (Pseudo-labelling self-supervision technique)

If you finish this practical early, you could also further investigate the questions below (some resources are provided for (1), and some food for thought provided for (2) and (3) if you're interested in this).
1. Can we use noisy student-teacher learning? Does this improve performance?
2. Can you more carefully select which pseudo-labels to use? Perhaps by checking the confidence scores?
2. Can we do multiple iterations of student-teacher learning to improve performance?

# 1) Initialise the notebook with libraries

Let's load in any libraries we will use in this notebook. We're going to install weights and biases (wandb) as it does not come by default in this environment.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import glob
from PIL import Image

import tqdm

!pip install wandb

import wandb

# 2) Load the labelled data and prepare for training
## 2a) Inspect the Data
**Make sure you've extracted the dataset folder by right-clicking and selecting 'Extract Archive'**. Once you've done this, look at the different folders and how the dataset is structured. Click open some images to see what they look like and get a feel for the data you're going to be working with.

As you can see, we have the following structure:
```
└── stanford_dogs_semi-supervised
    ├── labelled                     
    |   ├── Train                     # training and validation data folder, with classes separated into sub-directories
    |   |   ├── Class 1
    |   |   ├── Class 2
    |   |   └── ....
    |   └── Test                      # test data folder, with classes separated into their own folders as sub-directories
    |       ├── Class 1
    |       ├── Class 2
    |       └── ....
    └── unlabelled                     #approx. 1000 images, not sorted into class folders
```
## 2b) Loading the Labelled Data

Last week, we saw that we can load data with this format by:
1. Applying transformations -- by default, the most basic is [transforms.ToTensor()](https://pytorch.org/vision/stable/generated/torchvision.transforms.ToTensor.html), [transforms.Resize()](https://pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html) and [transforms.Normalize()](https://pytorch.org/vision/stable/generated/torchvision.transforms.Normalize.html).
2. Load the datasets in with [torchvision.datasets.ImageFolder](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html).

**If this is unfamiliar, please review the Week 4 practical.**

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Resize((224, 224), antialias = True), 
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainval_dataset = torchvision.datasets.ImageFolder('stanford_dogs_semi-supervised/labelled/Train/', transform = transform)
test_dataset = torchvision.datasets.ImageFolder('stanford_dogs_semi-supervised/labelled/Test/', transform = transform)

class_labels = trainval_dataset.classes
num_classes = len(class_labels)

print(f'Dataset has {num_classes} classes, which are: {class_labels}')

## 2c) Split the Labelled Data and Visualise

## **IMPORTANT NEW INFORMATION**

Previously, we've used random functions to split our data into training and validation data splits. This randomness is good, but it also means that every time we run it we may get something different. This could be problematic if you save a model from some initial tests, then come back to the workbook and run more tests and re-split the training/validation dataset, because you might get data leakage (training and validation data get mixed up).

We can set the "random" seed for the libraries where we use random functions. In JupyterLab, setting the random seed ensures reproducibility across different runs of the entire notebook; this means that every time you restart and run the notebook from the beginning, you will get the same outcomes for operations involving randomness. However, within a single notebook session, if you repeatedly run a function that involves randomness without resetting the seed in each call, the outcomes can vary.

**Your turn:** Use [torch.utils.data.random_split()](https://pytorch.org/docs/stable/data.html#torch.utils.data.random_split) to randomly separate the ```trainval_dataset``` into a training and validation subset. 

You should have completed this in the Week 3 and 4 practical. You can copy and paste that here, or if you're stuck there is extra support in the IFN680 Practical Support Sheet.

Once you've done this, we'll visualise the data and the split of classes between the training and validation dataset.

In [None]:
torch.manual_seed(1) # this is important to use every time you re-run the random_split function-- it means you'll continually get the same random split of the training and validation dataset

train_portion = 0.6

##### Your code goes here ######

#####################################################


print(f'Size of train dataset: {len(train_dataset)}')
print(f'Size of val dataset: {len(val_dataset)}')
print(f'Size of test dataset: {len(test_dataset)}')

fig, ax = plt.subplots(1, 5)
for idx in range(5):
    train_image = (train_dataset[idx][0].numpy())/2 + 0.5
    label = class_labels[train_dataset[idx][1]]
    train_image = np.moveaxis(train_image, 0, 2)
    ax[idx].imshow(train_image)
    ax[idx].set_axis_off()
    ax[idx].set_title(label.split('-')[-1])
plt.tight_layout()
plt.show()

#### Your code goes below to create the histogram of class distributions between the training and validation data subsets
train_labels = [data[1] for data in train_dataset]
val_labels = [data[1] for data in val_dataset]

plt.hist([train_labels, val_labels], label = ['Train', 'Val'], density = True) 
plt.xlabel('Class Label')
plt.xticks([i for i in range(num_classes)], class_labels, rotation=90)
plt.ylabel('Density')
plt.legend()
plt.show()

## 2d) Initialise the data loaders

**Your turn:** Using a batch size of 8, initialise the data loaders for the ```train_dataset```, ```val_dataset```, and ```test_dataset``` below. The training data loader should shuffle the data. Use a single data worker per dataloader. 

If you are confused, you can check the Week 3 and Week 4 practical sheets or review the IFN680 practical support sheet.

In [None]:
######### Your code goes here ################


# 3) Initialise and train a teacher network on the labelled data

## 3a) Initialise the teacher model, loss function and optimiser
Now that we've inspected the data and done our initial pre-processing, we need to initialise some other important things before we can begin training. These include:
1. Initialise the teacher model we will use for classification, including intialising the weights with the pre-trained ImageNet weights.
2. Adapt the model for transfer learning.
3. Move the model to the GPU if it is available.
4. Initialise the optimiser (stochastic gradient descent) that will update parameters for us.
5. Initialise the loss function we will use to supervise the training of our model.

These should all be familiar from the last practical. Some of these are easy to do and don't need to be changed during the training process, i.e. the loss function initialisation. Other things, i.e. the model initialisation and optimiser initialisation, might need to be flexible as we iteratively test over and over. We're now going to create a number of functions that do these things, so that we can constantly call these functions without having to copy and paste code.

### Function 1: Create the model, ready for training
Assumed input: the number of classes of our new dataset.
1. Initialise the model we will use for classification, including intialising the weights with the pre-trained ImageNet weights.
2. Adapt the model for transfer learning.
3. Move the model to the GPU if it is available.
Assumed output: the model

### Function 2: Initialise an optimiser with certain parameters
Assumed input: the model, learning rate and momentum parameters for the optimiser.
1. Initialise the optimiser (stochastic gradient descent) that will update parameters for us.
Assumed output: an optimiser with the given parameters.

If this is unfamiliar, you can review the Week 4 practical sheet or read more in the IFN680 Practical Support Sheet.

**Your turn:** I have completed Function 1. Try adapting the code to create function 2.

In [None]:
def create_classifier(nc):
    #load the model and initialise with pre-trained weights
    model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
    
    #adapt the architecture to the correct number of classes
    in_features = model.fc.in_features
    model.fc = nn.Linear(model.fc.in_features, nc)
    
    #move the model to the GPU
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') #this line checks if we have a GPU available
    model = model.to(device)
    
    return model

#this is fairly simple code -- we don't need to re-initialise this, so I won't put it into a function.
criterion = nn.CrossEntropyLoss()  

######################### Your code goes below #######################################  
def init_optimiser():
    #complete this function, including the assumed input arguments

    return

## 3b) Train the teacher model on the training dataset

There is a cell below that will train the teacher model on the training dataset. 
It follows these steps:
1. Set the hyperparameters for training -- hyperparameters relating to the optimizer, or any other part of the training process.
2. Initialise the model and optimizer using our new functions.
2. Initialise the visualisation tool 
We recommend using ```wandb```. If this is unfamiliar, review the Week 4 practical sheet or read more in the IFN680 practical support sheet.
3. For the given number of epochs
    1. Train on the training dataset, measuring accuracy and loss over the whole dataset.
    2. Measure performance on the validation dataset, measuring accuracy and loss over the whole dataset.
    3. If the validation accuracy is higher than it previously was, save these weights as the 'best' weights.
    4. Log the performance frmo the training and validation dataset.
    
**Your turn:** Try to adapt the parameters so that you create a teacher model that converges during training. I was able to achieve a validation accuracy of 89% - aim for this as a sign that the model has converged. You should try changing the learning rate and number of epochs to achieve this.   
*Hint: Remember that this will be training the model on a very small amount of data.. we're trying to use fine-tuning here, so if you observe poor performance, it's probably because your learning rate is too high and you're "training away" the good features learnt from ImageNet.*

If any part of this is unclear, you should review the Week 4 practical sheet and IFN680 Practical Support Sheet.

In [None]:
# set the hyperparameters for training
lr = 0.01 # learning rate for SGD - 0.01
momentum = 0.9 # momentum for SGD
num_epochs = 5 # how many times to train over entire training dataset - 5

#create the teacher model ResNet18, ready for fine-tuning, with the new function
teacher_model = create_classifier(num_classes)

#create the optimiser
optimizer = init_optimiser(teacher_model, lr, momentum)


#Initialise wandb
wandb.login()

config = {
    "learning_rate": lr,
    "momentum": momentum,
    "batch_size": 8,
    "epochs": num_epochs
}

run = wandb.init(
    # Set the project where this run will be logged
    project="Week 5 practical worksheet", 
    # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
    name=f"semisupervised_teacher_{lr}_{num_epochs}", 
    # Track hyperparameters and run metadata
    config=config, reinit = True)


best_accuracy = 0 # this is the running best accuracy on the validation dataset

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') #this line checks if we have a GPU available

for epoch in range(num_epochs):    
    
    teacher_model.train()
    
    train_loss = []
    
    train_correct = 0
    train_total = 0
    for i, data in  tqdm.tqdm(enumerate(trainloader, 0), total = len(trainloader), desc = f'Epoch {epoch+1} - training phase'):
        inputs, labels = data
        
        #A. move the inputs and labels to the GPU if available
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        #B. zero the parameter gradients
        optimizer.zero_grad()

        #C.  forward pass to find the outputs
        outputs = teacher_model(inputs)
        
        #D. calculate the loss
        loss = criterion(outputs, labels)
        
        #E. backward pass to calculate the gradient
        loss.backward()
        
        #F. take a step with gradient descent to change the parameters
        optimizer.step()

        #G. let's keep track of the loss and the accuracy
        train_loss += [loss.cpu().item()]
        
        # To find the accuracy, we need to know which label is predicted by our model. This is the class with the highest class score.
        predicted = torch.argmax(outputs, axis = 1)
        
        #now we'll count how many were correct vs how many there were - we can use this later to find the total accuracy of the epoch
        train_correct += torch.sum(predicted == labels).cpu().item()
        train_total += len(labels)

    #3. record the mean loss and accuracy over the entire epoch and training dataset
    mean_train_loss = np.mean(train_loss)
    train_accuracy = train_correct/train_total
    
    #log with wandb
    wandb.log({"training_loss": mean_train_loss, "training_accuracy": train_accuracy}, step = epoch)
    
    ### VALIDATION PHASE
        
    teacher_model.eval()
    
    val_loss = []
    val_correct = 0
    val_total = 0
    for i, data in  tqdm.tqdm(enumerate(valloader, 0), total = len(valloader), desc = f'Epoch {epoch+1} - validation phase'):
        inputs, labels = data
        
        #A. move the inputs and labels to the GPU if available
        inputs = inputs.to(device)
        labels = labels.to(device)

        #B.  forward pass to find the outputs
        outputs = teacher_model(inputs)
        
        #C. calculate the loss
        loss = criterion(outputs, labels)
        
        #D. let's keep track of the loss and the accuracy
        val_loss += [loss.cpu().item()]
        
        # To find the accuracy, we need to know which label is predicted by our model. This is the class with the highest class score.
        predicted = torch.argmax(outputs, axis = 1)
        
        #now we'll count how many were correct vs how many there were - we can use this later to find the total accuracy of the epoch
        val_correct += torch.sum(predicted == labels).cpu().item()
        
        val_total += len(labels)

    #3. record the mean loss and accuracy over the entire epoch and training dataset
    mean_val_loss = np.mean(val_loss)
    val_accuracy = val_correct/val_total
    
    if val_accuracy > best_accuracy:
        torch.save(teacher_model.state_dict(), 'Week5_ResNet_teacher_best.pth')
        best_accuracy = val_accuracy
        print(f'Model saved for val accuracy {best_accuracy*100}% at epoch {epoch}')
    
    #log with wandb
    wandb.log({"val_loss": mean_val_loss, "val_accuracy": val_accuracy}, step = epoch)
    
run.finish()

# 4) Load the unlabelled data.

In our dataset, we have an ```unlabelled``` folder that is filled with images with no class label. We can't use the standard [torchvision.datasets.ImageFolder](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html) function to create a torchvision dataset in this case.

Instead, we must create a custom torch dataset. You can read through the below to understand what is going on, or look in the IFN680 Practical Support Sheet for more information.

In [None]:
class UnlabelledDataset(torch.utils.data.Dataset):
    def __init__(self, im_dir, transform):
        self.im_paths = glob.glob(im_dir) #this creates a list of all the images in im_dir
        self.transform = transform #this is the transform we want to apply to all images -- this should be the same as in the training dataset.
        self.targets = len(self.im_paths)*[-1] #initialise as invalid value (-1) because we have no labels at this point, but we want to populate this with pseudolabels
        
    def __len__(self):
        #custom datasets need a __len__ function to return the number of data points
        return len(self.im_paths)
    
    def __getitem__(self, idx):
        #custom datasets need a __getitem__ function to return the data point and label at a given index
        
        im_path = self.im_paths[idx] #find the directory of image at idx
        image = Image.open(im_path) #use the PIL.Image.open function to open image as a PIL.Image
        
        transformed_image = self.transform(image) #apply the relevant transformation
        
        label = self.targets[idx] #grab the classification label of image at idx
        
        return transformed_image, label #return the transformed image and it's respective label
    
unlabelled_dataset = UnlabelledDataset('stanford_dogs_semi-supervised/unlabelled/*', transform)  #create the dataset using images in the unlabelled folder and our initial transform for resizing and normalizing images.s

# 5) Use the teacher network to pseudo-label the unlabelled dataset.

## 5a) Load the best weights into the teacher model
In the cell below, add code to load the ```Week5_ResNet_teacher_best.pth``` weights into the ```teacher_model```. You can review the Week 4 practical sheet for how to do this, or check the IFN680 Practical Support Sheet.

In [None]:
#### Your code goes below ####


## 5b) Test the teacher model on the unlabelled dataset
In the cell below, we're going through the unlabelled dataset and finding the predicted class label with the ```teacher_model```.

**Your turn:** Change the class label (held in the ```targets``` field) for the data point at ```idx``` in the ```unlabelled_dataset``` to be equal to the ```predicted``` class label from the teacher model.

In [None]:
teacher_model.eval()

for idx in range(len(unlabelled_dataset)):
    im, lbl = unlabelled_dataset[idx] #grab the data at idx
    im = im.to(device).unsqueeze(0) #the model expects a batch of images, so we add another dimension to the image to create a batch with 1 image

    output = teacher_model(im).detach().cpu().numpy() # we need to detach the output from the gradient calculator in PyTorch, and move it back to the CPU, then convert to a numpy array
      
    # To choose a pseudo-label, we need to know which label is predicted by our model. This is the class with the highest class score.
    predicted = np.argmax(output)
    
    #Enter your code below, to change the dataset label at idx to the predicted class.


## 5b) Visualise the distribution of pseudo-labels

It's worth checking what our teacher model predicted on the unlabelled data -- how many times did it predict each class? How does this relate to the distribution of classes in the training dataset?

There's no correct answer here, but some times we can observe weird biases towards certain classes at this stage.

**Your turn:** Below, create a histogram of the ```unlabelled_dataset.targets``` and look for trends in the class predictions. If you're going well for time, compare this to the labels in the ```train_dataset```. Is there a big change in distribution between the training labels and pseudo-labels?

In [None]:
#### Enter your code below


# 6) Combine the pseudo-labelled data and labelled data into one dataset for training.

Now that we have our pseudo-labelled ```unlabelled_dataset```, we can use the [torch.utils.data.ConcatDataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.ConcatDataset) to combine this with our original ```train_dataset```.

**Your turn:** Load the newly created ```student_train_dataset``` into a dataloader, using batch size of 8 and shuffling the data.

In [None]:
student_train_dataset = torch.utils.data.ConcatDataset([unlabelled_dataset, train_dataset])

#### Enter your code below.


## 7) Train a student network on both the labelled and pseudo-labelled data.

**Your turn:** Use the code developed in Step 3b to train the student network. See what validation accuracy you can achieve -- I could get to 91.8%, this is nearly 3% better than the teacher model. Make sure you choose a sensible learning rate and number of epochs to train for.

## **IMPORTANT INFORMATION**
It is very important that if you copy and paste code, you change the following variables:
1. Rather than creating a ```teacher_model```, you should create a ```student_model```
2. Everywhere that had ```teacher_model``` should be trained to your newly created ```student_model```.
**This is seriously important, I recommend using ctrl+f to check you haven't missed any.**
3. Update your ```wandb.init()``` to show that you are training the student model.
4. Use the ```student_trainloader```, not the trainloader.
5. Update the section that saves the weights to use a filename that includes "student", not "teacher".

In [None]:
#### Enter the code for student training here.


# Next Step (if you have time): Noisy student-teacher training 
With the above setup, our student model is always going to be somewhat limited by the teacher model -- it will learn the teacher models mistakes, and probably learn the knowledge in a similar way. This means we can get performance that is generally at least equal, and somehwat better to the teacher model, but it's hard to get much better.
 
To overcome this, we can use noisy student-teacher training. Noisy student-teacher training introduces noise into the student model during training, enhancing its robustness and generalization by learning from both labeled data and the teacher's pseudo-labeled data. One way to introduce noise is by adding random data augmentations! You can see a refresher on some random data augmentations below. As a start, try using my ```noise_transform``` below and seeing if it improves performance when you add it into the above student training loop. Make sure to add the noise_transform to transform the data in the training loop, and to change the name of your run to include ```noisystudent```. I found that RandomHorizontalFlip() could achieve an accuracy of 91.4%, very similar to the standard student network result, so likely not quite enough noise to make noisy student training helpful. If this works, try adding more noisy transforms to see how you can optimise performance.

1. [transforms.RandomResizedCrop](https://pytorch.org/vision/stable/generated/torchvision.transforms.RandomResizedCrop.html) -- this function randomly grabs a portion of the image (crops) and then resizes to the desired image size. By default, the crop can be anywhere between 8% to 100% of the image original area.
2. (already implemented below) [transforms.RandomHorizontalFlip](https://pytorch.org/vision/stable/generated/torchvision.transforms.RandomHorizontalFlip.html) -- randomly flips an image horizontally. Useful for tasks where horizontal orientation doesn't change the meaning.
3. [transforms.RandomVerticalFlip](https://pytorch.org/vision/stable/generated/torchvision.transforms.RandomVerticalFlip.html) -- randomly flips an image vertically. Useful for tasks where vertical orientation doesn't change the meaning.
4. [transforms.RandomRotation](https://pytorch.org/vision/stable/generated/torchvision.transforms.RandomRotation.html) -- randomly rotates an image by a specified angle. Can simulate variations in viewpoint.
5. [transforms.ColorJitter](https://pytorch.org/vision/stable/generated/torchvision.transforms.ColorJitter.html) -- randomly changes brightness, contrast, saturation, and hue of an image. Helps the model to be robust to different lighting conditions.


In [None]:
noise_transform = transforms.Compose(
    [
     transforms.RandomHorizontalFlip(),
    ])


#visualise the train dataset with these transforms
data = next(iter(trainloader))
fig, ax = plt.subplots(1, 5)
for idx in range(5):
    im = data[0][idx]
    lbl = data[1][idx]
    im = noise_transform(im)
    train_image = (im.numpy())/2 + 0.5
    label = class_labels[lbl]
    train_image = np.moveaxis(train_image, 0, 2)
    ax[idx].imshow(train_image)
    ax[idx].set_axis_off()
    ax[idx].set_title(label.split('-')[-1])
plt.tight_layout()
plt.show()
##

In [None]:
#### Enter the code for noisy student training here.


# Further food for thought

## Improving the quality of pseudo-labels
In the lecture, we discussed how thresholding the confidence of pseudo-labels in can help ensure that the model is trained on high-quality, likely correct pseudo-labels, thereby preventing the amplification of errors and maintaining the model's performance and generalization capabilities. How would you alter the code that generates the pseudo-labels to achieve this?

## Iterative student-teacher training
In the lecture, we discussed how iterative training (using the student to become the new teacher and repeating the pseudolabel process) progressively refines pseudo-labels and model robustness by cycling improvements between the student and teacher, enhancing overall performance over time. How would you alter the code above to incorporate iterative training?
