In [None]:
import os
from collections import defaultdict
import pandas as pd
import spacy
import torch
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.tensorboard import SummaryWriter
import nltk
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.meteor_score import meteor_score

from tqdm  import tqdm, trange

### Model Implementation

In [6]:
class PatchEmbedding(nn.Module):
    '''
    Turns a 2D input image into a 1D sequence learnable embedding vector.

    Args:
        in_channels (int): Number of color channels for the input images. Defaults to 3.
        patch_size (int): Size of patches to convert input image into. Defaults to 16.
        embedding_dim (int): Size of embedding to turn image into. Defaults to 768.
    '''
    def __init__(self, in_channels = 3, patch_size=16, embedding_dim=768):
        super().__init__()
        self.patcher = nn.Conv2d(in_channels=in_channels, out_channels=embedding_dim,
                                 kernel_size=patch_size, stride=patch_size, padding=0)
        
        self.flatten = nn.Flatten(start_dim=2)
    
    def forward(self, x):
        # input of shape (batch_size, color_channel, height, width)
        x = self.patcher(x) # (batch_size, embedding_dim, height//patch_size, width//patch_size)
        x = self.flatten(x) # (batch_size, embedding_dim, (height * width)//(patch_size)**2)
        return x.permute(0, 2, 1) #(batch_size,(height * width)//(patch_size)**2), embedding_dim)

In [None]:
class EncoderBlcok(nn.Module):
    '''
    Encoder block that returns a representation of the image patches.
    
    Args:
        embedding_dim (int): size of the embedding for each image patch. Defaults 768.
        num_heads (int): Number of head in the attention layer. Defaults 12.
        mlp_size (int): Size for the feed forward portion of the encoder. Defaults 3072.
        dropout (float): Amount of dropout in attention and mlp layer. Default 0.1
    '''
    def __init__(self, embeding_dim = 768, num_heads = 12, mlp_size= 3072, dropout=0.1):
        super().__init__()
        
        self.layer_norm1 = nn.LayerNorm(normalized_shape=embeding_dim)
        self.attention = nn.MultiheadAttention(embed_dim=embeding_dim, num_heads=num_heads, dropout=dropout)
        
        self.layer_norm2 = nn.LayerNorm(normalized_shape=embeding_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embeding_dim, mlp_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_size, embeding_dim),
            nn.Dropout(dropout)   
        )
    
    def forward(self, x):
        norm_x = self.layer_norm1(x)
        x = x + self.attention(norm_x, norm_x, norm_x, need_weights=False)[0]
        
        norm_x = self.layer_norm2(x)
        return x + self.mlp(norm_x)

In [None]:
class ViT(nn.Module):
    '''
    Creates a Vision Transformer architecture with ViT-Base hyperparameters by default with no classification token and layer.
    '''
    def __init__(self, img_size= 224, in_channels=3, patch_size=16, num_blocks = 12,
                 embed_dim = 768, mlp_size = 3072, num_heads = 12, dropout=0.1):
        super().__init__()
        assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}."
        self.num_patches = (img_size // patch_size) ** 2
        
        self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))
        self.embed_dropout = nn.Dropout(dropout)
        
        self.patch_embedding = PatchEmbedding(in_channels=in_channels, patch_size=patch_size, embedding_dim=embed_dim)
        
        self.transformer_encoder = nn.Sequential(*[EncoderBlcok(embeding_dim=embed_dim, num_heads=num_heads,
                                                                mlp_size=mlp_size, dropout=dropout) for _ in range(num_blocks)])
        
    def forward(self, x):
        # x is an image of shape (batch_size, color_chanel, height, width)
        x = self.patch_embedding(x) # -> (batch_size, embed_dim, num_patches)
        x = x + self.position_embedding
        x = self.embed_dropout(x)
        return self.transformer_encoder(x) # -> (batch_size, embed_dim, num_patches)
        
        

#### Decoder model (GPT-2 like model requires cross attention)

In [None]:
class DecoderBlock(nn.Module):
    
    def __init__(self, embed_dim=768, num_heads=12, dropout=0.1, mlp_size=3072):
        super().__init__()
        self.mask_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout, batch_first=True)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        
        self.cross_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout, batch_first=True)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_size, embed_dim),
            nn.Dropout(dropout)
        )
        
    
    def forward(self, x, encoder_out, padding_mask):
        seq_len = x.shape[1]
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        att_x, _ = self.mask_attention(x, x, x, attn_mask=mask, need_weights=False, key_padding_mask=padding_mask)
        x = self.layer_norm1(x + att_x)
        
        att_x, _ = self.cross_attention(x, encoder_out, encoder_out, need_weights=False)
        x = self.layer_norm2(x + att_x)
        
        return self.mlp(x)
        

In [None]:
class GPT(nn.Module):
    
    def __init__(self, vocab_size, embed_dim = 768, mlp_size = 3072, max_seq_len=350, num_layers=12, num_heads=12, dropout=0.1, padding_idx=0):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim=embed_dim)
        self.positional_embedding = nn.Embedding(max_seq_len, embedding_dim=embed_dim)
        
        self.transformer_decoder = nn.ModuleList(*[DecoderBlock(embed_dim, num_heads, dropout, mlp_size) for _ in range(num_layers)])
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.fc_out = nn.Linear(embed_dim, vocab_size)
        self.padding_index = padding_idx
    
    def forward(self, x, encoder_out):
        # x of shape (batch_size, seq_len)
        _, seq_len = x.shape
        padding_key = (x == self.padding_index)
        pos = torch.arange(0, seq_len, dtype=torch.long, device=x.device)
        pos_embed = self.positional_embedding(pos) # shape (seq_len, n_emebd)
        token_embed = self.token_embedding(x) # output of shape (batch_size, seq_len, embed_dim)
        x = pos_embed + token_embed
        
        for decoder in self.transformer_decoder:
            x = decoder(x, encoder_out, padding_key)
            
        return self.fc_out(self.layer_norm(x))