##### Side note: self.training is an internal parameter, describing whether the model is in training mode or not. Used in GoogleLeNet (or Inception)

In [52]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import einops

In [134]:
class ImageTransformer(nn.Module):
    def __init__(self,*,image_size,patch_size,num_classes,dim,depth,heads, mlp_dim, channels=3,
                dropout=0.1, emb_dropout=0.1):
        super(ImageTransformer,self).__init__()
        
        assert image_size % patch_size == 0
        num_patches=(image_size//patch_size)**2
        self.patch_size=patch_size
        
        self.pos_embedding=nn.Parameter(torch.empty(1,(num_patches+1),dim))
        nn.init.normal_(self.pos_embedding,std=0.02)
        # Patches are created by using a conv2d layer, you can write another function to do it also
        self.patch_conv = nn.Conv2d(3,dim,patch_size,stride=patch_size)
        # for each dim there is one learning parameter
        self.cls_token=nn.Parameter(torch.zeros(1,1,dim))
        self.dropout=nn.Dropout(emb_dropout)
        
        self.transformer=Transformer(dim,depth,heads,mlp_dim,dropout) # depth number of transformer blocks
        self.to_cls_token=nn.Identity()
        self.fc1=nn.Linear(dim,num_classes)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        
    def forward(self,img,mask=None):
        x=self.patch_conv(img)
        x=einops.rearrange(x,'b c h w -> b (h w) c')
        cls_tokens=self.cls_token.expand(img.shape[0],-1,-1) # expand for the batch size
        x=torch.cat((cls_tokens,x),dim=1)
        # Looking at the 16x16 paper Figure-1, pos_embedding are added to the concanated flatten image with cls_token
        x=x+self.pos_embedding
        x=self.dropout(x)
        x=self.transformer(x,mask)
        x=self.to_cls_token(x[:,0]) # get rid of 2nd dimension elements now only [batch,dim]
        x=self.fc1(x)
        
        return x

In [135]:
class Transformer(nn.Module):
    def __init__(self,dim,depth,heads,mlp_dim,dropout):
        super(Transformer,self).__init__()
        self.layers=nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(LayerNormalize(dim,Attention(dim,heads=heads,dropout=dropout))),
                Residual(LayerNormalize(dim,MLP_Block(dim,mlp_dim,dropout=dropout)))
            ]))
            
    def forward(self,x,mask=None):
        for attention,mlp in self.layers:
            x=attention(x,mask=mask)
            x=mlp(x)
        return x

In [148]:
class Attention(nn.Module):
    def __init__(self,dim, heads=8,dropout=0.1):
        super(Attention,self).__init__()
        self.heads=heads
        self.scale=dim** -0.5
        
        self.to_qvk=nn.Linear(dim,dim*3,bias=True)
        torch.nn.init.xavier_uniform_(self.to_qvk.weight)
        torch.nn.init.zeros_(self.to_qvk.bias)
        
        self.fc1=nn.Linear(dim,dim)
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        torch.nn.init.zeros_(self.fc1.bias)
        self.drop1=nn.Dropout(dropout)
        
    def forward(self,x,mask=None):
        b,n, _,h= *x.shape,self.heads
        qkv=self.to_qvk(x)
        q,k,v=einops.rearrange(qkv,'b n (qkv h d) -> qkv b h n d',qkv=3, h=h)
        
        dots=torch.einsum('bhid,bhjd->bhij',q,k) * self.scale
        
        attn=dots.softmax(dim=-1)
        
        out=torch.einsum('bhij,bhjd->bhid',attn,v)
        out=einops.rearrange(out,'b h n d -> b n (h d)')
        out=self.fc1(out)
        out=self.drop1(out)
        return out
        

In [137]:
class LayerNormalize(nn.Module):
    def __init__(self,dim,fn):
        super(LayerNormalize,self).__init__()
        self.norm=nn.LayerNorm(dim)
        self.fn=fn
        
    def forward(self,x,**kwargs):
        return self.fn(self.norm(x),**kwargs)

In [138]:
class Residual(nn.Module):
    def __init__(self,fn):
        super(Residual,self).__init__()
        self.fn=fn
        
    def forward(self,x,**kwargs):
        return self.fn(x,**kwargs) + x

In [139]:
class MLP_Block(nn.Module):
    def __init__(self,dim,hidden_dim,dropout=0.1):
        super(MLP_Block,self).__init__()
        self.fc1=nn.Linear(dim,hidden_dim)
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        torch.nn.init.normal_(self.fc1.bias,std= 1e-6)
        self.af1=nn.GELU()
        self.drop1=nn.Dropout(dropout)
        self.fc2=nn.Linear(hidden_dim,dim)
        torch.nn.init.xavier_uniform_(self.fc2.weight)
        torch.nn.init.normal_(self.fc2.bias,std=1e-6)
        self.drop2=nn.Dropout(dropout)
        
    def forward(self,x):
        x=self.fc1(x)
        x=self.af1(x)
        x=self.drop1(x)
        x=self.fc2(x)
        x=self.drop2(x)
        
        return x

In [140]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [161]:
transform_config=transforms.Compose(
     [torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.RandomAffine(8, translate=(.15,.15)),
     torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

In [160]:
train_dataset=datasets.CIFAR10('/home/ubuntu/gpu_work',download=True,train=True,transform=transform_config)
test_dataset=datasets.CIFAR10('/home/ubuntu/gpu_work',download=True,train=False,transform=transform_config)


Files already downloaded and verified
Files already downloaded and verified


In [159]:
BATCH_SIZE=256
train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_loader=torch.utils.data.DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False)

In [162]:
model=ImageTransformer(image_size=32,patch_size=4,num_classes=10,channels=3,dim=64,depth=6,
                      heads=4,mlp_dim=128).to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=0.003)


In [157]:
def train(model,train_loader,optimizer,epoch,device):
    model.train()
    for batch_ids, (data,label) in enumerate(train_loader):
        label=label.type(torch.LongTensor)
        data,label=data.to(device),label.to(device)
        
        optimizer.zero_grad()
        model_output=F.log_softmax(model(data), dim=1)
        loss=F.nll_loss(model_output,label)
        loss.backward()
        optimizer.step()
        
        if (batch_ids +1) % 50 ==0:
                print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                epoch,batch_ids*len(data),len(train_loader.dataset),
                100.*batch_ids/len(train_loader),loss.item()))

In [146]:
def test(model,test_loader,device):
    model.eval()
    correct=0
    with torch.no_grad():
        for data, label in test_loader:
            data,label=data.to(device),label.to(device)
            y_hat=F.log_softmax(model(data), dim=1)
            _,y_pred=torch.max(y_hat,1)
            correct+=(y_pred==label).sum().item()
        print("\n Test Set: Average loss: xx , Accuracy:{}/{} ({:.0f}%)".format(
            correct,len(test_dataset),100.*correct/len(test_dataset)))
        print("="*50)

In [163]:
if __name__=='__main__':
    seed=42
    EPOCHS=150
    
    for epoch in range (1,EPOCHS+1):
        train(model,train_loader,optimizer,epoch,device)
        test(model,test_loader,device)


 Test Set: Average loss: xx , Accuracy:3575/10000 (36%)

 Test Set: Average loss: xx , Accuracy:4370/10000 (44%)

 Test Set: Average loss: xx , Accuracy:4792/10000 (48%)

 Test Set: Average loss: xx , Accuracy:4930/10000 (49%)

 Test Set: Average loss: xx , Accuracy:5415/10000 (54%)

 Test Set: Average loss: xx , Accuracy:5690/10000 (57%)

 Test Set: Average loss: xx , Accuracy:5743/10000 (57%)

 Test Set: Average loss: xx , Accuracy:5833/10000 (58%)

 Test Set: Average loss: xx , Accuracy:6052/10000 (61%)

 Test Set: Average loss: xx , Accuracy:6231/10000 (62%)

 Test Set: Average loss: xx , Accuracy:6202/10000 (62%)

 Test Set: Average loss: xx , Accuracy:6475/10000 (65%)

 Test Set: Average loss: xx , Accuracy:6343/10000 (63%)

 Test Set: Average loss: xx , Accuracy:6496/10000 (65%)

 Test Set: Average loss: xx , Accuracy:6573/10000 (66%)

 Test Set: Average loss: xx , Accuracy:6476/10000 (65%)

 Test Set: Average loss: xx , Accuracy:6692/10000 (67%)

 Test Set: Average loss: xx , 


 Test Set: Average loss: xx , Accuracy:7255/10000 (73%)

 Test Set: Average loss: xx , Accuracy:7166/10000 (72%)

 Test Set: Average loss: xx , Accuracy:7265/10000 (73%)

 Test Set: Average loss: xx , Accuracy:7237/10000 (72%)

 Test Set: Average loss: xx , Accuracy:7349/10000 (73%)

 Test Set: Average loss: xx , Accuracy:7299/10000 (73%)

 Test Set: Average loss: xx , Accuracy:7245/10000 (72%)

 Test Set: Average loss: xx , Accuracy:7291/10000 (73%)

 Test Set: Average loss: xx , Accuracy:7306/10000 (73%)

 Test Set: Average loss: xx , Accuracy:7372/10000 (74%)

 Test Set: Average loss: xx , Accuracy:7327/10000 (73%)

 Test Set: Average loss: xx , Accuracy:7293/10000 (73%)

 Test Set: Average loss: xx , Accuracy:7411/10000 (74%)

 Test Set: Average loss: xx , Accuracy:7312/10000 (73%)

 Test Set: Average loss: xx , Accuracy:7368/10000 (74%)

 Test Set: Average loss: xx , Accuracy:7487/10000 (75%)

 Test Set: Average loss: xx , Accuracy:7368/10000 (74%)

 Test Set: Average loss: xx , 


 Test Set: Average loss: xx , Accuracy:7495/10000 (75%)

 Test Set: Average loss: xx , Accuracy:7588/10000 (76%)

 Test Set: Average loss: xx , Accuracy:7651/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7579/10000 (76%)

 Test Set: Average loss: xx , Accuracy:7660/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7638/10000 (76%)

 Test Set: Average loss: xx , Accuracy:7574/10000 (76%)

 Test Set: Average loss: xx , Accuracy:7597/10000 (76%)

 Test Set: Average loss: xx , Accuracy:7615/10000 (76%)

 Test Set: Average loss: xx , Accuracy:7654/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7628/10000 (76%)

 Test Set: Average loss: xx , Accuracy:7720/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7592/10000 (76%)

 Test Set: Average loss: xx , Accuracy:7581/10000 (76%)

 Test Set: Average loss: xx , Accuracy:7703/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7675/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7516/10000 (75%)

 Test Set: Average loss: xx , 


 Test Set: Average loss: xx , Accuracy:7573/10000 (76%)

 Test Set: Average loss: xx , Accuracy:7697/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7710/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7744/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7729/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7636/10000 (76%)

 Test Set: Average loss: xx , Accuracy:7676/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7781/10000 (78%)

 Test Set: Average loss: xx , Accuracy:7781/10000 (78%)

 Test Set: Average loss: xx , Accuracy:7782/10000 (78%)

 Test Set: Average loss: xx , Accuracy:7795/10000 (78%)

 Test Set: Average loss: xx , Accuracy:7713/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7676/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7669/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7710/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7646/10000 (76%)

 Test Set: Average loss: xx , Accuracy:7629/10000 (76%)

 Test Set: Average loss: xx , 


 Test Set: Average loss: xx , Accuracy:7821/10000 (78%)

 Test Set: Average loss: xx , Accuracy:7812/10000 (78%)

 Test Set: Average loss: xx , Accuracy:7748/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7734/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7745/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7774/10000 (78%)

 Test Set: Average loss: xx , Accuracy:7797/10000 (78%)

 Test Set: Average loss: xx , Accuracy:7733/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7731/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7865/10000 (79%)

 Test Set: Average loss: xx , Accuracy:7756/10000 (78%)

 Test Set: Average loss: xx , Accuracy:7762/10000 (78%)

 Test Set: Average loss: xx , Accuracy:7729/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7745/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7788/10000 (78%)

 Test Set: Average loss: xx , Accuracy:7685/10000 (77%)

 Test Set: Average loss: xx , Accuracy:7688/10000 (77%)

 Test Set: Average loss: xx , 