In [None]:
# -*- coding: utf-8 -*-
"""
Created on Fri Oct 16 11:37:52 2020

@author: mthossain
"""
import torch
import PIL
import time
import torchvision
import torch.nn.functional as F
!pip install einops
from einops import rearrange
from torch import nn

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class LayerNormalize(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class MLP_Block(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.1):
        super().__init__()
        self.nn1 = nn.Linear(dim, hidden_dim)
        torch.nn.init.xavier_uniform_(self.nn1.weight)
        torch.nn.init.normal_(self.nn1.bias, std = 1e-6)
        self.af1 = nn.GELU()
        self.do1 = nn.Dropout(dropout)
        self.nn2 = nn.Linear(hidden_dim, dim)
        torch.nn.init.xavier_uniform_(self.nn2.weight)
        torch.nn.init.normal_(self.nn2.bias, std = 1e-6)
        self.do2 = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.nn1(x)
        x = self.af1(x)
        x = self.do1(x)
        x = self.nn2(x)
        x = self.do2(x)
        
        return x

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dropout = 0.1):
        super().__init__()
        self.heads = heads
        self.scale = dim ** -0.5  # 1/sqrt(dim)

        self.to_qkv = nn.Linear(dim, dim * 3, bias = True) # Wq,Wk,Wv for each vector, thats why *3
        torch.nn.init.xavier_uniform_(self.to_qkv.weight)
        torch.nn.init.zeros_(self.to_qkv.bias)
        
        self.nn1 = nn.Linear(dim, dim)
        torch.nn.init.xavier_uniform_(self.nn1.weight)
        torch.nn.init.zeros_(self.nn1.bias)        
        self.do1 = nn.Dropout(dropout)
        

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x) #gets q = Q = Wq matmul x1, k = Wk mm x2, v = Wv mm x3
        q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h) # split into multi head attentions

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, float('-inf'))
            del mask

        attn = dots.softmax(dim=-1) #follow the softmax,q,k,v equation in the paper

        out = torch.einsum('bhij,bhjd->bhid', attn, v) #product of v times whatever inside softmax
        out = rearrange(out, 'b h n d -> b n (h d)') #concat heads into one matrix, ready for next encoder block
        out =  self.nn1(out)
        out = self.do1(out)
        return out

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(LayerNormalize(dim, Attention(dim, heads = heads, dropout = dropout))),
                Residual(LayerNormalize(dim, MLP_Block(dim, mlp_dim, dropout = dropout)))
            ]))
    def forward(self, x, mask = None):
        for attention, mlp in self.layers:
            x = attention(x, mask = mask) # go to attention
            x = mlp(x) #go to MLP_Block
        return x

class ImageTransformer(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0.1, emb_dropout = 0.1):
        super().__init__()
        assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
        num_patches = (image_size**2 // patch_size**2)
        patch_dim = channels * patch_size ** 2  # e.g. 3*4**2 = 16*3 --- not used

        self.patch_size = patch_size
        self.pos_embedding = nn.Parameter(torch.empty(1, (num_patches + 1), dim))
        torch.nn.init.normal_(self.pos_embedding, std = .02) # initialized based on the paper
        self.patch_conv= nn.Conv2d(3,dim, patch_size, stride = patch_size) #equivalent to x matmul E, E= embedd matrix, this is the linear patch projection
        
        #self.E = nn.Parameter(nn.init.normal_(torch.empty(BATCH_SIZE_TRAIN,patch_dim,dim)),requires_grad = True)
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim)) #initialized based on the paper
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)

        self.to_cls_token = nn.Identity()

        self.nn1 = nn.Linear(dim, num_classes)  # if finetuning, just use a linear layer without further hidden layers (paper)
        torch.nn.init.xavier_uniform_(self.nn1.weight)
        torch.nn.init.normal_(self.nn1.bias, std = 1e-6)
        # self.af1 = nn.GELU() # use additinal hidden layers only when training on large datasets
        # self.do1 = nn.Dropout(dropout)
        # self.nn2 = nn.Linear(mlp_dim, num_classes)
        # torch.nn.init.xavier_uniform_(self.nn2.weight)
        # torch.nn.init.normal_(self.nn2.bias)
        # self.do2 = nn.Dropout(dropout)

    def forward(self, img, mask = None):
        p = self.patch_size
        img=img.to(device)
        x = self.patch_conv(img) # each of x number of vectors is linearly transformed with a FFN equiv to E matmul
        #x = torch.matmul(x, self.E)
        x = rearrange(x, 'b c h w -> b (h w) c') #  batch of images(64 vectors in rows representing 64  patches, each dim long) 128x64x128 (128 images, 64 patches each with 128 elements)

        cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        x += self.pos_embedding #we have added the poss embeddings and the cls token at this point. now it's time to drop some of the patches
        #apply masking
        #find indices of patches to be masked (12.5%,25%,50%)
        global tot
        if(tot==0): #mask only during training
          global foc
          if(foc%4!=0): #one every 4 batches (25% of batches per epoch) is inputted without masking. So ~25% images per epoch are inputted like that
            testperm=torch.randperm((1024 // self.patch_size**2)+1) #generating numbers from 0 to 65-1=64 and I want to drop 0(token position) and keep 1 to 64(image patches)
            testperm = testperm[testperm!=0] #keeping patch_num(dropping 0 because it would be cls token vector)
            clhelper=torch.zeros([1],dtype=torch.int64) 
            testperm = torch.cat((clhelper, testperm), dim=0) #putting zero in the beginning of the permutation vector after having removed it so that the cls token isn't shuffled
            testperm=testperm.to(device)
            x= x[:,testperm]  #shuffling 
            maperc= int(((1024 // self.patch_size**2)*0.75)+1) #decided masking percentage. change the multiplication value in maperc to (0.5 for 50%,0.75 for 25%, 0.875 for 12.5%)
            x=x[:,:maperc,:] #dropping the last n patches (depending on masking)
          
        x = self.dropout(x)

        x = self.transformer(x, mask) #main game

        x = self.to_cls_token(x[:, 0])
        
        x = self.nn1(x)
        # x = self.af1(x)
        # x = self.do1(x)
        # x = self.nn2(x)
        # x = self.do2(x)
        
        return x


BATCH_SIZE_TRAIN = 128 
BATCH_SIZE_TEST = 100
sizei=32

train_transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.Resize(sizei),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(sizei),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN,
                                          shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE_TEST,
                                         shuffle=False)

def train(modelc10md, optimizer, data_loader, loss_history):
    size = len(data_loader.dataset) #50000
    modelc10md.train()
    correct_samples = 0

    for i, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device) 
        optimizer.zero_grad()
        output = F.log_softmax(modelc10md(data), dim=1)
        loss = F.nll_loss(output, target)
        _, pred = torch.max(output, dim=1)
        correct_samples += pred.eq(target).sum()
        loss.backward()
        optimizer.step()
        global foc
        foc=foc+1
        if i % 100 == 0:
            # print('[' +  '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(total_samples) +     
            #       ' (' + '{:3.0f}'.format(100 * i / len(data_loader)) + '%)]  Loss: ' + 
            #       '{:6.4f}'.format(loss.item()))
            loss, current = loss.item(), i * len(data)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            loss_history.append(loss)
    correct= correct_samples/size
    print('Accuracy top 1%:' + '{:5}'.format(correct_samples) + '/' +
          '{:5}'.format(size) + ' (' +
          '{:4.2f}'.format(100.0 * correct) + '%)\n')
            
def evaluate(modelc10md, data_loader, loss_history):
    modelc10md.eval()
    
    size = len(data_loader.dataset)
    num_batches = len(data_loader)
    correct_samples = 0
    test_loss = 0

    with torch.no_grad():
        for data, target in data_loader:
            data,target = data.to(device), target.to(device) 
            output = F.log_softmax(modelc10md(data), dim=1)
            #loss = F.nll_loss(output, target, reduction='sum')
            loss = F.nll_loss(output, target)
            _, pred = torch.max(output, dim=1) 
            
            test_loss += loss.item()
            correct_samples += pred.eq(target).sum() #if they are identical you get a boolean with one 1 and all other zeros otherwise only zeros and you add them up to get either 1 or 0

    avg_loss = test_loss / num_batches #changed size that the original code incorrectly had because we calc loss function per batch of imgs and not per img
    loss_history.append(avg_loss)
    correct= correct_samples/size
    print('\nAverage test loss: ' + '{:.4f}'.format(avg_loss) +
          '  Accuracy top 1%:' + '{:5}'.format(correct_samples) + '/' +
          '{:5}'.format(size) + ' (' +
          '{:4.2f}'.format(100.0 * correct) + '%)\n')
    current_accuracy=100.0 * correct
    global best_accuracy
    if(current_accuracy>best_accuracy):
      best_accuracy=current_accuracy
      torch.save(modelc10md.state_dict(), 'VitC10md.pth')

N_EPOCHS = 50 #TO BE CHANGED
best_accuracy = 0
modelc10md = ImageTransformer(image_size=32, patch_size=4, num_classes=10, channels=3, dim=128, depth=6, heads=8, mlp_dim=512).to(device)
optimizer = torch.optim.Adam(modelc10md.parameters(), lr=0.001)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, N_EPOCHS)

train_loss_history, test_loss_history = [], []
for epoch in range(1, N_EPOCHS + 1):
    print('Epoch:', epoch)
    start_time = time.time()
    foc=0
    tot=0
    train(modelc10md, optimizer, train_loader, train_loss_history)
    print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')
    tot=1
    evaluate(modelc10md, test_loader, test_loss_history)
    scheduler.step()
    if(epoch>100):
      if(test_loss_history[epoch-1]>test_loss_history[epoch-2] and test_loss_history[epoch-1]>test_loss_history[epoch-3] and test_loss_history[epoch-1]>test_loss_history[epoch-4] and test_loss_history[epoch-1]>test_loss_history[epoch-5] and test_loss_history[epoch-1]>test_loss_history[epoch-6] and test_loss_history[epoch-1]>test_loss_history[epoch-7] and test_loss_history[epoch-1]>test_loss_history[epoch-8] and test_loss_history[epoch-1]>test_loss_history[epoch-9] and test_loss_history[epoch-1]>test_loss_history[epoch-10] and test_loss_history[epoch-1]>test_loss_history[epoch-11] and test_loss_history[epoch-1]>test_loss_history[epoch-12] and test_loss_history[epoch-1]>test_loss_history[epoch-13] and test_loss_history[epoch-1]>test_loss_history[epoch-14] and test_loss_history[epoch-1]>test_loss_history[epoch-15] and test_loss_history[epoch-1]>test_loss_history[epoch-16] and test_loss_history[epoch-1]>test_loss_history[epoch-17] and test_loss_history[epoch-1]>test_loss_history[epoch-18]):
        break    #Applying early stopping when in 17 consecutive epochs the avg validation loss is increased
      
print('Training is over')