# Noisy Student Training Implementation

This is an implementation of the ["Noisy Student"](https://arxiv.org/abs/1911.04252) article, in which a teacher model is trained on a small quantity of labeled data, and is then used to produce pseudolabels for a much larger quantity of unlabeled data. The pseduolabeled data is then perturbed, and used to train a student model, which in turn produces new labels for the data. This process of repeated learning, labeling, and re-learning of perturbed data can be repeated to increase the overall accuracy of the model and improve its performance with new data. I have used the [Painter by Numbers](https://www.kaggle.com/competitions/painter-by-numbers/data) dataset from Kaggle in this implementation.

In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
import pandas as pd

from helpers import ImageDataset

Set the locations of the training dataset and unlabeled dataset

In [None]:
labeled_image_root = r''
labeled_annotations = r'' #A csv containing rows with [index, filename, encoded label]
unlabeled_image_root = r''
unlabeled_annotations = r'' #A csv containing rows with [index, filename]

In [None]:
labeled_df = pd.read_csv(labeled_annotations, names=['index', 'filename', 'style'], header=None)
unlabeled_df = pd.read_csv(unlabeled_annotations, names=['index', 'filename', 'style'], header=None)

In [None]:
# Set the save path for the models
save_folder = r''

Set up the transform to ensure the images are in the size the model expects and normalise them.

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((32, 32)),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

Set the batch size

In [None]:
batch_size = 4

Load in the training and validation datasets with the labeled data, ensuring it is transformed correctly

In [None]:
labeled_train_dataset = ImageDataset(
    root_dir=labeled_image_root,
    annotations=labeled_df,
    transform=transform
)

from sklearn.model_selection import train_test_split

labeled_train_dataset, labeled_test_dataset = train_test_split(labeled_train_dataset, test_size=0.2) # We take 20% of the data for a final test set
labeled_train_dataset, labeled_validation_dataset = train_test_split(labeled_train_dataset, test_size=0.25) # We take 20% of the original data for a validation set

labeled_train_data = DataLoader(dataset=labeled_train_dataset, shuffle=True, batch_size=batch_size)
labeled_validation_data = DataLoader(dataset=labeled_validation_dataset, shuffle=True, batch_size=batch_size)

labeled_test_data = DataLoader(dataset=labeled_test_dataset, shuffle=True, batch_size=batch_size)

Import the model, training function, set the training parameters

In [None]:
from models import CNN
from train import train

Visualisation of a piece of data going into the model to check everything looks correct.

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

train_features, train_labels, image_id = next(iter(labeled_train_data))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
print(train_features[0].shape)
img = train_features[0].squeeze()
label = train_labels[0]
image_id = image_id[0]
print(f"Label: {label}")
print(f'Image ID: {image_id}')
plt.imshow(img.T)
plt.show()


In [None]:
device = ("cuda" if torch.cuda.is_available() else "cpu")
lr = 0.001
momentum = 0.9
epochs = 2
criterion = nn.CrossEntropyLoss()

In [None]:
train(model=CNN(), training_data=labeled_train_data, validation_data=labeled_validation_data, device=device, criterion=criterion, lr=lr, momentum=momentum, epochs=epochs, save=False, save_path=save_folder)

## Using the teacher to produce pseudo-labeled data

In [None]:
unlabeled_dataset = ImageDataset(
    root_dir=unlabeled_image_root,
    annotations=unlabeled_df,
    transform=transform,
    labels=False
)

unlabeled_data = DataLoader(dataset=unlabeled_dataset, shuffle=True, batch_size=batch_size)

Check that we're inputting the images correctly

In [None]:
unlabl_features, unlabl_image_id = next(iter(unlabeled_data))
print(f"Feature batch shape: {unlabl_features.size()}")
#print(f"Labels batch shape: {train_labels.size()}")
print(unlabl_features[0].shape)
img_u = unlabl_features[0].squeeze()
#label = train_labels[0]
#print(f"Label: {label}")
plt.imshow(img_u.T)
plt.show()

Use a normal (i.e., not noised) teacher model to generate soft or hard pseudo labels for clean (i.e., not distorted) unlabeled images

In [None]:
model = CNN()
model.eval()

pseudolabel_list = []

with torch.no_grad():
    for i, data in enumerate(unlabeled_data, 0):
        outputs = model(data[0])

        _, predicted = torch.max(outputs, 1)

        for j in list(zip(data[1], predicted.tolist())):
            pseudolabel_list.append(j)

        #print(f'i: {data[1]}, Predicted: {predicted}')


pseudolabel_df = pd.DataFrame(pseudolabel_list, columns=['filename', 'style'])
pseudolabel_df.insert(loc=0, column='index', value=pseudolabel_df.index)

In [None]:
pseudolabel_df

In [None]:
combined_df = pd.concat([labeled_df, pseudolabel_df], ignore_index=True)
combined_df['index'] = combined_df.index

In [None]:
combined_df

## Train a student model which minimizes the cross entropy loss on a combination of labeled and pseudo-labeled images with noise added to the student model

In [None]:
noisy_transform = transforms.Compose(
    [transforms.TrivialAugmentWide(),
     transforms.Resize((32, 32)),
     transforms.ToTensor()
     ])

hardlabel_train_dataset = ImageDataset(
    root_dir=labeled_image_root,
    annotations=combined_df,
    transform=noisy_transform
)

hardlabel_train_dataset, hardlabel_valid_dataset = train_test_split(hardlabel_train_dataset, test_size=0.2) # We take 20% of the data for a validation test set

hardlabel_train_data = DataLoader(dataset=hardlabel_train_dataset, shuffle=True, batch_size=batch_size)
hardlabel_valid_data = DataLoader(dataset=hardlabel_valid_dataset, shuffle=True, batch_size=batch_size)



In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

train_features, train_labels, image_id = next(iter(hardlabel_train_data))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
print(train_features[0].shape)
img = train_features[0].squeeze()
label = train_labels[0]
image_id = image_id[0]
print(img.dtype)
print(f"Label: {label}")
print(f'Image ID: {image_id}')
plt.imshow(img.T)
plt.show()

In [None]:
train(model=CNN(), training_data=hardlabel_train_data, validation_data=hardlabel_valid_data, device=device, criterion=criterion, lr=lr, momentum=momentum, epochs=epochs, save=False)

In [None]:
model.eval()

pseudolabel_list = []

with torch.no_grad():
    for i, data in enumerate(unlabeled_data, 0):
        outputs = model(data[0])

        _, predicted = torch.max(outputs, 1)

        for j in list(zip(data[1], predicted.tolist())):
            pseudolabel_list.append(j)

combined_df = pd.DataFrame(pseudolabel_list, columns=['filename', 'style'])
combined_df.insert(loc=0, column='index', value=pseudolabel_df.index)

combined_df[combined_df.iloc[:, 1].isin(labeled_df.iloc[:, 1])] = labeled_df
combined_df['index'] = combined_df.index