### Dependencies

In [262]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
# import torch.nn.functional as F
import pytorch_lightning
import wandb
import os
import torchvision

## Model

![Vision Transformer](vit.png)

In [195]:
class Encoder(nn.Module):
    
    def __init__(self, embedding_dim, d_model=256, nhead=4, num_layers=2):
        super(Encoder, self).__init__()
        "encoder layer is single encoder block that is shown in above pic  (right)"
        self.encoder = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead),
            num_layers=num_layers,
            norm = nn.LayerNorm(normalized_shape=embedding_dim)
        )
    def forward(self, x):
        out = self.encoder(x)
        return out


In [247]:
class VisionTransformer(nn.Module):
    """
        transformer module to encoder the image patches
        output of this module will be flatten encoded patches
    """
    
    def __init__(self, image_size, channels, patch_size, stride, embedding_dim, nhead, num_layers, num_classes, fc_dim=256):
        super(VisionTransformer, self).__init__()
        
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_dim = channels * patch_size ** 2
        self.stride = stride

        # patch_pos embedding and patch projection layer
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, embedding_dim))
        self.patch_projection = nn.Linear(in_features=self.patch_dim, out_features=embedding_dim, bias=False)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) 
        
        # transformer module to encoder the image patches output of this module will be flatten encoded patches
        self.transformer = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=nhead),
            num_layers=num_layers,
            norm=nn.LayerNorm(normalized_shape=embedding_dim)
        )
        
        # to take cls token
        self.to_cls_token = nn.Identity()
        # classifier or mlp head to classify the data
        self.fc = nn.Sequential(
            nn.Linear(in_features=embedding_dim, out_features=fc_dim),
            nn.ReLU(),
            nn.Linear(in_features=fc_dim, out_features=num_classes)
        )
        
        
    def forward(self, x):
        # x.shape = [batch, w, h, channel]
        # patchifyt the image
        batch_size = x.shape[0]
        
        x = self.patchify(x, self.patch_size, self.patch_size)
        x = self.patch_projection(x)
        
        # concat cls token into projected patch
        cls_token = self.cls_token.expand(batch_size, -1, -1) 
        x = torch.cat((cls_token, x), dim=1)
        # add positional embedding + projected patches 
        x = x + self.pos_embedding
        
        # encoded the input and take the cls token and then feed it to mlp
        x = self.transformer(x)
        x = self.to_cls_token(x[:, 0])        
        outputs = self.fc(x)
        return outputs
        
        
    def patchify(self, images, patch_size, stride):
        # get all image windows of size (patch_size, patch_size) and stride (stride, stride)
        patches = images.unfold(2, patch_size, stride).unfold(3, patch_size, stride)
        patches = patches.permute(0, 2, 3, 4, 5, 1).contiguous()
        # patches.shape -> [batch, ... .... ... ..., ]
        # the size of flatten vector
        bs, pr, pc, h, w, ch = patches.shape[0], patches.shape[1], patches.shape[2], patches.shape[3], patches.shape[4], patches.shape[5]
        # bs->batch_size, rp->patches_row, pc->patches_col, h->patch_height, w->patch_width, w->patch_widht, ch->channels

        # dissolve it 
        patches = patches.view(bs, pr*pc, h*w*ch)

        return patches
        
    
        
        

In [248]:
vit = VisionTransformer(
    image_size=256,
    channels=3,
    patch_size=16,
    stride=16,
    embedding_dim=512,
    nhead=8,
    num_layers=2,
    num_classes=100
)

In [249]:
x = torch.rand(10, 3, 256, 256)
outputs = vit(x)

In [250]:
outputs.shape

torch.Size([10, 100])

## Training with PyTorch-Lightning and WandB

In [257]:
import pytorch_lightning as pl
import pytorch_lightning.loggers as loggers
import pytorch_lightning.metrics as metrics

In [282]:
args = {
    # model args
    "image_size":28,
    "channels":1,
    "patch_size":7,
    "stride":7,
    "embedding_dim":512,
    "nhead":8,
    "num_layers":4,
    "num_classes":10,
    "fc_dim":256,
    # data
    "batch_size":32,
    "num_workers":4,
    "transforms":torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]),
    #training
    "lr":0.01,
}

In [266]:
# nn.CrossEntropyLoss?

In [283]:
class  LightningViT(pl.LightningModule):
    
    def __init__(self, args):
        super(LightningViT, self).__init__()
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = metrics.Accuracy(num_classes=args["num_classes"])
        self.args = args
        
        self.model = VisionTransformer(
            image_size=args["image_size"],
            channels=args["channels"],
            patch_size=args["patch_size"],
            stride=args["stride"],
            embedding_dim=args["embedding_dim"],
            nhead=args["nhead"],
            num_layers=args["num_layers"],
            num_classes=args["num_classes"],
            fc_dim=args["fc_dim"]
        )
    
    def forward(self, x):
        outputs = self.model(x)
        return outputs
    
    def configure_optimizers(self):
        return torch.optim.Adam(params=self.parameters(), lr=self.args["lr"]),
    
    def train_dataloader(self):
        dataset = torchvision.datasets.MNIST(root=os.getcwd(), train=True, download=True, transform=self.args["transforms"])
        train_loader = DataLoader(dataset=dataset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"])
        return train_loader
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        outputs = self(x)
        loss = self.criterion(outputs, y)
        acc = self.accuracy(outputs.argmax(dim=1), y)
        logs = {"loss":loss, "accuracy":acc}
        return {"loss":loss, "accuracy":acc, "log":logs}
    
    def val_dataloader(self):
        dataset = torchvision.datasets.MNIST(root=os.getcwd(), train=False, download=True, transform=self.args["transforms"])
        val_loader = DataLoader(dataset=dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"])
        return val_loader
    
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        outputs = self(x)
        loss = self.criterion(outputs, y)
        acc = self.accuracy(outputs.argmax(dim=1), y)
        return {"val_loss":loss, "val_accuracy":acc}
    
    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_accuracy'] for x in outputs]).mean()
        logs = {'val_loss': avg_loss, 'val_log':avg_acc}
        return  {'val_loss': avg_loss, 'val_log':avg_acc, 'log': logs}
        
        


In [285]:
model =  LightningViT(args=args)

In [291]:
# callbacks
early_stopping = pl.callbacks.EarlyStopping(
    monitor="val_acc",
    min_delta=0.05,
)

model_checkpoint = pl.callbacks.ModelCheckpoint(
    filepath="model.pth",
    monitor="val_acc",
)


In [293]:
trainer = pl.Trainer(
    checkpoint_callback=model_checkpoint,
    early_stop_callback=early_stopping,
    gpus=[0],
    max_epochs=100,
    precision=16,
    deterministic=True,
    show_progress_bar=True
)

MisconfigurationException: 
                You requested GPUs: [0]
                But your machine only has: []
            

In [None]:
trainer = pl.Trainer