# Library Set-up

In [1]:
!pip install warmup_scheduler_pytorch
!pip install datasets
import torch
import torchvision
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from scipy.ndimage import zoom
from math import sqrt

!pip install torchinfo
from torchinfo import summary

Collecting warmup_scheduler_pytorch
  Downloading warmup_scheduler_pytorch-0.1.2-py3-none-any.whl.metadata (3.3 kB)
Downloading warmup_scheduler_pytorch-0.1.2-py3-none-any.whl (5.7 kB)
Installing collected packages: warmup_scheduler_pytorch
Successfully installed warmup_scheduler_pytorch-0.1.2
Collecting datasets
  Downloading datasets-3.0.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.0-py3-none-any.whl (474 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.3/474.3 kB[0m [31m2

# Load Oxford Pet dataset from HuggingFace

In [2]:
#Download Oxford Pets dataset from HuggingFace
from datasets import load_dataset

In [3]:
ds = load_dataset("timm/oxford-iiit-pet")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/2.50k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/378M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/413M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3680 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3669 [00:00<?, ? examples/s]

In [4]:
#Create a list of label names - label_names[0] reutrns the breed name associated with label 0
#Creates a list of the indexes of the first insance of each label
indexes = [ds['train']['label'].index(x) for x in set(ds['train']['label'])]
label_names = []
#Iterates through the unique labels and returns the label_id name with numbers removed from the end. This leaves the breed name
for i in range(len(indexes)):
  unique_image_id = ds['train']['image_id'][indexes[i]]
  cleaned = unique_image_id.rstrip('0123456789_')
  label_names.append(cleaned)

In [5]:
#As all images in the Oxford Pet dataset have a resolution larger than 96x96 they need to undergo the following transforms as outlined in Kolesnikov et al. 2019 (page 6)
data_transform = transforms.Compose([
    # Resize the images to 448x448
    transforms.Resize(size=(448, 448)),
    # Randomly crops the images to 384x384
    transforms.RandomCrop(size=(384, 384)),
    # Flip the images randomly on the horizontal
    transforms.RandomHorizontalFlip(p=0.5),
    # Turns the images into torch.Tensor format
    transforms.ToTensor(),
    #Converts any images in RGBA format to RGB
    transforms.Lambda(lambda x: x[:3]),
])

In [6]:
#Create custom dataset with just the images in tensor format and image label
class OxfordPets(Dataset):

    def __init__(self, dataset=ds['train'], transform=None) -> None:

        # Setup transforms
        self.transform = transform
        # Create images and labels attribute
        self.images = dataset['image']
        self.labels = dataset['label']

    #Function to load images
    def load_image(self, index: int):
        '''Opens an image'''
        im = self.images[index]
        return im

    #Overwrites the __len__() method
    def __len__(self) -> int:
        '''Returns the total number of samples'''
        return len(self.images)

    #Overwrites the __getitem__() method
    def __getitem__(self, index: int):
        '''Returns one sample of data, data and label (X, y)'''
        im = self.images[index]
        label = self.labels[index]

        # Transform if necessary
        if self.transform:
            return self.transform(im), label # return data, label (X, y)
        else:
            return im, label # return data, label (X, y)

In [7]:
#Create Training and test datasets
train_data = OxfordPets(dataset=ds['train'], transform=data_transform)
test_data = OxfordPets(dataset=ds['test'], transform=data_transform)

# Vision Transformer Model

In [8]:
#Equivalent to equation 1 in Dosovitsky et al. 2021 (explanation found in section 3.1 of the paper)
class Embeddings_Set_Up(torch.nn.Module):
  '''
  Class that performs the embeddings set-up for the ViT model in Dosovitskiy et al. 2021
  Output of module is the transformer block input for the pretraining version of ViT

  Input: image or batch of images as tensor. [batch_size, num_of_colour_channels, height, width]
         for default values as laid out in Vision_Transformer_Pretraining the input size is [batch_size, 3, 224, 224]

  Output: embeddings tensor of shape [batch_size, num_patches+1, latent_vector_size]
          for default values as laid out in Vision_Transformer_Pretraining the input size is [batch_size, 197, 768]
  '''
  def __init__(self, image_resolution, num_image_channels, patch_size, latent_vector_size):
    super().__init__()
    self.unfold = torch.nn.Unfold(kernel_size=(patch_size,patch_size), stride=(patch_size,patch_size))
    self.linear_projection = torch.nn.Linear((patch_size**2)*num_image_channels,latent_vector_size,bias=True)
    self.class_embedding = torch.nn.parameter.Parameter(torch.randn(latent_vector_size))
    self.position_embeddings = torch.nn.parameter.Parameter(torch.randn(int((image_resolution/patch_size)**2+1),latent_vector_size))

  def forward(self, images):
    #Step 1 - Turn image into patches and flatten them
    image_patches = self.unfold(images)
    image_patches = image_patches.transpose(1,2)
    #Step 2 - Map each patch to a vector of length D (latent vector size)
    patch_embeddings = self.linear_projection(image_patches)
    #Step 3 - Prepend a learnable embedding for the image class to the sequence of patches
    unsqueezed_class_embedding = self.class_embedding.unsqueeze(0).unsqueeze(0)
    batched_class_embeddings = torch.cat([unsqueezed_class_embedding for i in range(patch_embeddings.shape[0])],dim=0)
    class_and_patch_embeddings = torch.cat([batched_class_embeddings,patch_embeddings],dim=1)
    #Step 4 - Add learnable position embeddings to the class and patch embeddings.
    position_embeddings = self.position_embeddings.unsqueeze(0)
    batched_position_embeddings = torch.cat([position_embeddings for i in range(patch_embeddings.shape[0])],dim=0)
    transformer_input = class_and_patch_embeddings + batched_position_embeddings

    return transformer_input

In [9]:
class Embeddings_Set_Up_Finetuning(torch.nn.Module):
  '''
  Class that performs the embeddings set-up for the ViT model in Dosovitskiy et al. 2021
  Output of module is the transformer block input for the finetuning version of ViT

  Input: image or batch of images as tensor. [batch_size, num_of_colour_channels, height, width]
         for default values as laid out in Vision_Transformer_Finetuning the input size is [batch_size, 3, 384, 384]

  Output: embeddings tensor of shape [batch_size, num_patches+1, latent_vector_size]
          for default values as laid out in Vision_Transformer_Finetuning the input size is [batch_size, 577, 768]
  '''
  def __init__(self, pretrained_model, pretraining_image_resolution, finetuning_image_resolution, num_image_channels, patch_size, latent_vector_size):
    super().__init__()
    self.unfold = torch.nn.Unfold(kernel_size=(patch_size,patch_size), stride=(patch_size,patch_size))
    self.linear_projection = torch.nn.Linear((patch_size**2)*num_image_channels,latent_vector_size,bias=True)
    #Class embedding is initialised as the pretrained model class embedding
    self.class_embedding = pretrained_model.embeddings.class_embedding
    #Position embeddings are initialised as the 2D interpolation of the pretrained model position embeddings
    self.position_embeddings = torch.nn.parameter.Parameter(interpolate_position_embedding(pretrained_model.embeddings.position_embeddings,pretraining_image_resolution,finetuning_image_resolution,patch_size,latent_vector_size))

  def forward(self, images):
    #Step 1 - Turn finetuning image into patches and flatten them
    image_patches = self.unfold(images)
    image_patches = image_patches.transpose(1,2)
    #Step 2 - Map each patch to a vector of length D (latent vector size)
    patch_embeddings = self.linear_projection(image_patches)
    #Step 3 - Prepend the pretrained class embedding to the sequence of patches
    unsqueezed_class_embedding = self.class_embedding.unsqueeze(0).unsqueeze(0)
    class_and_patch_embeddings = torch.cat([unsqueezed_class_embedding,patch_embeddings],dim=1)
    #Step 4 - Add the pretrained position embeddings to the class_and_patch embeddings
    position_embeddings = self.position_embeddings.unsqueeze(0)
    transformer_input = class_and_patch_embeddings + position_embeddings

    return transformer_input

In [11]:
#Equivalent to equations 2 & 3 of Dosovitsky et al. 2021
class Transformer_Layer(torch.nn.Module):
  '''
  Transformer Layer for ViT model from Dosovitskiy et al. 2021

  Input: embeddings tensor of shape [batch_size, num_patches+1, latent_vector_size]

  Output: embeddings tensor of shape [batch_size, num_patches+1, latent_vector_size]
  '''
  def __init__(self, num_patches, latent_vector_size, num_MSA_heads, MLP_hidden_layer_size, dropout=0.1):
    super().__init__()

    #Step 5 - Layer Normalisation
    self.layer_norm = torch.nn.LayerNorm(latent_vector_size)
    #Step 6 - Multi-headed Self Attention (MSA)
    self.MSA_module = torch.nn.MultiheadAttention(latent_vector_size, num_MSA_heads, batch_first=True)

    #Step 8 - Layer Normalisation
    self.layer_norm_2 = torch.nn.LayerNorm(latent_vector_size)
    #Step 9 - Multi-Layer Perceptron (MLP)
    self.MLP_module = torch.nn.Sequential(
        torch.nn.Linear(latent_vector_size,MLP_hidden_layer_size),
        torch.nn.Dropout(dropout), #dropout after every dense layer (from Appendix B.1 - 'Training')
        torch.nn.GELU(),
        torch.nn.Linear(MLP_hidden_layer_size,latent_vector_size),
        torch.nn.Dropout(dropout) #dropout after every dense layer (from Appendix B.1 - 'Training')
    )

  def forward(self, transformer_input):
    layer_norm = self.layer_norm(transformer_input)
    MSA_module_output, MSA_weights = self.MSA_module(layer_norm,layer_norm,layer_norm) #Step 6 - Multi-headed Self Attention (MSA)
    z_apostraphe = MSA_module_output + transformer_input #Step 7 - Residual Connection
    transformer_output = self.MLP_module(self.layer_norm_2(z_apostraphe)) + z_apostraphe #Step 10 - Residual Connection

    return transformer_output

In [12]:
class Vision_Transformer_Pretraining(torch.nn.Module):
  '''
  Replication of Vision Transformer model from Dosovitskiy et al. (2021)
  Pretraining version where classification head is a MLP with one hidden layer

  Input: image or batch of images as tensor. [batch_size, num_of_colour_channels, height, width]
         for default values, the input size is [batch_size, 3, 224, 224]

  Output: classification probabilities tensor of shape [batch_size, num_labels]
          for Oxford Pets Dataset output size is [batch_size, 37]
  '''
  def __init__(self, image_resolution=224, num_image_channels=3, num_labels=1000, patch_size=16, latent_vector_size=768, num_transformer_layers=12,
               num_MSA_heads=12,MLP_hidden_layer_size=3072,dropout=0.0):
    super().__init__()

    # Dropout
    self.dropout = torch.nn.Dropout(dropout)
    # Create patch embeddings to feed into Transformer block
    self.embeddings = Embeddings_Set_Up(image_resolution, num_image_channels, patch_size, latent_vector_size)
    # Create a Trarnsformer block with L layers, L = 'num_transformer_layers'
    self.Transformer = torch.nn.ModuleList([Transformer_Layer(int((image_resolution/patch_size)**2), latent_vector_size, num_MSA_heads, MLP_hidden_layer_size,dropout) for i in range(num_transformer_layers)])
    #Final Layer Normalisation
    self.layer_norm = torch.nn.LayerNorm(latent_vector_size)
    # Create MLP classification head
    self.classification_head = torch.nn.Sequential(
        torch.nn.Linear(latent_vector_size,MLP_hidden_layer_size),
        torch.nn.Dropout(dropout), #dropout after every dense layer (from Appendix B.1 - 'Training')
        torch.nn.Tanh(),
        torch.nn.Linear(MLP_hidden_layer_size,num_labels),
        torch.nn.Dropout(dropout), #dropout after every dense layer (from Appendix B.1 - 'Training')
    )

  def forward(self, images):
    x = self.embeddings(images)
    x = self.dropout(x)
    for layer in self.Transformer:
      x = layer(x)
    x = self.layer_norm(x)
    y = self.classification_head(x[:,0,:]) #Classification layer only applies to class embedding
    return y

In [13]:
class Vision_Transformer_Finetuning(torch.nn.Module):
  '''
  Replication of Vision Transformer model from Dosovitskiy et al. (2021)
  Finetuning version where classification head is replaced with a linear layer

  Input: image or batch of images as tensor. [batch_size, num_of_colour_channels, height, width]
         for default values, the input size is [batch_size, 3, 384, 384]

  Output: classification probabilities tensor of shape [batch_size, num_labels]
          for Oxford Pets Dataset output size is [batch_size, 37]
  '''
  def __init__(self, image_resolution=384, num_image_channels=3, num_labels=37, patch_size=16, latent_vector_size=768, num_transformer_layers=12,
               num_MSA_heads=12,MLP_hidden_layer_size=3072,dropout=0.1):
    super().__init__()

    # Dropout
    self.dropout = torch.nn.Dropout(dropout)
    # Create patch embeddings to feed into Transformer block
    self.embeddings = Embeddings_Set_Up(image_resolution, num_image_channels, patch_size, latent_vector_size)
    # Create a Trarnsformer block with L layers, L = 'num_transformer_layers'
    self.Transformer = torch.nn.ModuleList([Transformer_Layer(int((image_resolution/patch_size)**2), latent_vector_size, num_MSA_heads, MLP_hidden_layer_size,dropout) for i in range(num_transformer_layers)])
    #Final Layer Normalisation
    self.layer_norm = torch.nn.LayerNorm(latent_vector_size)
    # Create Linear classification layer to apply to class embedding
    self.classification = torch.nn.Linear(latent_vector_size,num_labels)
    # Initialise Classification Linear Layer weight
    torch.nn.init.zeros_(self.classification.weight)
    # Initialise Classification Linear Layer bias
    torch.nn.init.zeros_(self.classification.bias)


  def forward(self, images):
    x = self.embeddings(images)
    x = self.dropout(x)
    for layer in self.Transformer:
      x = layer(x)
    x = self.layer_norm(x)
    y = self.classification(x[:,0,:]) # Classification layer only applies to class embedding
    y = self.dropout(y)
    return y

# Pretraining to Finetuning helper functions

In [14]:
def vitB16_model_to_pretraining_model(pretrained_model, initialised_model):
  '''
  Transfers the model weights from the pretrained model vitB16 to an initialised pretraining model
  The output therefore has the right parameter names to be fed into Vision_Transformer_Finetuning
  '''
  initialised_model.embeddings.class_embedding = torch.nn.parameter.Parameter(pretrained_model.class_token.squeeze().squeeze())
  initialised_model.embeddings.position_embeddings = torch.nn.parameter.Parameter(pretrained_model.encoder.pos_embedding.squeeze())
  initialised_model.embeddings.linear_projection.weight = torch.nn.parameter.Parameter(pretrained_model.conv_proj.weight.flatten(1,3))
  initialised_model.embeddings.linear_projection.bias = torch.nn.parameter.Parameter(pretrained_model.conv_proj.bias)
  for i in range(12):
    initialised_model.Transformer[i].layer_norm.weight = torch.nn.parameter.Parameter(pretrained_model.encoder.layers[i].ln_1.weight)
    initialised_model.Transformer[i].layer_norm.bias = torch.nn.parameter.Parameter(pretrained_model.encoder.layers[i].ln_1.bias)
    initialised_model.Transformer[i].MSA_module.in_proj_weight = torch.nn.parameter.Parameter(pretrained_model.encoder.layers[i].self_attention.in_proj_weight)
    initialised_model.Transformer[i].MSA_module.in_proj_bias = torch.nn.parameter.Parameter(pretrained_model.encoder.layers[i].self_attention.in_proj_bias)
    initialised_model.Transformer[i].MSA_module.out_proj.weight = torch.nn.parameter.Parameter(pretrained_model.encoder.layers[i].self_attention.out_proj.weight)
    initialised_model.Transformer[i].MSA_module.out_proj.bias = torch.nn.parameter.Parameter(pretrained_model.encoder.layers[i].self_attention.out_proj.bias)
    initialised_model.Transformer[i].layer_norm_2.weight = torch.nn.parameter.Parameter(pretrained_model.encoder.layers[i].ln_2.weight)
    initialised_model.Transformer[i].layer_norm_2.bias = torch.nn.parameter.Parameter(pretrained_model.encoder.layers[i].ln_2.bias)
    initialised_model.Transformer[i].MLP_module[0].weight = torch.nn.parameter.Parameter(pretrained_model.encoder.layers[i].mlp[0].weight)
    initialised_model.Transformer[i].MLP_module[0].bias = torch.nn.parameter.Parameter(pretrained_model.encoder.layers[i].mlp[0].bias)
    initialised_model.Transformer[i].MLP_module[3].weight = torch.nn.parameter.Parameter(pretrained_model.encoder.layers[i].mlp[3].weight)
    initialised_model.Transformer[i].MLP_module[3].bias = torch.nn.parameter.Parameter(pretrained_model.encoder.layers[i].mlp[3].bias)
  initialised_model.layer_norm.weight = torch.nn.parameter.Parameter(pretrained_model.encoder.ln.weight)
  initialised_model.layer_norm.bias = torch.nn.parameter.Parameter(pretrained_model.encoder.ln.bias)

  return initialised_model

In [10]:
def interpolate_position_embedding(pretrained_embeddings,pretraining_resolution=224,finetuning_resolution=384,patch_size=16,latent_vector_size=768):
  '''
  Performs 2D interpolation of position embeddings of pretrained model
  to allow for increased resolution of images from pretraining to finetuning
  Explanation can be found in section 3.2 of Dosovitsky et al. 2021
  '''
  pretraining_num_patches = int((pretraining_resolution/patch_size)**2)
  finetuning_num_patches = int((finetuning_resolution/patch_size)**2)
  sqrt_pretraining_num_patches = int(sqrt(pretraining_num_patches))
  sqrt_finetuning_num_patches = int(sqrt(finetuning_num_patches))
  #Remove class embedding from pretrainined embeddings tensor
  old_position_embeddings = pretrained_embeddings[1:]
  #Reshape into a grid so that each patch embedding has 2D positional info
  old_position_embeddings = torch.reshape(old_position_embeddings,[sqrt_pretraining_num_patches,sqrt_pretraining_num_patches,latent_vector_size])
  #Increase grid size to allow for increased resolution of images in finetuning
  #and interpolate missing values
  scaling = sqrt_finetuning_num_patches/sqrt_pretraining_num_patches
  position_interpolated = zoom(old_position_embeddings.detach().numpy(),(scaling,scaling,1))
  position_interpolated = torch.tensor(position_interpolated)
  #Flatten the grid
  new_position_embedding = torch.reshape(position_interpolated,[int(finetuning_num_patches),latent_vector_size])
  #Prepend class embedding to new position embeddings
  new_position_embedding = torch.cat([pretrained_embeddings[0].unsqueeze(0),new_position_embedding],dim=0)

  return new_position_embedding

In [15]:
def initialise_finetuning_parameters(finetuning_model, pretrained_model, pretraining_image_resolution=224,finetuning_image_resolution=384,patch_size=16,latent_vector_size=768):
  '''
  Transfers the model weights from the pretrained model to the finetuning model
  '''
  finetuning_model.embeddings.class_embedding = pretrained_model.embeddings.class_embedding
  #interpolate the position embeddings from the pretraining to finetuning image resolution version
  interpolated = interpolate_position_embedding(pretrained_model.embeddings.position_embeddings,pretraining_image_resolution,finetuning_image_resolution,patch_size,latent_vector_size)
  finetuning_model.embeddings.position_embeddings = torch.nn.parameter.Parameter(interpolated)
  finetuning_model.embeddings.linear_projection.weight = pretrained_model.embeddings.linear_projection.weight
  finetuning_model.embeddings.linear_projection.bias = pretrained_model.embeddings.linear_projection.bias
  for i in range(12):
    finetuning_model.Transformer[i].layer_norm.weight = pretrained_model.Transformer[i].layer_norm.weight
    finetuning_model.Transformer[i].layer_norm.bias = pretrained_model.Transformer[i].layer_norm.bias
    finetuning_model.Transformer[i].MSA_module.in_proj_weight = pretrained_model.Transformer[i].MSA_module.in_proj_weight
    finetuning_model.Transformer[i].MSA_module.in_proj_bias = pretrained_model.Transformer[i].MSA_module.in_proj_bias
    finetuning_model.Transformer[i].MSA_module.out_proj.weight = pretrained_model.Transformer[i].MSA_module.out_proj.weight
    finetuning_model.Transformer[i].MSA_module.out_proj.bias = pretrained_model.Transformer[i].MSA_module.out_proj.bias
    finetuning_model.Transformer[i].layer_norm_2.weight = pretrained_model.Transformer[i].layer_norm_2.weight
    finetuning_model.Transformer[i].layer_norm_2.bias = pretrained_model.Transformer[i].layer_norm_2.bias
    finetuning_model.Transformer[i].MLP_module[0].weight = pretrained_model.Transformer[i].MLP_module[0].weight
    finetuning_model.Transformer[i].MLP_module[0].bias = pretrained_model.Transformer[i].MLP_module[0].bias
    finetuning_model.Transformer[i].MLP_module[3].weight = pretrained_model.Transformer[i].MLP_module[3].weight
    finetuning_model.Transformer[i].MLP_module[3].bias = pretrained_model.Transformer[i].MLP_module[3].bias
  finetuning_model.layer_norm.weight = pretrained_model.layer_norm.weight
  finetuning_model.layer_norm.bias = pretrained_model.layer_norm.bias

  return finetuning_model


# Tests

 Tests for model ouput shape. ouput.shape should be [1,1000] and output2.shape should be [1,37]

In [None]:
image = torch.rand([1,3,224,224])
model = Vision_Transformer_Pretraining(num_transformer_layers=12)
output = model(image)
output.shape

torch.Size([1, 1000])

In [None]:
image2 = torch.rand([1,3,384,384])
model2 = Vision_Transformer_Finetuning(num_transformer_layers=12)
output2 = model2(image2)
output2.shape

torch.Size([1, 37])

Use these to compare parameters in pretrained vitB16 model downloaded from torchvision and the model created using this code.

In [None]:
for name, param in vitB16_model.state_dict().items():
    print(name, param.size())

In [None]:
for name, param in model2.state_dict().items():
    print(name, param.size())

Use the summary function to compare parameter numbers for vitB16 model and model created using this code

In [None]:
summary(model,input_size=(1,3,224,224))

In [None]:
summary(model2,input_size=(1,3,384,384))

In [None]:
#Using vitB16 model from torchvision.models
vitB16_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
vitB16_model = torchvision.models.vit_b_16(weights=vitB16_weights)

summary(vitB16_model,input_size=(1,3,224,224))

# Training

In [16]:
def train(model, train_dataloader, loss_fn, optimizer, epochs, scheduler):

  for epoch in range(epochs):
    train_loss = 0
    correct = 0
    train_accuracy = 0

    for batch, (X, y) in enumerate(train_dataloader):
      #Move data to device
      X, y = X.to(device), y.to(device)

      # Calculate the prediction error
      pred_probabilities = model(X)
      pred = pred_probabilities.argmax(1)
      loss = loss_fn(pred_probabilities, y)

      # Backpropagation
      optimizer.zero_grad()
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), 1) #For ImageNet Pretraining and all Finetuning
      optimizer.step()
      scheduler.step()

      #Updating the number of correct predictions
      for i in range(len(pred)):
        if pred[i] == y[i]:
          correct += 1

      train_loss += loss.item()

    #Average training loss over all batches (for one epoch)
    train_loss = train_loss/len(train_dataloader)
    #Avarage Training accuracy over one iteration through the data (for one epoch)
    train_accuracy = correct/len(train_dataloader.dataset) * 100

    print(f'For epoch: {epoch} - Training Loss: {train_loss} - Training Accuracy: {train_accuracy}')

## Pretraining

Not all the hyperparameters I needed to write the pytorch code were found in the paper, so I had a look at the [original jax code repository](https://github.com/google-research/vision_transformer/blob/main/vit_jax/train.py) on Github and found the train.py and utils.py files in vision_transformer/vit_jax/

In [None]:
train_dataloader = DataLoader(train_data, batch_size=4096, shuffle=True) #Batch size found in Section 4.1 SetUp -'Training & Fine-tuning'
loss_fn = torch.nn.CrossEntropyLoss() #I found the loss used in line 48 of vision_transformer/vit_jax/train.py (from original code repo)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,betas=(0.9,0.999),weight_decay=0.03) #Hyperparameters found in Table 3 and Section 4.1
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 300*len(train_dataloader), 1e-5) #end factor found line 61 utils.py
warmup_scheduler = WarmUpScheduler(optimizer, scheduler,
                                   len_loader=len(train_dataloader),
                                   warmup_steps=10000,
                                   warmup_start_lr=0.001/10000,
                                   warmup_mode='linear')


train(model, train_dataloader, loss_fn, optimizer, 300, warmup_scheduler)


## Finetuning

In [65]:
#Using vitB16 model from torchvision.models
vitB16_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
vitB16_model = torchvision.models.vit_b_16(weights=vitB16_weights)

#Initialise a random pretrained model
model = Vision_Transformer_Pretraining()

#Transfer weights from vitB16 to randomly initialiseed model (Don't need this step
# if you train a pretrained model yourself with Vision_Transformer_Pretraining)
pretrained_model = vitB16_model_to_pretraining_model(vitB16_model,model)

#Initialise a random finetuning model
finetuning_model = Vision_Transformer_Finetuning()

#Transfer weights from pretrained model to finetuining model
finetuning_model = initialise_finetuning_parameters(finetuning_model,pretrained_model)

In [66]:
#Create training and test dataloaders. The batch size was found under Appendix B.1.1 - Finetuning (page 13)
train_dataloader = DataLoader(train_data, batch_size=8, shuffle=True)
test_batch_size = len(test_data)
test_dataloader = DataLoader(test_data, 1, shuffle=False)

In [61]:
torch.cuda.empty_cache()

In [20]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
finetuning_model.to(device)

In [69]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(finetuning_model.parameters(), lr=0.001, momentum=0.9, weight_decay=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500*len(train_dataloader), 1e-5) #Table 4

batch_size = 70

train(finetuning_model, train_dataloader, loss_fn, optimizer, epochs, scheduler)

For epoch: 0 - Training Loss: 3.2910491124443384 - Training Accuracy: 48.20652173913044
For epoch: 1 - Training Loss: 2.378184854465982 - Training Accuracy: 76.05978260869564
For epoch: 2 - Training Loss: 1.6413141294665958 - Training Accuracy: 79.86413043478261
For epoch: 3 - Training Loss: 1.2130005779473678 - Training Accuracy: 82.28260869565217
For epoch: 4 - Training Loss: 1.0455683213213216 - Training Accuracy: 82.14673913043478
For epoch: 5 - Training Loss: 0.9051577595265016 - Training Accuracy: 83.61413043478261
For epoch: 6 - Training Loss: 0.8183574061030927 - Training Accuracy: 84.375
For epoch: 7 - Training Loss: 0.7535558528226355 - Training Accuracy: 85.02717391304347
For epoch: 8 - Training Loss: 0.7243552260424779 - Training Accuracy: 85.16304347826087
For epoch: 9 - Training Loss: 0.7102991459162339 - Training Accuracy: 85.02717391304347
For epoch: 10 - Training Loss: 0.6719776741836382 - Training Accuracy: 85.08152173913044
For epoch: 11 - Training Loss: 0.6185982463

# Testing

In [54]:
from tqdm import tqdm

In [55]:
def test(model, test_data, loss_fn):
  model.eval()
  model.to(device)

  with torch.inference_mode():
    test_loss = 0
    correct = 0
    test_accuracy = 0

    for i in tqdm(range(len(test_data))):
      X, y = test_data[i][0], test_data[i][1]
      X = X.unsqueeze(0)
      X, y = torch.tensor(X).to(device), torch.tensor([y]).to(device)
      # Calculate the prediction error
      pred_probabilities = model(X)
      pred = pred_probabilities.argmax(1)
      loss = loss_fn(pred_probabilities, y)

      if pred == y:
        correct += 1

      test_loss += loss.item()

    #Average test loss over batches
    test_loss = test_loss/len(test_data)
    #Test accuracy for batch
    test_accuracy = correct/len(test_data) * 100

    print(f'Testing Loss: {test_loss} - Testing Accuracy: {test_accuracy}')

In [70]:
test(finetuning_model, test_data, loss_fn)

  X, y = torch.tensor(X).to(device), torch.tensor([y]).to(device)
100%|██████████| 3669/3669 [01:15<00:00, 48.44it/s]

Testing Loss: 0.29556379368295893 - Training Accuracy: 92.28672662850913



