In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
from torchvision import transforms, datasets

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, GPUStatsMonitor, ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization, imagenet_normalization
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

In [7]:
class Transform:
    def __init__(self, transform: A.Compose):
        self.transform = transform

    def __call__(self, img, *args, **kwargs):
        return self.transform(image=np.array(img), *args, **kwargs)


pl.seed_everything(42)

BATCH_SIZE = 128
NUM_EPOCHS = 30
data_folder = "/home/dima/datasets/imagenet/"

dm = ImagenetDataModule(data_folder, batch_size=BATCH_SIZE, shuffle=True)
dm.train_transforms = Transform(A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    A.VerticalFlip(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(p=0.5),
    ToTensorV2(),
]))
dm.val_transforms = Transform(A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
]))
dm.setup()

STEPS_PER_EPOCH = len(dm.train_dataloader())

Global seed set to 42


In [16]:
class PatchEmbedding(nn.Module):
    """ 2D Image to Patch Embedding"""
    
    def __init__(self, img_size=224, patch_size=16, emb_dim=768, num_channels=3, norm_layer=None, dropout=0.):
        super().__init__()
        
        self.img_size = img_size = (img_size, img_size)
        self.patch_size = patch_size = (patch_size, patch_size)
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(num_channels, emb_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(emb_dim) if norm_layer else nn.Identity()
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = nn.Parameter(torch.zeros(1, 1 + self.num_patches, emb_dim))  # + cls_token
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        _, _, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)
        x = rearrange(x, "b c h w -> b (h w) c")
        x = self.norm(x)
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = self.dropout(x + self.pos_embedding)
        return x

In [17]:
class ViT(pl.LightningModule):
    def __init__(self, img_size=224, patch_size=16, emb_dim=512, depth=6, num_heads=8, 
                 num_channels=3, num_classes=10, dropout=0.1, emb_dropout=0.1, 
                 lr=1e-3, weight_decay=0, warmup=0, max_iters=2000):
        super().__init__()
        
        self.save_hyperparameters()
        
        self.embedding = PatchEmbedding(img_size, patch_size, emb_dim, num_channels, dropout=emb_dropout)
        encoder_layer = nn.TransformerEncoderLayer(emb_dim, num_heads, dropout=dropout, 
                                                   activation='gelu', batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, depth)
        self.norm = nn.LayerNorm(emb_dim, eps=1e-6)
        self.classifier = nn.Linear(emb_dim, num_classes)
        
        self.train_acc = pl.metrics.Accuracy()
        self.val_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()
        
    def forward(self, x):
        x = self.embedding(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = self.classifier(x[:, 0])
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x['image'])
        loss = F.cross_entropy(y_pred, y)
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        self.log('train_acc', self.train_acc(y_pred, y), on_step=True, on_epoch=False, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x['image'])
        loss = F.cross_entropy(y_pred, y)
        self.log("val_loss", loss, on_step=False, on_epoch=True)
        self.log('val_acc', self.val_acc(y_pred, y), on_step=False, on_epoch=False, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x['image'])
        loss = F.cross_entropy(y_pred, y)
        self.log("test_loss", loss, on_step=False, on_epoch=True)
        self.log('test_acc', self.test_acc(y_pred, y), on_step=False, on_epoch=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
        )
        scheduler = LinearWarmupCosineAnnealingLR(optimizer, self.hparams.warmup, self.hparams.max_iters)
        scheduler_dict = {'scheduler': scheduler, 'interval': 'step'}
        return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}

In [None]:
model = ViT(img_size=224, patch_size=16, lr=1e-3, weight_decay=0.1, 
            emb_dim=768, depth=8, num_heads=12, num_classes=1000, 
            warmup=2*STEPS_PER_EPOCH, max_iters=NUM_EPOCHS*STEPS_PER_EPOCH)
lr_monitor = LearningRateMonitor(logging_interval='step')
model_checkpoint = ModelCheckpoint(
    monitor='val_loss', 
    filename='imagenet-{epoch}-{val_loss:.3f}', 
    dirpath='/home/dima/ViTransformer/checkpoints/', 
    mode='min',
)
gpu_monitor = GPUStatsMonitor()

trainer = pl.Trainer(
    max_epochs=NUM_EPOCHS,
    gpus=1,
    logger=TensorBoardLogger('/home/dima/lightning_logs/', name='vit_imagenet'),
#     default_root_dir='/home/dima/vitransformer/checkpoints/',
    callbacks=[lr_monitor, model_checkpoint, gpu_monitor],
    gradient_clip_val=1,
)

trainer.fit(model, dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type               | Params
---------------------------------------------------
0 | embedding   | PatchEmbedding     | 742 K 
1 | transformer | TransformerEncoder | 44.1 M
2 | norm        | LayerNorm          | 1.5 K 
3 | classifier  | Linear             | 769 K 
4 | train_acc   | Accuracy           | 0     
5 | val_acc     | Accuracy           | 0     
6 | test_acc    | Accuracy           | 0     
---------------------------------------------------
45.6 M    Trainable params
0         Non-trainable params
45.6 M    Total params
182.500   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: -1it [00:00, ?it/s]