In [2]:
!pip install einops
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1


In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))

        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        # add position embedding
        x += self.positions
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out

class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))

class ViT(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 1000,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )

In [8]:
summary(ViT(1,16,768,32,5,2), (1, 32, 32), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 768, 2, 2]         197,376
         Rearrange-2               [-1, 4, 768]               0
    PatchEmbedding-3               [-1, 5, 768]               0
         LayerNorm-4               [-1, 5, 768]           1,536
            Linear-5              [-1, 5, 2304]       1,771,776
           Dropout-6              [-1, 8, 5, 5]               0
            Linear-7               [-1, 5, 768]         590,592
MultiHeadAttention-8               [-1, 5, 768]               0
           Dropout-9               [-1, 5, 768]               0
      ResidualAdd-10               [-1, 5, 768]               0
        LayerNorm-11               [-1, 5, 768]           1,536
           Linear-12              [-1, 5, 3072]       2,362,368
             GELU-13              [-1, 5, 3072]               0
          Dropout-14              [-1, 

In [11]:
import random
from torch.utils.data import random_split,Dataset
from torchvision.datasets import MNIST,CIFAR10
from torchvision.transforms import Compose,ToTensor,Normalize,RandomHorizontalFlip,RandomRotation, RandomVerticalFlip, RandomApply


class TinyCifar(Dataset):
    def __init__(self,args,train):
        self.train = train
        random.seed(args.seed)
        if args.normalise=='standard':
            mean,std = [0.485, 0.456, 0.406],[0.229, 0.224, 0.225]
        elif args.normalise=='constant':
            mean,std = (0,),(255.0,)
        else:
            raise ValueError
        
        self.test_transforms = Compose([ToTensor(),Normalize(mean,std)])
        self.train_transforms =  Compose([ToTensor(),Normalize(mean,std)])
        self.augment_transforms = Compose([RandomHorizontalFlip(),RandomVerticalFlip(),RandomRotation(-10,10)])

        data = CIFAR10(root="cifar/",train=True,download=True,transform=self.train_transforms)
        tiny_data = {}
        self.new_data = []
        if self.train:
            for label in range(0,10):
                tiny_data[label] = [x for x in data if x[1]==label]
            
            print(len(tiny_data[0]))
            for label in range(0,10):
                self.new_data.extend(random.sample(tiny_data[label],500))

            random.shuffle(self.new_data)
            assert len(self.new_data)==5000
        else:
            self.test = CIFAR10(root="cifar/",train=False,transform=self.test_transforms)


    def __getitem__(self, index):
        if self.train:
            img,y = self.new_data[index]
            img = self.augment_transforms(img)
            return (img,y)
        else:
            img,y = self.test[index]
            return (img,y)
    
    def __len__(self):
        if self.train:
            return 5000
        else:
            return 10000
        


In [12]:
def load_mnist(args):
    if args.normalise=='standard':
        mean,std = (0.1307,), (0.3081,)
    elif args.normalise=='constant':
        mean,std = (0,),(255.0,)
    else:
        raise ValueError


    transforms = Compose([ToTensor(),Normalize(mean,std)])
    data = MNIST(root="mnist/",train=True,download=True,transform=transforms)
    test = MNIST(root="mnist/",train=False,transform=transforms)
    train,val = random_split(data,[50000,10000],generator=torch.Generator().manual_seed(args.seed))

    print("Train : {}, Validation : {}, Test : {} ".format(len(train),len(val),len(test)))

    train = torch.utils.data.DataLoader(train,batch_size=args.batch,shuffle=True)    
    val = torch.utils.data.DataLoader(val,batch_size=args.batch,shuffle=True)    
    test = torch.utils.data.DataLoader(test,batch_size=32,shuffle=True)  
    for x,y in train:
        print(type(x),type(y),x.shape,y.shape)
        break

    return train,val,test

def load_cifar(args):
    if args.normalise=='standard':
        mean,std = [0.485, 0.456, 0.406],[0.229, 0.224, 0.225]
    elif args.normalise=='constant':
        mean,std = (0,),(255.0,)
    else:
        raise ValueError


    test_transforms = Compose([ToTensor(),Normalize(mean,std)])
    train_transforms =  Compose([ToTensor(),Normalize(mean,std)])

    data = CIFAR10(root="cifar/",train=True,download=True,transform=train_transforms)
    test = CIFAR10(root="cifar/",train=False,transform=test_transforms)
    train,val = random_split(data,[42000,8000],generator=torch.Generator().manual_seed(args.seed))

    print("Train : {}, Validation : {}, Test : {} ".format(len(train),len(val),len(test)))

    train = torch.utils.data.DataLoader(train,batch_size=args.batch,shuffle=True)    
    val = torch.utils.data.DataLoader(val,batch_size=args.batch,shuffle=True)    
    test = torch.utils.data.DataLoader(test,batch_size=32,shuffle=True)  
    for x,y in train:
        print(type(x),type(y),x.shape,y.shape)
        break

    return train,val,test

    pass

def load_tinycifar(args):
    data = TinyCifar(args,train=True)
    test = TinyCifar(args,train=False)
    train,val = random_split(data,[4200,800],generator=torch.Generator().manual_seed(args.seed))

    print("Train : {}, Validation : {}, Test : {} ".format(len(train),len(val),len(test)))

    train = torch.utils.data.DataLoader(train,batch_size=args.batch,shuffle=True)    
    val = torch.utils.data.DataLoader(val,batch_size=args.batch,shuffle=True)    
    test = torch.utils.data.DataLoader(test,batch_size=32,shuffle=True)  
    for x,y in train:
        print(type(x),type(y),x.shape,y.shape)
        break

    return train,val,test

   

In [22]:
class temp:
        def __init__(self):
            self.normalise = 'standard'
            self.seed = 10
            self.layers = 2
            self.expt = 2
            self.batch = 128
            self.epochs= 75
            self.early = 10
            self.dropout = False
            self.reg = False
            self.dropout_rate = 0.0
            self.activation = 'relu'
            self.lr = 0.01
            self.hid = 128
            self.scratch = True
            self.name = None 
            self.device = None
    
args = temp() 
args.name = 'Expt{}_l{}_ep{}_early{}_reg{}_dr{}_rate_{}_act{}_hid{}_lr{}_sc{}_bch{}__norm{}_seed{}'.format(args.expt,args.layers,args.epochs,args.early,args.reg,args.dropout,args.dropout_rate,args.activation,args.hid,args.lr,args.scratch,args.batch,args.normalise,args.seed)
name = args.name
print(name)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.save = False


def train(model,data_train,data_valid,optimizer,scheduler,args):
    losses = []
    train_accuracy = []
    patience = args.early if args.early > 0 else -1
    no_improvement = 0
    best = 0
    for epoch in range(0,args.epochs):
        epoch_loss = 0
        correct = 0
        total = 0
        for x,y in data_train:
            x = x.to(args.device)
            y = y.to(args.device)
            optimizer.zero_grad()
            pred = model(x).type(torch.float32)
            loss = torch.nn.functional.nll_loss(pred,y)
            loss.backward()
            epoch_loss += loss.item()*len(x)
            optimizer.step()
            _,output = torch.max(pred,dim=1)
            correct += (output == y).detach().float().sum().item()
            total = total + x.shape[0]

        val_loss,val_acc = test(model,data_valid,args)
        train_loss = epoch_loss/total
        train_acc = correct/total

        print("Epoch : {} ,Train Acc : {}, Train Loss : {}, Val Acc : {}, Val Loss : {}".format(epoch,train_acc,train_loss,val_acc,val_loss))
        
        with open("Results/{}/log.txt".format(args.name),'a') as f:
           f.write("Epoch : {} ,Train Acc : {}, Train Loss : {}, Val Acc : {}, Val Loss : {}\n".format(epoch,train_acc,train_loss,val_acc,val_loss))
        
        if val_acc > best + 0.01:
            best = val_acc
            no_improvement = 0
            if args.save:
                torch.save(model.state_dict(),"Results/{}/model_best.pth".format(args.name))

        if args.save:
            torch.save(model.state_dict(),"Results/{}/model_{}.pth".format(args.name,epoch))

        no_improvement += 1

        if patience > 0 and no_improvement == patience:
            break

        scheduler.step(val_acc)
        train_accuracy.append(train_acc)
        losses.append(train_loss)
        
    return (losses,train_accuracy)

def test(model,data,args):
    test_loss = 0
    test_accuracy = 0
    with torch.no_grad():
        correct = 0
        total = 0
        for x,y in data:
            x = x.to(args.device)
            y = y.to(args.device)
            pred = model(x).type(torch.float32)
            loss = torch.nn.functional.nll_loss(pred,y)
            _,output = torch.max(pred,dim=1)
            correct += (output == y).detach().float().sum().item()
            total = total + x.shape[0]

        test_accuracy = correct/total
        test_loss = loss/total
        
    
    return test_loss, test_accuracy

Expt2_l2_ep75_early10_regFalse_drFalse_rate_0.0_actrelu_hid128_lr0.01_scTrue_bch128__normstandard_seed10


In [14]:
data_train,data_valid,data_test = load_cifar(args)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting cifar/cifar-10-python.tar.gz to cifar/
Train : 42000, Validation : 8000, Test : 10000 
<class 'torch.Tensor'> <class 'torch.Tensor'> torch.Size([128, 3, 32, 32]) torch.Size([128])


In [19]:
for x,y in data_train:
    print(x.shape,y.shape)
    out = model(x)
    print(out.shape)
    break

torch.Size([128, 3, 32, 32]) torch.Size([128])
torch.Size([128, 10])


In [27]:
import torch.optim as optim
import os

model = ViT(3,32,16,32,2,10)
args.lr = 0.00001
optimizer = optim.Adam(params=model.parameters(),lr=args.lr,weight_decay=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=5,min_lr=1e-5)

try:
    os.mkdir("Results")
except:
    pass

try:
    os.mkdir("Results/"+name)
except:
    pass


train(model,data_train,data_valid,optimizer,scheduler,args)

Epoch : 0 ,Train Acc : 0.17792857142857144, Train Loss : -0.5552880728131249, Val Acc : 0.218375, Val Loss : -9.293136099586263e-05
Epoch : 1 ,Train Acc : 0.23107142857142857, Train Loss : -0.8239356306166876, Val Acc : 0.236625, Val Loss : -0.00013070134446024895
Epoch : 2 ,Train Acc : 0.23742857142857143, Train Loss : -0.9630529928888594, Val Acc : 0.23275, Val Loss : -0.0001320558221777901
Epoch : 3 ,Train Acc : 0.23792857142857143, Train Loss : -1.0624631857190814, Val Acc : 0.231625, Val Loss : -0.00015102683391887695
Epoch : 4 ,Train Acc : 0.23947619047619048, Train Loss : -1.139815902414776, Val Acc : 0.231625, Val Loss : -0.0001331451494479552
Epoch : 5 ,Train Acc : 0.24183333333333334, Train Loss : -1.20409587483179, Val Acc : 0.23475, Val Loss : -0.00014327761891763657


KeyboardInterrupt: ignored