In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from tqdm import tqdm

In [4]:
torch.cuda.is_available()

True

<img src = "https://editor.analyticsvidhya.com/uploads/35004Vit.png">

In [5]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, input_channels = 3, embedding_dims = 768):
        super(PatchEmbedding, self).__init__()
        '''
            image_size: the size of tha image assuming that the image is square aka height = width
            patch_size: Size of the batch assuming that it is square
            input_channel: 1 for grey_scale, 3 for RGB Channels
            embedding_dims: the dimension of the embedding layer
        '''
        self.image_size = image_size
        self.patch_size = patch_size
        self.input_channels = input_channels
        self.embedding_dims = embedding_dims
        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.projection = nn.Conv2d(self.input_channels, self.embedding_dims, kernel_size=self.patch_size, stride=self.patch_size)

    def forward(self, x):
        # x shape: (n_samples, input_channels, image_size, image_size) -> both image_size for height and width
        projection = self.projection(x) # shape (n_samples, embedding_dim, sqrt(n_patches), sqrt(n_patches))
        projection = projection.flatten(2) # shape (n_samples, embedding_dim, n_patches)
        projection = projection.transpose(1, 2) # shape (n_samples, n_patches, embedding_dim)
        return projection



In [6]:
class AttentionModel(nn.Module):
    def __init__(self, dim, num_heads, include_bias, attention_dropout = 0.5, projection_dropout = 0.5):
        super(AttentionModel, self).__init__()
        '''
            dim: Input/Output dimensions
            num_heads: number of heads of the attention
            include_bias: bool variable to include bias or not for query, key, and value of the attention
            attention_dropout: probability of dropout for the attention
            projection_dropout: robability of dropout for the projection (Patch Embedding Layer)
        '''
        self.dim = dim
        self.num_heads = num_heads
        self.include_bias = include_bias
        self.attention_dropout = attention_dropout
        self.projection_dropout = projection_dropout

        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.linear_layer = nn.Linear(dim, dim * 3, bias = include_bias) # Linear Mapping take in token embedding and generate query, key and a value (reason for dim * 3)
        self.projection = nn.Linear(dim, dim)

        self.attention_drop = nn.Dropout(self.attention_dropout)
        self.projection_drop = nn.Dropout(self.projection_dropout)

    def forward(self, x):
        # x shape: (n_samples, n_patches + 1, dim) [num_patches +1 for the 0 class token (from the paper)]

        # Extract the dimensions:
        n_samples, n_tokens, dim = x.shape
        linear = self.linear_layer(x) # shape: (n_samples, n_patches + 1, dim * 3)
        linear = linear.reshape(n_samples, n_tokens, 3, self.num_heads, self.head_dim) # shape: (n_samples, n_tokens, 3, num_heads, head_dim)
        linear = linear.permute(2, 0, 3, 1, 4) # shape: (3, n_samples, num_heads,  n_patches + 1, head_dim) # To Extract query, key, value
        query = linear[0]
        key = linear[1]
        value = linear[2]

        key_transpose = key.transpose(-2, -1) # Shape (num_samples, num_heads, head_dim, n_patches + 1)
        query_key = (query @ key_transpose) * self.scale # From Attention all you Need [Transformers]
        attention = query_key.softmax(dim = -1) # (n_samples, n_heads, n_patches + 1, ) To Generate a discrete probability distribution that sums up to one for [weighted average]
        attention = self.attention_drop(attention)
        weighted_average = attention @ value
        weighted_average_transpose = weighted_average.transpose(1, 2)
        weighted_average_flat = weighted_average_transpose.flatten(2) # To Flat the last 2 dimensions [For concatination] shape:(n_samples, n_patches + 1, head_dim)
        output = self.projection(weighted_average_flat) # shape: (n_samples, n_patches+1, dim)
        output = self.projection_drop(output)

        return output


In [7]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, dropout_p = 0.5):
        super(MLP, self).__init__()

        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features

        # Neural Network
        self.layer1 = nn.Linear(in_features, self.hidden_features)
        self.gelu = nn.GELU()
        self.linear2 = nn.Linear(self.hidden_features, out_features)
        self.drop = nn.Dropout(dropout_p)

    def forward(self, x):
        # x shape: (n_samples, n_patches + 1, in_features)
        linear1 = self.layer1(x)
        gelu = self.gelu(linear1)
        gelu = self.drop(gelu)
        linear2 = self.linear2(gelu)
        output = self.drop(linear2)
        return output

In [8]:
class Dual_Residual_Block(nn.Module):
    def __init__(self,dim,norm):
        super(Dual_Residual_Block,self).__init__()
        self.norm = norm
        self.dim = dim
    def forward(self,x,x_d,f):
        x_f = f(x)
        x = x+x_f
        x_d = x_d+x_f
        x_a = self.norm(x)
        x_d = self.norm(x_d)
        y = x_a+x_d
        return y,x_d
        

**Layer Normalization**
<br>
For each layer $ {h_i} $
<br>
$h_i := \frac{g}{\sigma} (h_i - \mu)$   $\ \ [\mu = 0, \sigma = 1]$

In [9]:
class BuildingBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio = 4.0, include_bias = True, dropout_p = 0.5, attention_p = 0.5):
        super(BuildingBlock, self).__init__()
        self.norm= nn.LayerNorm(dim, eps=1e-6)
        self.attention = AttentionModel(dim, num_heads, include_bias, attention_p, dropout_p)
        self.hidden_features = int(dim * mlp_ratio)
        self.FFN = nn.Sequential(
            nn.Linear(dim, dim*3),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_p),
            nn.Linear(dim*3,dim),
            nn.Dropout(dropout_p)
        )
        self.mlp = MLP(dim, self.hidden_features, dim)
        self.dim = dim
        self.Residual = Dual_Residual_Block(dim,self.norm)
    def forward(self, x):
        #Block 1
        x_ln1,x_d = self.Residual(x,x,self.attention)
        x_ln2,x_d = self.Residual(x_ln1,x_d,self.FFN)

        #Block 2
        x_ln3,x_d = self.Residual(x_ln2,x_d,self.attention)
        x_ln4,x_d = self.Residual(x_ln3,x_d,self.FFN)

        #Block 3
        x_ln5,x_d = self.Residual(x_ln4,x_d,self.attention)
        x_ln6,x_d = self.Residual(x_ln5,x_d,self.FFN)

        #Block 4:
        x_ln7,x_d = self.Residual(x_ln6,x_d,self.attention)
        x_ln8,x_d = self.Residual(x_ln7,x_d,self.FFN)
        
        y = self.mlp(x_ln8)
        return y
        

In [10]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size=384, patch_size=16, input_channels=3, num_classes=100, embedding_dims=768, depth=12, num_heads=12, mlp_ratio=4.0, include_bias = True, dropout_p = 0.5, attention_p = 0.5):
        super(VisionTransformer, self).__init__()
        self.patch_embedding = PatchEmbedding(image_size, patch_size, input_channels, embedding_dims) # instance of patch embedding model
        self.cls = nn.Parameter(torch.zeros(1, 1, embedding_dims))
        self.positional_embeddings = nn.Parameter(torch.zeros(1, 1 + self.patch_embedding.num_patches, embedding_dims)) # to get the exact position of a given patch in the image
        self.pos_drop = nn.Dropout(dropout_p)

        self.blocks = nn.ModuleList([ BuildingBlock(embedding_dims, num_heads, mlp_ratio, include_bias, dropout_p, attention_p) for transformer in range(depth) ])

        self.norm = nn.LayerNorm(embedding_dims, eps=1e-6)
        self.head = nn.Linear(embedding_dims, num_classes)

    def forward(self, x):
        # x shape: (n_samples, in_channels, img_size, img_size)
        n_samples = x.shape[0]
        x = self.patch_embedding(x)
        cls = self.cls.expand(n_samples, -1, -1) # shape: (n_samples, 1, embedding_dims)
        x = torch.cat((cls, x), dim = 1) # Concatination -> shape(n_samples, 1 + n_patches, embedding_dims)
        x = x + self.positional_embeddings
        x = self.pos_drop(x)
        print(x.shape)
        for block in self.blocks:
            x = block(x)
        cls_final = x[:, 0]

        x = self.head(cls_final)

        return x

In [11]:
import torch.nn.functional as F
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss


In [17]:
def train(vision_transformer_model, epochs, learning_rate, dataset):
    optimizer = optim.Adagrad(vision_transformer_model.parameters(), lr=learning_rate)
    critation = FocalLoss()
    losses = []
    t = len(dataset)
#     print(dataset.dataset.targets)
    embedding_dim = 100
    for epoch in tqdm(range(epochs)):
        loss_sum = 0
        i =  0
        for x, y in zip(dataset.dataset.data, dataset.dataset.targets):
            i+=1
            x = torch.tensor(x).cuda()
            x = x.view(1, x.shape[0], x.shape[1], x.shape[2])
            x = x.permute(0, 3, 1, 2)
            optimizer.zero_grad()
            y_hat= vision_transformer_model(x.float())
            #y = torch.tensor([y])
            y_label = torch.zeros(embedding_dim).cuda()
            y_label[y] = 1
            y_label = y_label.unsqueeze(0)
            loss = critation(y_hat, y_label)
            if i % 10 == 0:
                print("Loss ",i,"/",t,":",float(loss))
            losses.append(loss)
            loss_sum += loss
            loss.backward()
            optimizer.step()
        print("Total loss epoch ",epoch,":",loss_sum/t)
    return losses

In [18]:
vision_transformer = VisionTransformer(image_size = 32, num_classes = 100)

In [19]:
!mkdir cifar
cifar_data = torchvision.datasets.CIFAR100('data/cifar', download = True)
data_loader = torch.utils.data.DataLoader(cifar_data,
                                          batch_size=4,
                                          shuffle=True)
print(len(data_loader.dataset.data))


mkdir: cannot create directory ‘cifar’: File exists


Files already downloaded and verified
50000


In [15]:
EPOCHS = 10
LEARNING_RATE = 0.001

In [20]:
losses = train(vision_transformer.cuda(), EPOCHS, LEARNING_RATE, data_loader)

  0%|          | 0/10 [00:00<?, ?it/s]

torch.Size([1, 5, 768])
torch.Size([1, 5, 768])
torch.Size([1, 5, 768])
torch.Size([1, 5, 768])
torch.Size([1, 5, 768])
torch.Size([1, 5, 768])
torch.Size([1, 5, 768])
torch.Size([1, 5, 768])
torch.Size([1, 5, 768])
torch.Size([1, 5, 768])
Loss  10 / 12500 : 5.534091949462891
torch.Size([1, 5, 768])
torch.Size([1, 5, 768])
torch.Size([1, 5, 768])
torch.Size([1, 5, 768])
torch.Size([1, 5, 768])
torch.Size([1, 5, 768])


  0%|          | 0/10 [00:14<?, ?it/s]


KeyboardInterrupt: 