In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR

import torchvision
from torchvision import transforms, datasets

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, GPUStatsMonitor, ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

from torchmetrics import Accuracy

from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization, imagenet_normalization
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

In [2]:
class Transform:
    def __init__(self, transform: A.Compose):
        self.transform = transform

    def __call__(self, img, *args, **kwargs):
        return self.transform(image=np.array(img), *args, **kwargs)['image']


pl.seed_everything(42)

BATCH_SIZE = 128
data_folder = "/home/dima/datasets/imagenet/"

dm = ImagenetDataModule(data_folder, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
# dm.train_transforms = Transform(A.Compose([
#     A.Resize(224, 224),
#     A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
#     A.HorizontalFlip(p=0.5),
#     A.RandomResizedCrop(224, 224, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
#     ToTensorV2(),
# ]))
# dm.val_transforms = Transform(A.Compose([
#     A.Resize(224, 224),
#     A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
#     ToTensorV2(),
# ]))
dm.train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop((224, 224),scale=(0.8,1.0),ratio=(0.9,1.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
dm.val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
dm.setup()

STEPS_PER_EPOCH = len(dm.train_dataloader())

Global seed set to 42


In [5]:
class PatchEmbedding(nn.Module):
    """ 2D Image to Patch Embedding"""
    
    def __init__(self, img_size=224, patch_size=16, emb_dim=768, num_channels=3, norm_layer=None, dropout=0.):
        super().__init__()
        
        self.img_size = img_size = (img_size, img_size)
        self.patch_size = patch_size = (patch_size, patch_size)
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(num_channels, emb_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(emb_dim) if norm_layer else nn.Identity()
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = nn.Parameter(torch.zeros(1, 1 + self.num_patches, emb_dim))  # + cls_token
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        _, _, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)
        x = rearrange(x, "b c h w -> b (h w) c")
        x = self.norm(x)
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = self.dropout(x + self.pos_embedding)
        return x
    
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=0):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()  # binarize
        output = x.div(keep_prob) * random_tensor
        return output
    
class AttentionBlock(nn.Module):
    def __init__(self, emb_dim, num_heads, hidden_dim=2048, dropout=0., drop_path=0.):
        """
        Inputs:
            emb_dim - Dimensionality of input and attention feature vectors
            num_heads - Number of heads to use in the Multi-Head Attention block
            hidden_dim - Dimensionality of hidden layer in feed-forward network
                         (usually 2-4x larger than emb_dim)
            dropout - Amount of dropout to apply in the feed-forward network
        """
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(emb_dim)
        self.attn = nn.MultiheadAttention(emb_dim, num_heads, batch_first=True)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.layer_norm_2 = nn.LayerNorm(emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, emb_dim),
            nn.Dropout(dropout)
        )


    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.drop_path(self.attn(inp_x, inp_x, inp_x)[0])
        x = x + self.drop_path(self.mlp(self.layer_norm_2(x)))
        return x
    
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, emb_dim=512, mlp_dim=2048, depth=6, num_heads=8, 
                 num_channels=3, num_classes=10, dropout=0.1, emb_dropout=0.1, drop_path=0.):
        super().__init__()
        
        self.embedding = PatchEmbedding(img_size, patch_size, emb_dim, num_channels, dropout=emb_dropout)
        dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
        self.transformer = nn.Sequential(*[
            AttentionBlock(emb_dim, num_heads, mlp_dim, dropout=dropout, drop_path=dpr[i]) for i in range(depth)
        ])
        self.norm = nn.LayerNorm(emb_dim, eps=1e-6)
        self.classifier = nn.Linear(emb_dim, num_classes)
        
    def forward(self, x):
        x = self.embedding(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = self.classifier(x[:, 0])
        return x

In [6]:
class ViT(pl.LightningModule):
    def __init__(self, model, lr=1e-3, weight_decay=0, max_iters=2000):
        super().__init__()
        
        self.save_hyperparameters(ignore='model')
        self.model = model
        
        self.train_acc_top1 = Accuracy()
        self.val_acc_top1 = Accuracy()
        self.test_acc_top1 = Accuracy()
        
        self.train_acc_top5 = Accuracy(top_k=5)
        self.val_acc_top5 = Accuracy(top_k=5)
        self.test_acc_top5 = Accuracy(top_k=5)
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.cross_entropy(y_pred, y)
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        self.log('train_acc_top1', self.train_acc_top1(y_pred, y), on_step=True, on_epoch=False, prog_bar=True)
        self.log('train_acc_top5', self.train_acc_top5(y_pred, y), on_step=True, on_epoch=False, prog_bar=False)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.cross_entropy(y_pred, y)
        self.log("val_loss", loss, on_step=False, on_epoch=True)
        self.log('val_acc_top1', self.val_acc_top1(y_pred, y), on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc_top5', self.val_acc_top5(y_pred, y), on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.cross_entropy(y_pred, y)
        self.log("test_loss", loss, on_step=False, on_epoch=True)
        self.log('test_acc_top1', self.test_acc_top1(y_pred, y), on_step=False, on_epoch=True)
        self.log('test_acc_top5', self.test_acc_top5(y_pred, y), on_step=False, on_epoch=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
        )
        scheduler = CosineAnnealingLR(optimizer, self.hparams.max_iters, eta_min=self.hparams.lr*1e-2)
        scheduler_dict = {'scheduler': scheduler, 'interval': 'step'}
        return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}

In [None]:
NUM_EPOCHS = 30
model = VisionTransformer(img_size=224, patch_size=16, emb_dim=768, 
                          depth=8, num_heads=12, num_classes=1000, drop_path=0.3)
vit = ViT(model, lr=1e-4, weight_decay=1e-5, max_iters=NUM_EPOCHS*STEPS_PER_EPOCH)
lr_monitor = LearningRateMonitor(logging_interval='step')
model_checkpoint = ModelCheckpoint(
    monitor='val_loss', 
    filename='imagenetv3-{epoch}-{val_loss:.3f}', 
    dirpath='/home/dima/ViTransformer/checkpoints/', 
    mode='min',
)
gpu_monitor = GPUStatsMonitor()

trainer = pl.Trainer(
    max_epochs=NUM_EPOCHS,
    gpus=1,
    logger=TensorBoardLogger('/home/dima/lightning_logs/', name='vit_imagenet'),
#     default_root_dir='/home/dima/vitransformer/checkpoints/',
    callbacks=[lr_monitor, model_checkpoint, gpu_monitor],
    gradient_clip_val=1,
)

trainer.fit(vit, dm)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name           | Type              | Params
-----------------------------------------------------
0 | model          | VisionTransformer | 45.6 M
1 | train_acc_top1 | Accuracy          | 0     
2 | val_acc_top1   | Accuracy          | 0     
3 | test_acc_top1  | Accuracy          | 0     
4 | train_acc_top5 | Accuracy          | 0     
5 | val_acc_top5   | Accuracy          | 0     
6 | test_acc_top5  | Accuracy          | 0     
-----------------------------------------------------
45.6 M    Trainable params
0         Non-trainable params
45.6 M    Total params
182.500   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: -1it [00:00, ?it/s]

In [None]:
trainer = pl.Trainer(
    max_epochs=20,
    gpus=1,
    logger=TensorBoardLogger('/home/dima/lightning_logs/', name='vit_imagenet'),
#     default_root_dir='/home/dima/vitransformer/checkpoints/',
    callbacks=[lr_monitor, model_checkpoint, gpu_monitor],
    gradient_clip_val=1,
)

trainer.fit(vit, dm)