In [1]:
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
from omegaconf import DictConfig
from torch.utils.data import DataLoader

import torch
import torchmetrics
import torch.nn.functional as F
import pytorch_lightning as pl
import numpy as np

from dataset.cub import CUB200
from model.vit import VisionTransformer

In [2]:
config = DictConfig({
    "patch_size": 16,
    "split": "non-overlap",
    "slide_step": 12,
    "hidden_size": 768,
    "dropout": 0.1,
    "max_len": 100,
    "classifier": "token",
    "transformer": {
        "mlp_dim": 3072,
        "num_heads": 12,
        "num_layers": 12,
        "attention_dropout_rate": 0.0,
    },
    "num_classes": 200,
    "batch_size": 16,
    "num_workers": 8,
    "image_size": 448,
    "lr": 3e-2,
    "seed": 42,
    "momentum": 0.9,
    "epoch": 10,
    "gpus": [0],
    "logger": False,
    "pretrained_dir": "./pretrained/vit/imagenet21k_ViT-B_16.npz",
})

In [3]:
class LitViT(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.model = VisionTransformer(config)
        self.model.load_from(np.load(config.pretrained_dir))
        self.config = config

        self.init_dataset()

        self.train_accuracy = torchmetrics.Accuracy()
        self.val_accuracy = torchmetrics.Accuracy()
        self.test_accuracy = torchmetrics.Accuracy()

    def training_step(self, batch, batch_idx):
        inputs, targets = batch

        outputs = self.model(inputs)
        loss = F.cross_entropy(outputs.view(-1, config.num_classes), targets.view(-1))
        train_acc = self.train_accuracy(torch.argmax(outputs, dim=-1), targets)

        self.log("train_loss", loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log("train_acc", train_acc, on_step=False, on_epoch=True,
                sync_dist=True)

        return loss

    def training_epoch_end(self, outs):
        self.log("train_acc_epoch", self.train_accuracy.compute(),
                prog_bar=True, logger=True, sync_dist=True)

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        
        outputs = self.model(inputs)

        loss = F.cross_entropy(outputs.view(-1, config.num_classes), targets.view(-1))
        val_acc = self.val_accuracy(torch.argmax(outputs, dim=-1), targets)

        self.log("val_acc", val_acc, on_step=False, on_epoch=True, sync_dist=True)
        self.log("val_loss", loss, on_step=False, on_epoch=True, sync_dist=True)

        return loss


    def validation_epoch_end(self, outs):
        self.log("val_acc_epoch", self.val_accuracy.compute(),
                prog_bar=True, logger=True, sync_dist=True)
    
    def test_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self.model(inputs)

        loss = F.cross_entropy(outputs.view(-1, config.num_classes), targets.view(-1))
        test_acc = self.test_accuracy(torch.argmax(outputs, dim=-1), targets)

        self.log("test_acc", test_acc, on_step=False, on_epoch=True, logger=True,
                sync_dist=True)
        self.log("test_loss", loss, on_step=False, on_epoch=True, logger=True,
                sync_dist=True)

        return loss

    def test_epoch_end(self, outs):
        test_acc = self.test_accuracy.compute()
        self.log("test_acc_epoch", test_acc, logger=True, sync_dist=True)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.config.lr, momentum=self.config.momentum)
        return optimizer

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.config.batch_size,
                        shuffle=True, pin_memory=True, num_workers=self.config.num_workers)

    def val_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.config.batch_size,
                        pin_memory=True, num_workers=self.config.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.config.batch_size,
                        pin_memory=True, num_workers=self.config.num_workers)

    def init_dataset(self):
        train_transform=transforms.Compose([
            transforms.Resize((600, 600), InterpolationMode.BILINEAR),
            transforms.RandomCrop((448, 448)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        test_transform=transforms.Compose([
            transforms.Resize((600, 600), InterpolationMode.BILINEAR),
            transforms.CenterCrop((448, 448)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        self.train_set = CUB200(root="./data", train=True, transform=train_transform)
        self.test_set = CUB200(root="./data", train=False, transform=test_transform)
    

In [4]:
if config.logger:
    from pytorch_lightning.loggers import WandbLogger
    logger = WandbLogger(
        project="xfg",
        name=f"vit"
    )
else:
    logger = pl.loggers.TestTubeLogger(
        "output", name=f"vit")
    logger.log_hyperparams(config)

pl.seed_everything(config.seed)
trainer = pl.Trainer(
    precision=16,
    deterministic=True,
    check_val_every_n_epoch=1,
    gpus=config.gpus,
    logger=logger,
    max_epochs=config.epoch,
    weights_summary="top",
    # accelerator='ddp',
)

model = LitViT(config)
trainer.fit(model)
trainer.test()

Global seed set to 42
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
load_pretrained: grid-size from 14 to 28
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type              | Params
-----------------------------------------------------
0 | model          | VisionTransformer | 86.4 M
1 | train_accuracy | Accuracy          | 0     
2 | val_accuracy   | Accuracy          | 0     
3 | test_accuracy  | Accuracy          | 0     
-----------------------------------------------------
86.4 M    Trainable params
0         Non-trainable params
86.4 M    Total params
345.616   Total estimated model params size (MB)
  value = torch.tensor(value, device=device, dtype=torch.float)
                                                                      Global seed set to 42
Epoch 0:   0%|          | 0/738 [00:00<?, ?it/s] 

RuntimeError: CUDA out of memory. Tried to allocate 452.00 MiB (GPU 0; 10.76 GiB total capacity; 9.09 GiB already allocated; 439.94 MiB free; 9.11 GiB reserved in total by PyTorch)