In [None]:
import torch
import torchvision 
import torch.nn as nn
import math

import numpy as np
from einops.layers.torch import Rearrange

import matplotlib.pyplot as plt

# Homework description
 
In this homework, you have to build an image classification model that treats each image as a sequence of sub-images. Very similar to the recent architecture of [Visual Transformer](https://arxiv.org/pdf/2010.11929.pdf). 

In [None]:
# Here we define our dataset and resize each image to (224, 224).
# Consider adding your own data augmentations, like flip, blur, etc.

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
    ]
)

dataset = torchvision.datasets.Caltech256('data/', download=True, transform=transform)

#################

# YOUR CODE HERE
# 1) Split the original dataset into the train, validation, and test datasets. 
# 2) Explain your choice of dataset sizes
# 3) Create a dataloader for each dataset

#################

In [None]:
# You can check how PatchEmbeddings class splits the image into patches using the image_to_patches function

class PatchEmbeddings(torch.nn.Module):
    def __init__(self, patch_size, d_model, channels=3):
        ''' Patch Embedding class. 
                Takes image as an input, splits it into a number of patches of the pre-determined size (patch_size) 
                and applies a linear transformation to convert the patch into a vector representation.

            Arguments: 
                patch_size: int
                    size of the patch to cut from an image
                d_model: int
                    linear representation size of the patch
                channels: int
                    number of channels in the input image
        '''
        super().__init__()
        self.patch_size = patch_size
        self.channels = channels
        self.d_model = d_model
        self.patch_dim = channels * patch_size * patch_size

        self.patch_embeddings = Rearrange(
            'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', 
            p1 = patch_size, p2 = patch_size
        )
        self.embedding = nn.Linear(
            in_features = self.patch_dim,
            out_features = d_model
        )
    
    def image_to_patches(self, img, plot_figure=False):
        ''' Split image into patches '''
        x = img.unsqueeze(0)
        x = self.patch_embeddings(x)
        batch_size, number_of_channels, PD = x.shape
        x = x.reshape(
            batch_size, number_of_channels, self.patch_size, self.patch_size, self.channels
        ).squeeze(0)

        if plot_figure:
            fig = plt.figure(figsize=(8, 8))
            matrix_shape = np.sqrt(number_of_channels).astype(int)
            
            for i in range(number_of_channels):
                patch = x[i]
                ax = fig.add_subplot(matrix_shape, matrix_shape, i+1)
                ax.axes.get_xaxis().set_visible(False)
                ax.axes.get_yaxis().set_visible(False)
                ax.imshow(patch)
        else:
            return x

    def forward(self, x):
        x = self.patch_embeddings(x)
        x = self.embedding(x)
        return x

In [None]:
# Refer to this links for more information about positional encoding 
# [1] https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
# [2] https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/ (code from here most likely won't work in out setting)
# [3] https://towardsdatascience.com/master-positional-encoding-part-i-63c05d90a0c3

class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model: int, max_seq_len: int = 256, dropout: float = 0.1):
        '''
          Positional encoding class, that gives the model a notion of a position of a given patch representation
          Args:
            d_model: int 
              Embedding dims of the model 
            max_seq_len: int 
              maximum length
            dropout: float
              dropout rate
        '''
        super().__init__()

        ################

        # YOUR CODE HERE

        # 1) create a positions vector
        positions = # TODO

        # 2) Calculate a division factor
        div_factor = # TODO

        # 3) Calculate a positional encoding, 
        # don't forget that sin operation is applied on each second position in the sequence
        # ~ 3 lines

        self.positional_encoding = # TODO

        # 4) You can check results of your positional encoding using visualize_positional_encoding function
        # positional_encoding tensor should be two-dimensional: time and value
        #################

        self.dropout = nn.Dropout(p=dropout)

    def visualize_positional_encoding(self):
        plt.figure(figsize=(8, 6))
        plt.imshow(self.positional_encoding, cmap='hot', interpolation='nearest')
        plt.show()

    def forward(self, x):
        '''
        Args: 
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        '''
        x = x + self.positional_encoding[:x.shape[1]]
        return self.dropout(x)


In [None]:
# Useful links
# [1] https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-recurrent-neural-networks
# [2] https://codeburst.io/recurrent-neural-network-4ca9fd4f242

class RNN(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim

        ##############

        # YOUR CODE HERE

        # Define weights of the input tensor, hidden state tensor and bias
        self.W_xh = # TODO
        self.W_hh = # TODO
        self.b = # TODO

        ##############

    def step_forward(self, x, state):
        ##############
        
        # YOUR CODE HERE

        # Calculate next_state tensor of the RNN 
        next_state = # TODO

        ##############
        return next_state
    
    def forward(self, inputs, state=None):
        if state is not None:
            state = state
        
        B, T, _ = inputs.shape
        outputs = torch.zeros((B, T, self.hidden_dim))
        for idx in range(T):
            if idx == 0:
                outputs[:, idx, :] = self.step_forward(inputs[:, idx, :], outputs[:, 0, :])
            else:
                outputs[:, idx, :] = self.step_forward(inputs[:, idx, :], outputs[:, idx-1, :])
        return outputs

In [None]:
class Linear(nn.Module):
    def __init__(self, in_features, out_features, add_bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.add_bias = add_bias

        ##############
        
        # YOUR CODE HERE

        # Define weights of the input tensor and bias
        self.W = # TODO
        self.b = # TODO

        ##############

    def forward(self, x):
        ##############
        
        # YOUR CODE HERE

        # Define a forward calculate of the linear layer over the input tensor
        x = # TODO

        ##############
        return x

In [None]:
# Define your model here. 
# Initialize the RNN layer with hidden state dimension size 
# Initialize linear classification head. Make sure your output size matches the number of classes in the dataset

class MyModel(nn.Module):
    def __init__(self, patch_size, embedding_dim):
        super().__init__()

        self.patch_embedding = PatchEmbeddings(
            patch_size=patch_size, d_model = embedding_dim
        )
        self.positional_encoding = PositionalEncoding(d_model = embedding_dim)

        ###############

        # YOUR CODE HERE
        self.rnn = # TODO

        self.classifier = # TODO

        ###############

    def forward(self, x):
        x = self.patch_embedding(x)
        x = self.positional_encoding(x)
        
        ###############

        # YOUR CODE HERE
        # you forward path is here
        logits = # TODO

        ###############
        return logits 


# Train loops and evaluation

Here you will need to implement training and evaluation methods, as in the previous homework

1) Define methods train, evaluation, and train_epoch

2) Define your model instance, loss function, optimizer, and metric function

3) Train model

4) Evaluate model

5) Do a small hyperparameter search and visualize your results