# Attentional Networks in Computer Vision
Prepared by comp411 Teaching Unit (TA Can Küçüksözen) in the context of Computer Vision with Deep Learning Course. Do not hesitate to ask in case you have any questions, contact me at: ckucuksozen19@ku.edu.tr

Up until this point, we have worked with deep fully-connected networks, convolutional networks and recurrent networks using them to explore different optimization strategies and network architectures. Fully-connected networks are a good testbed for experimentation because they are very computationally efficient, on the other hand, most successful image processing methods use convolutional networks. However recent state-of-the-art results on computer vision realm are acquired using Attentional layers and Transformer architectures.

First you will implement several layer types that are used in fully attentional networks. You will then use these layers to train an Attentional Image Classification network, specifically a smaller version of Vision Transformer (VIT) on the CIFAR-10 dataset. The original paper can be accessed via the following link: https://arxiv.org/pdf/2010.11929.pdf

# Part I. Preparation

First, we load the CIFAR-10 dataset. This might take a couple minutes the first time you do it, but the files should stay cached after that.

In previous parts of the assignment we had to write our own code to download the CIFAR-10 dataset, preprocess it, and iterate through it in minibatches; PyTorch provides convenient tools to automate this process for us.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision.datasets as dset
import torchvision.transforms as T

import numpy as np

In [2]:
PYTORCH_ENABLE_MPS_FALLBACK=1 

In [3]:
NUM_TRAIN = 49000

# The torchvision.transforms package provides tools for preprocessing data
# and for performing data augmentation; here we set up a transform to
# preprocess the data by subtracting the mean RGB value and dividing by the
# standard deviation of each RGB value; we've hardcoded the mean and std.
transform = T.Compose([
                T.ToTensor(),
                T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])

# We set up a Dataset object for each split (train / val / test); Datasets load
# training examples one at a time, so we wrap each Dataset in a DataLoader which
# iterates through the Dataset and forms minibatches. We divide the CIFAR-10
# training set into train and val sets by passing a Sampler object to the
# DataLoader telling how it should sample from the underlying Dataset.
cifar10_train = dset.CIFAR10('./comp411/datasets', train=True, download=False,
                             transform=transform)
loader_train = DataLoader(cifar10_train, batch_size=64, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

cifar10_val = dset.CIFAR10('./comp411/datasets', train=True, download=False,
                           transform=transform)
loader_val = DataLoader(cifar10_val, batch_size=64, 
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))

cifar10_test = dset.CIFAR10('./comp411/datasets', train=False, download=False, 
                            transform=transform)
loader_test = DataLoader(cifar10_test, batch_size=64)

You have an option to **use GPU by setting the flag to True below**. It is not necessary to use GPU for this assignment. Note that if your computer does not have CUDA enabled, `torch.cuda.is_available()` will return False and this notebook will fallback to CPU mode.

The global variables `dtype` and `device` will control the data types throughout this assignment. 

In [4]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cpu


# Part II. Barebones Transformers: Self-Attentional Layer

Here you will complete the implementation of the Pytorch nn.module `SelfAttention`, which will perform the forward pass of a self-attentional layer. Our implementation of the SelfAttentional layer will include three distinct fully connected layers which will be responsible of:

1. A fully connected layer, `W_Q`, which will be used to project our input into `queries`
2. A fully connected layer, `W_K`, which will be used to project our input into `keys`
3. A fully connected layer, `W_V`, which will be used to project our input into `values`

After defining such three fully connected layers, and obtain our `queries, keys, and values` variables at the beginning of our forward pass, the following operations should be carried out in order to complete the attentional layer implementation.

1. Seperate each of `query, key, and value` projections into their respective heads. In other words, split the feature vector dimension of each matrix into necessarry number of chunks.

2. Compute the `Attention Scores` between each pair of sequence elements via conducting a scaled dot product operation between every pair of `queries` and `keys`. Note that `Attention Scores` matrix should have the size of `[# of queries , # of keys]`

3. Calculate the `Attention Weights` of each query by applying the non-linear `Softmax` normalization accross the `keys` dimension of the `Attention Scores` matrix.

4. Obtain the output combination of `values` by matrix multiplying `Attention Weights` with `values`

5. Reassemble heads into one flat vector and return the output.

**HINT**: For a more detailed explanation of the self attentional layer, examine the Appendix A of the original ViT manuscript here:  https://arxiv.org/pdf/2010.11929.pdf 

In [5]:
class SelfAttention(nn.Module):
    
    def __init__(self, input_dims, head_dims=128, num_heads=2,  bias=False):
        super(SelfAttention, self).__init__()
        
        ## initialize module's instance variables
        self.input_dims = input_dims
        self.head_dims = head_dims
        self.num_heads = num_heads
        self.proj_dims = head_dims * num_heads
        
        ## Declare module's parameters
        self.W_Q = nn.Linear(input_dims, self.proj_dims,bias=bias)
        self.W_K = nn.Linear(input_dims, self.proj_dims,bias=bias)
        self.W_V = nn.Linear(input_dims, self.proj_dims,bias=bias)

        self.init_weights()
        
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.fill_(0.1)

    def forward(self, x):
        ## Input of shape, [B, N, D] where:
        ## - B denotes the batch size
        ## - N denotes number of sequence elements. I.e. the number of patches + the class token 
        ## - D corresponds to model dimensionality
        b,n,d = x.shape
        
        ## Construct queries,keys,values
        q_ = self.W_Q(x)
        k_ = self.W_K(x)
        v_ = self.W_V(x)
        
        ## Seperate q,k,v into their corresponding heads,
        ## After this operation each q,k,v will have the shape: [B,H,N,D//H] where
        ## - B denotes the batch size
        ## - H denotes number of heads
        ## - N denotes number of sequence elements. I.e. the number of patches + the class token 
        ## - D//H corresponds to per head dimensionality
        q, k, v = map(lambda z: torch.reshape(z, (b,n,self.num_heads,self.head_dims)).permute(0,2,1,3), [q_,k_,v_])
       
        #########################################################################################
        # TODO: Complete the forward pass of the SelfAttention layer, follow the comments below #
        #########################################################################################
        # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
        
        ## Compute attention logits. Note that this operation is conducted as a
        ## batched matrix multiplication between q and k, the output is scaled by 1/(D//H)^(1/2)
        ## inputs are queries and keys that are both of size [B,H,N,D//H]
        ## Output Attention logits should have the size: [B,H,N,N]
        
        attention_logits =  torch.matmul(q, k.transpose(-2, -1)) * (1 / ((d//self.num_heads)**(1/2)))
    
        ## Compute attention Weights. Note that this operation is conducted as a
        ## Softmax Normalization across the keys dimension. 
        ## Hint: You can apply the Softmax operation across the final dimension

        attention_weights = torch.softmax(attention_logits, dim = -1)
        
        ## Compute output values. Note that this operation is conducted as a 
        ## batched matrix multiplication between the Attention Weights matrix and 
        ## the values tensor. After computing output values, the output should be reshaped
        ## Inputs are Attention Weights with size [B, H, N, N], values with size [B, H, N, D//H]
        ## Output should be of size [B, N, D]
        ## Hint: you should use torch.matmul, torch.permute, torch.reshape in that order
        
        attn_out = torch.matmul(attention_weights, v).permute(0, 2, 1, 3).reshape(-1, n, self.proj_dims)
        
        
        # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
        ################################################################################
        #                                 END OF YOUR CODE                             
        ################################################################################
    
        return attn_out

After defining the forward pass of the Self-Attentional Layer above, run the following cell to test your implementation.

When you run this function, output should have shape (64, 16, 64).

In [6]:
def test_self_attn_layer():
    x = torch.zeros((64, 16, 32), dtype=dtype)  # minibatch size 64, sequence elements size 16, feature channels size 32
    layer = SelfAttention(32,64,4)
    out = layer(x)
    print(out.size())  # you should see [64,16,256]
test_self_attn_layer()

torch.Size([64, 16, 256])


# Part III. Barebones Transformers: Transformer Encoder Block

Here you will complete the implementation of the Pytorch nn.module `TransformerBlock`, which will perform the forward pass of a Transfomer Encoder Block. You can refer to Figure 1 of the original manuscript of ViT from this link: https://arxiv.org/pdf/2010.11929.pdf in order to get yourself familiar with the architecture.



In [7]:
## Implementation of a two layer GELU activated Fully Connected Network is provided for you below:

class MLP(nn.Module):
    def __init__(self, input_dims, hidden_dims, output_dims, bias=True):
        super().__init__()
        
        self.fc_1 = nn.Linear(input_dims, hidden_dims, bias=bias)
        self.fc_2 = nn.Linear(hidden_dims, output_dims, bias=bias)
        
        self.init_weights()
        
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.fill_(0.1)
        
    def forward(self, x):
        o = F.relu(self.fc_1(x))
        o = self.fc_2(o)
        return o

In [8]:
## Build from scratch a TransformerBlock Module. Note that the architecture of this
## module follows a simple computational pipeline:
## input --> layernorm --> SelfAttention --> skip connection 
##       --> layernorm --> MLP ---> skip connection ---> output
## Note that the TransformerBlock module works on a single hidden dimension hidden_dims,
## in order to faciliate skip connections with ease. Be careful about the input arguments
## to the SelfAttention block.


class TransformerBlock(nn.Module):
    def __init__(self, hidden_dims, num_heads=4, bias=False):
        super(TransformerBlock, self).__init__()
        
###############################################################
# TODO: Complete the consturctor of  TransformerBlock module  #
###############################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)****
        self.norm_layer = nn.LayerNorm(hidden_dims)
        self.attention = SelfAttention(hidden_dims, hidden_dims//num_heads, num_heads, bias=bias)
        self.norm_layer2 = nn.LayerNorm(hidden_dims)
        self.mlp = MLP(hidden_dims, hidden_dims, hidden_dims, bias=bias)
        
        
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
###################################################################
#                                 END OF YOUR CODE                #             
###################################################################
        
    def forward(self, x):
        
##############################################################
# TODO: Complete the forward of TransformerBlock module      #
##############################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)****

        norm_layer = self.norm_layer(x)
        attention  = self.attention(norm_layer)
        norm_layer2 = self.norm_layer2(attention + x)
        mlp = self.mlp(norm_layer2)
        output = mlp + attention + x
        return output
        
 # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
###################################################################
#                                 END OF YOUR CODE                #             
###################################################################

After defining the forward pass of the Transformer Block Layer above, run the following cell to test your implementation.

When you run this function, output should have shape (64, 16, 128).

In [9]:
def test_transfomerblock_layer():
    x = torch.zeros((64, 16, 128), dtype=dtype)  # minibatch size 64, sequence elements size 16, feature channels size 128
    layer = TransformerBlock(128,4) # hidden dims size 128, heads size 4
    out = layer(x)
    print(out.size()) 
test_transfomerblock_layer()

torch.Size([64, 16, 128])


# Part IV The Vision Transformer (ViT)

The final implementation for the Pytorch nn.module `ViT` is given to you below, which will perform the forward pass of the Vision Transformer. Study it and get yourself familiar with the API.


In [10]:
class ViT(nn.Module):
    def __init__(self, hidden_dims, input_dims=3, output_dims=10, num_trans_layers = 4, num_heads=4, image_k=32, patch_k=4, bias=False):
        super(ViT, self).__init__()
                
        ## initialize module's instance variables
        self.hidden_dims = hidden_dims
        self.input_dims = input_dims
        self.output_dims = output_dims
        self.num_trans_layers = num_trans_layers
        self.num_heads = num_heads
        self.image_k = image_k
        self.patch_k = patch_k
        
        self.image_height = self.image_width = image_k
        self.patch_height = self.patch_width = patch_k
        
        assert self.image_height % self.patch_height == 0 and self.image_width % self.patch_width == 0,\
                'Image size must be divisible by the patch size.'

        self.num_patches = (self.image_height // self.patch_height) * (self.image_width // self.patch_width)
        self.patch_flat_len = self.patch_height * self.patch_width
        
        ## Declare module's parameters
        
        ## ViT's flattened patch embedding projection:
        self.linear_embed = nn.Linear(self.input_dims*self.patch_flat_len, self.hidden_dims)
        
        ## Learnable positional embeddings, an embedding is learned for each patch location and the class token
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.hidden_dims))
        
        ## Learnable classt token and its index among attention sequence elements.
        self.cls_token = nn.Parameter(torch.randn(1,1,self.hidden_dims))
        self.cls_index = torch.LongTensor([0])
        
        ## Declare cascaded Transformer blocks:
        transformer_encoder_list = []
        for _ in range(self.num_trans_layers):
            transformer_encoder_list.append(TransformerBlock(self.hidden_dims, self.num_heads, bias))
        self.transformer_encoder = nn.Sequential(*transformer_encoder_list)
        
        ## Declare the output mlp:
        self.out_mlp = MLP(self.hidden_dims, self.hidden_dims, self.output_dims)
         
    def unfold(self, x, f = 7, st = 4, p = 0):
        ## Create sliding window pathes using nn.Functional.unfold
        ## Input dimensions: [B,D,H,W] where
        ## --B : input batch size
        ## --D : input channels
        ## --H, W: input height and width
        ## Output dimensions: [B,N,H*W,D]
        ## --N : number of patches, decided according to sliding window kernel size (f),
        ##      sliding window stride and padding.
        b,d,h,w = x.shape
        x_unf = F.unfold(x, (f,f), stride=st, padding=p)    
        x_unf = torch.reshape(x_unf.permute(0,2,1), (b,-1,d,f*f)).transpose(-1,-2)
        n = x_unf.size(1)
        return x_unf,n
    
    def forward(self, x):
        b = x.size(0)
        ## create sliding window patches from the input image
        x_patches,n = self.unfold(x, self.patch_height, self.patch_height, 0)
        ## flatten each patch into a 1d vector: i.e. 3x4x4 image patch turned into 1x1x48
        x_patch_flat = torch.reshape(x_patches, (b,n,-1))
        ## linearly embed each flattened patch
        x_embed = self.linear_embed(x_patch_flat)
        
        ## retrieve class token 
        cls_tokens = self.cls_token.repeat(b,1,1)
        ## concatanate class token to input patches
        xcls_embed = torch.cat([cls_tokens, x_embed], dim=-2)
        
        ## add positional embedding to input patches + class token 
        xcls_pos_embed = xcls_embed + self.pos_embedding
        
        ## pass through the transformer encoder
        trans_out = self.transformer_encoder(xcls_pos_embed)
        
        ## select the class token 
        out_cls_token = torch.index_select(trans_out, -2, self.cls_index.to(trans_out.device))
        
        ## create output
        out = self.out_mlp(out_cls_token)
        
        return out.squeeze(-2)

After defining the forward pass of the ViT above, run the following cell to test your implementation.

When you run this function, output should have shape (64, 10).

In [11]:
def test_vit():
    x = torch.zeros((64, 3, 32, 32), dtype=dtype)  # minibatch size 64, image size 3,32,32
    model = ViT(hidden_dims=128, input_dims=3, output_dims=10, num_trans_layers = 4, num_heads=4, image_k=32, patch_k=4)
    out = model(x)
    print(out.size()) 
test_vit()

torch.Size([64, 10])


# Part V. Train the ViT

### Check Accuracy
Given any minibatch of input data and desired targets, we can check the classification accuracy of a neural network. 

The check_batch_accuracy function is provided for you below:

In [12]:
def check_batch_accuracy(out, target,eps=1e-7):
    b, c = out.shape
    with torch.no_grad():
        _, pred = out.max(-1) 
        correct = np.sum(np.equal(pred.cpu().numpy(), target.cpu().numpy()))
    return correct, float(correct) / (b)

### Training Loop
As we have already seen in the Second Assignment, in our PyTorch based training loops, we use an Optimizer object from the `torch.optim` package, which abstract the notion of an optimization algorithm and provides implementations of most of the algorithms commonly used to optimize neural networks.

In [13]:
def train(network, optimizer, trainloader):
    """
    Train a model on CIFAR-10 using the PyTorch Module API for a single epoch
    
    Inputs:
    - network: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - trainloader: Iterable DataLoader object that fetches the minibatches
    
    Returns: overall training accuracy for the epoch
    """
    print('\nEpoch: %d' % epoch)
    network.train()  # put model to training mode
    network = network.to(device=device)  # move the model parameters to CPU/GPU
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = Variable(inputs.to(device)), targets.to(device)  # move to device, e.g. GPU
            
        outputs = network(inputs)
        loss =  F.cross_entropy(outputs, targets)
            
        # Zero out all of the gradients for the variables which the optimizer
        # will update.
        optimizer.zero_grad() 

        # This is the backwards pass: compute the gradient of the loss with
        # respect to each  parameter of the model.
        loss.backward()
            
        # Actually update the parameters of the model using the gradients
        # computed by the backwards pass.
        optimizer.step()
            
        loss = loss.detach()
        train_loss += loss.item()
        correct_p, _ = check_batch_accuracy(outputs, targets) 
        correct += correct_p
        total += targets.size(0)

        print('Loss: %.3f | Acc: %.3f%% (%d/%d)'
        % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        
    return 100.*correct/total

### Evaluation Loop
We have also prepared a Evaluation loop in order to determine our networks capabilities in terms of classification accuracy on a given dataset, either the training, or the validation split

In [14]:
def evaluate(network, evalloader):
    """
    Evaluate a model on CIFAR-10 using the PyTorch Module API for a single epoch
    
    Inputs:
    - network: A PyTorch Module giving the model to train.
    - evalloader: Iterable DataLoader object that fetches the minibatches
    
    Returns: overall evaluation accuracy for the epoch
    """
    network.eval() # put model to evaluation mode
    network = network.to(device=device)  # move the model parameters to CPU/GPU
    eval_loss = 0
    correct = 0
    total = 0
    print('\n---- Evaluation in process ----')
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(evalloader):
            inputs, targets = inputs.to(device), targets.to(device) # move to device, e.g. GPU
            outputs = network(inputs)
            loss = F.cross_entropy(outputs, targets)
            
            eval_loss += loss.item()
            correct_p, _ = check_batch_accuracy(outputs, targets)
            correct += correct_p
            total += targets.size(0)
            print('Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (eval_loss/(batch_idx+1), 100.*correct/total, correct, total))
    return 100.*correct/total

### Overfit a ViT
Now we are ready to run the training loop. A nice trick is to train your model with just a few training samples in order to see if your implementation is actually bug free. 

Simply pass the input size, hidden layer size, and number of classes (i.e. output size) to the constructor of `ViT`. 

You also need to define an optimizer that tracks all the learnable parameters inside `ViT`. We prefer to use `Adam` optimizer for this part.

You should be able to overfit small datasets, which will result in very high training accuracy and comparatively low validation accuracy.

In [15]:
sample_idx_tr = torch.randperm(len(cifar10_train))[:100]
sample_idx_val = torch.randperm(len(cifar10_train))[-100:]

trainset_sub = torch.utils.data.Subset(cifar10_train, sample_idx_tr)
valset_sub = torch.utils.data.Subset(cifar10_train, sample_idx_val)

print("For overfitting experiments, the subset of the dataset that is used has {} sample images".format(len(trainset_sub)))

batch_size_sub = 25
trainloader_sub = torch.utils.data.DataLoader(trainset_sub, batch_size=batch_size_sub, shuffle=True)
valloader_sub = torch.utils.data.DataLoader(valset_sub, batch_size=batch_size_sub, shuffle=False)

print('==> Data ready, batchsize = {}'.format(batch_size_sub))

For overfitting experiments, the subset of the dataset that is used has 100 sample images
==> Data ready, batchsize = 25


In [16]:
learning_rate = 0.002
regularization_val = 1e-6
input_dims = 3
hidden_dims = 128
output_dims=10
num_trans_layers = 4
num_heads=4
image_k=32
patch_k=4

model = None
optimizer = None

################################################################################
# TODO: Instantiate your ViT model and a corresponding optimizer #
################################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

network = ViT(hidden_dims, 
              input_dims, 
              output_dims, 
              num_trans_layers, 
              num_heads, 
              image_k, 
              patch_k)

optimizer = torch.optim.Adam(network.parameters(), lr = learning_rate)

# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
################################################################################
#                                 END OF YOUR CODE                             
################################################################################

tr_accs=[]
eval_accs=[]
for epoch in range(15):
    tr_acc = train(network, optimizer, trainloader_sub)
    print('Epoch {} of training is completed, Training accuracy for this epoch is {}'\
              .format(epoch, tr_acc))  
    
    eval_acc = evaluate(network, valloader_sub)
    print('Evaluation of Epoch {} is completed, Validation accuracy for this epoch is {}'\
              .format(epoch, eval_acc))  
    tr_accs.append(tr_acc)
    eval_accs.append(eval_acc)
    
print("\nFinal train set accuracy is {}".format(tr_accs[-1]))
print("Final val set accuracy is {}".format(eval_accs[-1]))


Epoch: 0
Loss: 3.734 | Acc: 16.000% (4/25)
Loss: 4.913 | Acc: 12.000% (6/50)
Loss: 4.549 | Acc: 10.667% (8/75)
Loss: 4.202 | Acc: 12.000% (12/100)
Epoch 0 of training is completed, Training accuracy for this epoch is 12.0

---- Evaluation in process ----
Loss: 3.228 | Acc: 0.000% (0/25)
Loss: 3.087 | Acc: 12.000% (6/50)
Loss: 3.135 | Acc: 13.333% (10/75)
Loss: 3.122 | Acc: 14.000% (14/100)
Evaluation of Epoch 0 is completed, Validation accuracy for this epoch is 14.0

Epoch: 1
Loss: 3.560 | Acc: 28.000% (7/25)
Loss: 2.860 | Acc: 30.000% (15/50)
Loss: 2.836 | Acc: 24.000% (18/75)
Loss: 2.758 | Acc: 21.000% (21/100)
Epoch 1 of training is completed, Training accuracy for this epoch is 21.0

---- Evaluation in process ----
Loss: 2.408 | Acc: 24.000% (6/25)
Loss: 2.624 | Acc: 16.000% (8/50)
Loss: 2.704 | Acc: 14.667% (11/75)
Loss: 2.768 | Acc: 14.000% (14/100)
Evaluation of Epoch 1 is completed, Validation accuracy for this epoch is 14.0

Epoch: 2
Loss: 1.588 | Acc: 48.000% (12/25)
Loss: 

## Train the net
By training the four-layer ViT network for three epochs, with untuned hyperparameters that are initialized as below,  you should achieve greater than 50% accuracy both on the training set and the test set:

In [17]:
learning_rate = 0.002
regularization_val = 1e-6
input_dims = 3
hidden_dims = 128
output_dims=10
num_trans_layers = 4
num_heads=4
image_k=32
patch_k=4

network = None
optimizer = None

################################################################################
# TODO: Instantiate your ViT model and a corresponding optimizer #
################################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

network = ViT(hidden_dims, 
            input_dims, 
            output_dims, 
            num_trans_layers, 
            num_heads, 
            image_k, 
            patch_k)

optimizer = torch.optim.Adam(network.parameters(), lr = learning_rate)

# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
################################################################################
#                                 END OF YOUR CODE                             
################################################################################

tr_accs=[]
test_accs=[]
for epoch in range(3):
    tr_acc = train(network, optimizer, loader_train)
    print('Epoch {} of training is completed, Training accuracy for this epoch is {}'\
              .format(epoch, tr_acc))  
    
    test_acc = evaluate(network, loader_test)
    print('Evaluation of Epoch {} is completed, Test accuracy for this epoch is {}'\
              .format(epoch, test_acc))  
    
    tr_accs.append(tr_acc)
    test_accs.append(test_acc)
    
print("\nFinal train set accuracy is {}".format(tr_accs[-1]))
print("Final test set accuracy is {}".format(test_accs[-1]))


Epoch: 0
Loss: 3.637 | Acc: 6.250% (4/64)
Loss: 4.643 | Acc: 8.594% (11/128)
Loss: 4.635 | Acc: 9.375% (18/192)
Loss: 4.477 | Acc: 9.766% (25/256)
Loss: 4.268 | Acc: 11.875% (38/320)
Loss: 4.235 | Acc: 12.760% (49/384)
Loss: 4.031 | Acc: 14.955% (67/448)
Loss: 3.853 | Acc: 15.430% (79/512)
Loss: 3.717 | Acc: 16.146% (93/576)
Loss: 3.592 | Acc: 16.875% (108/640)
Loss: 3.508 | Acc: 16.903% (119/704)
Loss: 3.397 | Acc: 17.708% (136/768)
Loss: 3.304 | Acc: 17.788% (148/832)
Loss: 3.225 | Acc: 17.746% (159/896)
Loss: 3.167 | Acc: 18.333% (176/960)
Loss: 3.125 | Acc: 18.750% (192/1024)
Loss: 3.050 | Acc: 19.210% (209/1088)
Loss: 2.995 | Acc: 19.358% (223/1152)
Loss: 2.958 | Acc: 19.326% (235/1216)
Loss: 2.921 | Acc: 19.766% (253/1280)
Loss: 2.909 | Acc: 19.420% (261/1344)
Loss: 2.878 | Acc: 19.673% (277/1408)
Loss: 2.841 | Acc: 19.837% (292/1472)
Loss: 2.823 | Acc: 19.596% (301/1536)
Loss: 2.800 | Acc: 19.438% (311/1600)
Loss: 2.781 | Acc: 19.231% (320/1664)
Loss: 2.755 | Acc: 19.387% (335/

Loss: 1.949 | Acc: 31.552% (4281/13568)
Loss: 1.947 | Acc: 31.609% (4309/13632)
Loss: 1.946 | Acc: 31.600% (4328/13696)
Loss: 1.945 | Acc: 31.584% (4346/13760)
Loss: 1.944 | Acc: 31.619% (4371/13824)
Loss: 1.942 | Acc: 31.675% (4399/13888)
Loss: 1.941 | Acc: 31.702% (4423/13952)
Loss: 1.941 | Acc: 31.685% (4441/14016)
Loss: 1.939 | Acc: 31.697% (4463/14080)
Loss: 1.938 | Acc: 31.745% (4490/14144)
Loss: 1.938 | Acc: 31.750% (4511/14208)
Loss: 1.937 | Acc: 31.740% (4530/14272)
Loss: 1.937 | Acc: 31.759% (4553/14336)
Loss: 1.936 | Acc: 31.764% (4574/14400)
Loss: 1.934 | Acc: 31.775% (4596/14464)
Loss: 1.933 | Acc: 31.849% (4627/14528)
Loss: 1.931 | Acc: 31.860% (4649/14592)
Loss: 1.930 | Acc: 31.871% (4671/14656)
Loss: 1.929 | Acc: 31.902% (4696/14720)
Loss: 1.929 | Acc: 31.893% (4715/14784)
Loss: 1.927 | Acc: 31.903% (4737/14848)
Loss: 1.926 | Acc: 31.934% (4762/14912)
Loss: 1.925 | Acc: 31.978% (4789/14976)
Loss: 1.924 | Acc: 32.035% (4818/15040)
Loss: 1.923 | Acc: 32.044% (4840/15104)


Loss: 1.800 | Acc: 35.795% (9553/26688)
Loss: 1.800 | Acc: 35.810% (9580/26752)
Loss: 1.800 | Acc: 35.833% (9609/26816)
Loss: 1.799 | Acc: 35.848% (9636/26880)
Loss: 1.798 | Acc: 35.878% (9667/26944)
Loss: 1.798 | Acc: 35.889% (9693/27008)
Loss: 1.797 | Acc: 35.915% (9723/27072)
Loss: 1.796 | Acc: 35.952% (9756/27136)
Loss: 1.795 | Acc: 35.978% (9786/27200)
Loss: 1.795 | Acc: 35.989% (9812/27264)
Loss: 1.795 | Acc: 36.000% (9838/27328)
Loss: 1.794 | Acc: 35.996% (9860/27392)
Loss: 1.794 | Acc: 36.007% (9886/27456)
Loss: 1.793 | Acc: 36.017% (9912/27520)
Loss: 1.793 | Acc: 36.032% (9939/27584)
Loss: 1.792 | Acc: 36.060% (9970/27648)
Loss: 1.791 | Acc: 36.089% (10001/27712)
Loss: 1.791 | Acc: 36.100% (10027/27776)
Loss: 1.790 | Acc: 36.106% (10052/27840)
Loss: 1.790 | Acc: 36.127% (10081/27904)
Loss: 1.789 | Acc: 36.152% (10111/27968)
Loss: 1.788 | Acc: 36.180% (10142/28032)
Loss: 1.788 | Acc: 36.201% (10171/28096)
Loss: 1.787 | Acc: 36.229% (10202/28160)
Loss: 1.787 | Acc: 36.246% (1023

Loss: 1.712 | Acc: 38.590% (15263/39552)
Loss: 1.711 | Acc: 38.593% (15289/39616)
Loss: 1.711 | Acc: 38.606% (15319/39680)
Loss: 1.711 | Acc: 38.622% (15350/39744)
Loss: 1.710 | Acc: 38.638% (15381/39808)
Loss: 1.710 | Acc: 38.651% (15411/39872)
Loss: 1.709 | Acc: 38.672% (15444/39936)
Loss: 1.709 | Acc: 38.690% (15476/40000)
Loss: 1.708 | Acc: 38.696% (15503/40064)
Loss: 1.708 | Acc: 38.696% (15528/40128)
Loss: 1.708 | Acc: 38.702% (15555/40192)
Loss: 1.707 | Acc: 38.722% (15588/40256)
Loss: 1.707 | Acc: 38.743% (15621/40320)
Loss: 1.706 | Acc: 38.768% (15656/40384)
Loss: 1.706 | Acc: 38.783% (15687/40448)
Loss: 1.706 | Acc: 38.796% (15717/40512)
Loss: 1.705 | Acc: 38.809% (15747/40576)
Loss: 1.705 | Acc: 38.814% (15774/40640)
Loss: 1.704 | Acc: 38.834% (15807/40704)
Loss: 1.704 | Acc: 38.844% (15836/40768)
Loss: 1.703 | Acc: 38.869% (15871/40832)
Loss: 1.703 | Acc: 38.872% (15897/40896)
Loss: 1.703 | Acc: 38.872% (15922/40960)
Loss: 1.702 | Acc: 38.882% (15951/41024)
Loss: 1.702 | Ac

Loss: 1.460 | Acc: 46.528% (1608/3456)
Loss: 1.462 | Acc: 46.392% (1633/3520)
Loss: 1.456 | Acc: 46.680% (1673/3584)
Loss: 1.455 | Acc: 46.711% (1704/3648)
Loss: 1.453 | Acc: 46.902% (1741/3712)
Loss: 1.455 | Acc: 46.743% (1765/3776)
Loss: 1.452 | Acc: 46.849% (1799/3840)
Loss: 1.450 | Acc: 46.926% (1832/3904)
Loss: 1.448 | Acc: 47.077% (1868/3968)
Loss: 1.447 | Acc: 47.222% (1904/4032)
Loss: 1.450 | Acc: 47.192% (1933/4096)
Loss: 1.452 | Acc: 47.212% (1964/4160)
Loss: 1.453 | Acc: 47.183% (1993/4224)
Loss: 1.455 | Acc: 47.085% (2019/4288)
Loss: 1.453 | Acc: 47.197% (2054/4352)
Loss: 1.449 | Acc: 47.351% (2091/4416)
Loss: 1.448 | Acc: 47.321% (2120/4480)
Loss: 1.447 | Acc: 47.381% (2153/4544)
Loss: 1.448 | Acc: 47.331% (2181/4608)
Loss: 1.446 | Acc: 47.389% (2214/4672)
Loss: 1.443 | Acc: 47.551% (2252/4736)
Loss: 1.444 | Acc: 47.604% (2285/4800)
Loss: 1.442 | Acc: 47.677% (2319/4864)
Loss: 1.440 | Acc: 47.707% (2351/4928)
Loss: 1.441 | Acc: 47.596% (2376/4992)
Loss: 1.440 | Acc: 47.607

Loss: 1.412 | Acc: 48.821% (3312/6784)
Loss: 1.411 | Acc: 48.817% (3343/6848)
Loss: 1.409 | Acc: 48.886% (3379/6912)
Loss: 1.410 | Acc: 48.853% (3408/6976)
Loss: 1.409 | Acc: 48.878% (3441/7040)
Loss: 1.408 | Acc: 48.958% (3478/7104)
Loss: 1.409 | Acc: 48.954% (3509/7168)
Loss: 1.409 | Acc: 48.963% (3541/7232)
Loss: 1.407 | Acc: 49.013% (3576/7296)
Loss: 1.407 | Acc: 49.022% (3608/7360)
Loss: 1.407 | Acc: 49.003% (3638/7424)
Loss: 1.408 | Acc: 49.038% (3672/7488)
Loss: 1.408 | Acc: 49.033% (3703/7552)
Loss: 1.408 | Acc: 49.041% (3735/7616)
Loss: 1.408 | Acc: 49.023% (3765/7680)
Loss: 1.408 | Acc: 49.083% (3801/7744)
Loss: 1.408 | Acc: 49.065% (3831/7808)
Loss: 1.407 | Acc: 49.085% (3864/7872)
Loss: 1.407 | Acc: 49.168% (3902/7936)
Loss: 1.407 | Acc: 49.138% (3931/8000)
Loss: 1.409 | Acc: 49.082% (3958/8064)
Loss: 1.409 | Acc: 49.102% (3991/8128)
Loss: 1.408 | Acc: 49.109% (4023/8192)
Loss: 1.410 | Acc: 49.019% (4047/8256)
Loss: 1.409 | Acc: 49.038% (4080/8320)
Loss: 1.408 | Acc: 49.070

Loss: 1.406 | Acc: 49.314% (9847/19968)
Loss: 1.406 | Acc: 49.306% (9877/20032)
Loss: 1.406 | Acc: 49.293% (9906/20096)
Loss: 1.406 | Acc: 49.315% (9942/20160)
Loss: 1.406 | Acc: 49.308% (9972/20224)
Loss: 1.406 | Acc: 49.300% (10002/20288)
Loss: 1.406 | Acc: 49.278% (10029/20352)
Loss: 1.406 | Acc: 49.300% (10065/20416)
Loss: 1.407 | Acc: 49.253% (10087/20480)
Loss: 1.406 | Acc: 49.275% (10123/20544)
Loss: 1.406 | Acc: 49.272% (10154/20608)
Loss: 1.405 | Acc: 49.284% (10188/20672)
Loss: 1.405 | Acc: 49.257% (10214/20736)
Loss: 1.406 | Acc: 49.250% (10244/20800)
Loss: 1.406 | Acc: 49.243% (10274/20864)
Loss: 1.406 | Acc: 49.235% (10304/20928)
Loss: 1.406 | Acc: 49.214% (10331/20992)
Loss: 1.406 | Acc: 49.221% (10364/21056)
Loss: 1.406 | Acc: 49.242% (10400/21120)
Loss: 1.406 | Acc: 49.231% (10429/21184)
Loss: 1.406 | Acc: 49.242% (10463/21248)
Loss: 1.406 | Acc: 49.212% (10488/21312)
Loss: 1.406 | Acc: 49.205% (10518/21376)
Loss: 1.406 | Acc: 49.216% (10552/21440)
Loss: 1.406 | Acc: 49

Loss: 1.388 | Acc: 49.661% (16273/32768)
Loss: 1.388 | Acc: 49.662% (16305/32832)
Loss: 1.388 | Acc: 49.641% (16330/32896)
Loss: 1.388 | Acc: 49.642% (16362/32960)
Loss: 1.389 | Acc: 49.640% (16393/33024)
Loss: 1.389 | Acc: 49.649% (16428/33088)
Loss: 1.389 | Acc: 49.647% (16459/33152)
Loss: 1.390 | Acc: 49.606% (16477/33216)
Loss: 1.390 | Acc: 49.594% (16505/33280)
Loss: 1.390 | Acc: 49.571% (16529/33344)
Loss: 1.390 | Acc: 49.557% (16556/33408)
Loss: 1.390 | Acc: 49.567% (16591/33472)
Loss: 1.390 | Acc: 49.574% (16625/33536)
Loss: 1.390 | Acc: 49.568% (16655/33600)
Loss: 1.390 | Acc: 49.566% (16686/33664)
Loss: 1.390 | Acc: 49.573% (16720/33728)
Loss: 1.390 | Acc: 49.580% (16754/33792)
Loss: 1.390 | Acc: 49.575% (16784/33856)
Loss: 1.390 | Acc: 49.573% (16815/33920)
Loss: 1.390 | Acc: 49.570% (16846/33984)
Loss: 1.390 | Acc: 49.559% (16874/34048)
Loss: 1.390 | Acc: 49.554% (16904/34112)
Loss: 1.390 | Acc: 49.567% (16940/34176)
Loss: 1.390 | Acc: 49.565% (16971/34240)
Loss: 1.390 | Ac

Loss: 1.384 | Acc: 49.776% (22682/45568)
Loss: 1.384 | Acc: 49.776% (22714/45632)
Loss: 1.384 | Acc: 49.786% (22750/45696)
Loss: 1.384 | Acc: 49.792% (22785/45760)
Loss: 1.384 | Acc: 49.795% (22818/45824)
Loss: 1.384 | Acc: 49.806% (22855/45888)
Loss: 1.384 | Acc: 49.800% (22884/45952)
Loss: 1.384 | Acc: 49.802% (22917/46016)
Loss: 1.384 | Acc: 49.807% (22951/46080)
Loss: 1.383 | Acc: 49.814% (22986/46144)
Loss: 1.383 | Acc: 49.818% (23020/46208)
Loss: 1.383 | Acc: 49.816% (23051/46272)
Loss: 1.383 | Acc: 49.810% (23080/46336)
Loss: 1.383 | Acc: 49.819% (23116/46400)
Loss: 1.383 | Acc: 49.832% (23154/46464)
Loss: 1.383 | Acc: 49.835% (23187/46528)
Loss: 1.383 | Acc: 49.835% (23219/46592)
Loss: 1.383 | Acc: 49.818% (23243/46656)
Loss: 1.382 | Acc: 49.839% (23285/46720)
Loss: 1.382 | Acc: 49.846% (23320/46784)
Loss: 1.382 | Acc: 49.853% (23355/46848)
Loss: 1.382 | Acc: 49.844% (23383/46912)
Loss: 1.382 | Acc: 49.855% (23420/46976)
Loss: 1.382 | Acc: 49.868% (23458/47040)
Loss: 1.382 | Ac

Loss: 1.403 | Acc: 49.172% (4752/9664)
Loss: 1.403 | Acc: 49.157% (4782/9728)
Loss: 1.404 | Acc: 49.132% (4811/9792)
Loss: 1.404 | Acc: 49.087% (4838/9856)
Loss: 1.405 | Acc: 49.052% (4866/9920)
Loss: 1.406 | Acc: 48.998% (4892/9984)
Loss: 1.405 | Acc: 49.000% (4900/10000)
Evaluation of Epoch 1 is completed, Test accuracy for this epoch is 49.0

Epoch: 2
Loss: 1.491 | Acc: 45.312% (29/64)
Loss: 1.357 | Acc: 49.219% (63/128)
Loss: 1.278 | Acc: 50.521% (97/192)
Loss: 1.265 | Acc: 53.516% (137/256)
Loss: 1.266 | Acc: 53.438% (171/320)
Loss: 1.228 | Acc: 54.167% (208/384)
Loss: 1.282 | Acc: 52.902% (237/448)
Loss: 1.324 | Acc: 51.953% (266/512)
Loss: 1.340 | Acc: 51.389% (296/576)
Loss: 1.335 | Acc: 51.719% (331/640)
Loss: 1.342 | Acc: 50.994% (359/704)
Loss: 1.334 | Acc: 51.042% (392/768)
Loss: 1.327 | Acc: 51.683% (430/832)
Loss: 1.326 | Acc: 51.786% (464/896)
Loss: 1.336 | Acc: 51.458% (494/960)
Loss: 1.330 | Acc: 51.367% (526/1024)
Loss: 1.325 | Acc: 51.654% (562/1088)
Loss: 1.319 | Ac

Loss: 1.267 | Acc: 54.308% (7021/12928)
Loss: 1.266 | Acc: 54.326% (7058/12992)
Loss: 1.266 | Acc: 54.335% (7094/13056)
Loss: 1.266 | Acc: 54.345% (7130/13120)
Loss: 1.266 | Acc: 54.339% (7164/13184)
Loss: 1.266 | Acc: 54.333% (7198/13248)
Loss: 1.266 | Acc: 54.312% (7230/13312)
Loss: 1.268 | Acc: 54.246% (7256/13376)
Loss: 1.267 | Acc: 54.219% (7287/13440)
Loss: 1.267 | Acc: 54.221% (7322/13504)
Loss: 1.267 | Acc: 54.201% (7354/13568)
Loss: 1.266 | Acc: 54.225% (7392/13632)
Loss: 1.266 | Acc: 54.228% (7427/13696)
Loss: 1.266 | Acc: 54.237% (7463/13760)
Loss: 1.265 | Acc: 54.196% (7492/13824)
Loss: 1.266 | Acc: 54.169% (7523/13888)
Loss: 1.265 | Acc: 54.193% (7561/13952)
Loss: 1.264 | Acc: 54.195% (7596/14016)
Loss: 1.264 | Acc: 54.190% (7630/14080)
Loss: 1.263 | Acc: 54.242% (7672/14144)
Loss: 1.263 | Acc: 54.265% (7710/14208)
Loss: 1.264 | Acc: 54.253% (7743/14272)
Loss: 1.264 | Acc: 54.262% (7779/14336)
Loss: 1.265 | Acc: 54.243% (7811/14400)
Loss: 1.265 | Acc: 54.204% (7840/14464)


Loss: 1.268 | Acc: 54.107% (13990/25856)
Loss: 1.268 | Acc: 54.078% (14017/25920)
Loss: 1.268 | Acc: 54.083% (14053/25984)
Loss: 1.269 | Acc: 54.077% (14086/26048)
Loss: 1.269 | Acc: 54.094% (14125/26112)
Loss: 1.269 | Acc: 54.107% (14163/26176)
Loss: 1.269 | Acc: 54.108% (14198/26240)
Loss: 1.269 | Acc: 54.094% (14229/26304)
Loss: 1.269 | Acc: 54.092% (14263/26368)
Loss: 1.269 | Acc: 54.086% (14296/26432)
Loss: 1.270 | Acc: 54.069% (14326/26496)
Loss: 1.269 | Acc: 54.085% (14365/26560)
Loss: 1.269 | Acc: 54.098% (14403/26624)
Loss: 1.269 | Acc: 54.107% (14440/26688)
Loss: 1.268 | Acc: 54.119% (14478/26752)
Loss: 1.269 | Acc: 54.113% (14511/26816)
Loss: 1.269 | Acc: 54.115% (14546/26880)
Loss: 1.269 | Acc: 54.120% (14582/26944)
Loss: 1.269 | Acc: 54.102% (14612/27008)
Loss: 1.268 | Acc: 54.108% (14648/27072)
Loss: 1.268 | Acc: 54.113% (14684/27136)
Loss: 1.269 | Acc: 54.099% (14715/27200)
Loss: 1.269 | Acc: 54.093% (14748/27264)
Loss: 1.269 | Acc: 54.080% (14779/27328)
Loss: 1.270 | Ac

Loss: 1.274 | Acc: 53.998% (20839/38592)
Loss: 1.274 | Acc: 54.002% (20875/38656)
Loss: 1.274 | Acc: 53.998% (20908/38720)
Loss: 1.274 | Acc: 54.007% (20946/38784)
Loss: 1.274 | Acc: 54.005% (20980/38848)
Loss: 1.274 | Acc: 53.994% (21010/38912)
Loss: 1.274 | Acc: 54.000% (21047/38976)
Loss: 1.274 | Acc: 53.993% (21079/39040)
Loss: 1.274 | Acc: 54.002% (21117/39104)
Loss: 1.274 | Acc: 54.011% (21155/39168)
Loss: 1.274 | Acc: 54.009% (21189/39232)
Loss: 1.274 | Acc: 54.003% (21221/39296)
Loss: 1.274 | Acc: 54.014% (21260/39360)
Loss: 1.274 | Acc: 54.008% (21292/39424)
Loss: 1.274 | Acc: 54.009% (21327/39488)
Loss: 1.274 | Acc: 54.020% (21366/39552)
Loss: 1.274 | Acc: 54.016% (21399/39616)
Loss: 1.274 | Acc: 54.015% (21433/39680)
Loss: 1.274 | Acc: 54.031% (21474/39744)
Loss: 1.274 | Acc: 54.029% (21508/39808)
Loss: 1.275 | Acc: 54.028% (21542/39872)
Loss: 1.274 | Acc: 54.046% (21584/39936)
Loss: 1.275 | Acc: 54.040% (21616/40000)
Loss: 1.275 | Acc: 54.049% (21654/40064)
Loss: 1.274 | Ac

Loss: 1.254 | Acc: 55.659% (1318/2368)
Loss: 1.252 | Acc: 55.798% (1357/2432)
Loss: 1.248 | Acc: 56.050% (1399/2496)
Loss: 1.259 | Acc: 55.742% (1427/2560)
Loss: 1.260 | Acc: 55.640% (1460/2624)
Loss: 1.262 | Acc: 55.543% (1493/2688)
Loss: 1.259 | Acc: 55.596% (1530/2752)
Loss: 1.256 | Acc: 55.717% (1569/2816)
Loss: 1.254 | Acc: 55.694% (1604/2880)
Loss: 1.253 | Acc: 55.774% (1642/2944)
Loss: 1.251 | Acc: 55.818% (1679/3008)
Loss: 1.251 | Acc: 55.729% (1712/3072)
Loss: 1.252 | Acc: 55.676% (1746/3136)
Loss: 1.252 | Acc: 55.594% (1779/3200)
Loss: 1.254 | Acc: 55.453% (1810/3264)
Loss: 1.254 | Acc: 55.529% (1848/3328)
Loss: 1.253 | Acc: 55.542% (1884/3392)
Loss: 1.254 | Acc: 55.498% (1918/3456)
Loss: 1.256 | Acc: 55.398% (1950/3520)
Loss: 1.256 | Acc: 55.413% (1986/3584)
Loss: 1.258 | Acc: 55.400% (2021/3648)
Loss: 1.254 | Acc: 55.550% (2062/3712)
Loss: 1.256 | Acc: 55.376% (2091/3776)
Loss: 1.252 | Acc: 55.547% (2133/3840)
Loss: 1.253 | Acc: 55.430% (2164/3904)
Loss: 1.254 | Acc: 55.368