In [1]:
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
from omegaconf import DictConfig
from torch.utils.data import DataLoader

import torch
import torchmetrics
import torch.nn.functional as F
import pytorch_lightning as pl
import numpy as np

from dataset.cub import CUB200
from model.resnet import ResNet

In [2]:
config = DictConfig({
    "dropout": 0.1,
    "num_classes": 200,
    "batch_size": 16,
    "num_workers": 8,
    "image_size": 448,
    "lr": 1e-3,
    "seed": 42,
    "momentum": 0.9,
    "epoch": 1000,
    "gpus": [0],
    "logger": False,
    "pretrained_dir": "./pretrained/vit/imagenet21k_ViT-B_32.npz",
})

In [3]:
class LitResNet(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        model = ResNet(classes=config.num_classes, pretrained=True)
        self.feature_extractor = model.feature_extractor
        self.classifier = model.classifier
        self.config = config

        self.init_dataset()

        self.train_accuracy = torchmetrics.Accuracy()
        self.val_accuracy = torchmetrics.Accuracy()
        self.test_accuracy = torchmetrics.Accuracy()

    def training_step(self, batch, batch_idx):
        inputs, targets = batch

        outputs = self.classifier(self.feature_extractor(inputs))
        loss = F.cross_entropy(outputs, targets)
        train_acc = self.train_accuracy(torch.argmax(outputs, dim=-1), targets)

        self.log("train_loss", loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log("train_acc", train_acc, on_step=False, on_epoch=True,
                sync_dist=True)

        return loss

    def training_epoch_end(self, outs):
        self.log("train_acc_epoch", self.train_accuracy.compute(),
                prog_bar=True, logger=True, sync_dist=True)

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        
        outputs = self.classifier(self.feature_extractor(inputs))

        loss = F.cross_entropy(outputs, targets)
        val_acc = self.val_accuracy(torch.argmax(outputs, dim=-1), targets)

        self.log("val_acc", val_acc, on_step=False, on_epoch=True, sync_dist=True)
        self.log("val_loss", loss, on_step=False, on_epoch=True, sync_dist=True)

        return loss


    def validation_epoch_end(self, outs):
        self.log("val_acc_epoch", self.val_accuracy.compute(),
                prog_bar=True, logger=True, sync_dist=True)
    
    def test_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self.classifier(self.feature_extractor(inputs))

        loss = F.cross_entropy(outputs, targets)
        test_acc = self.test_accuracy(torch.argmax(outputs, dim=-1), targets)

        self.log("test_acc", test_acc, on_step=False, on_epoch=True, logger=True,
                sync_dist=True)
        self.log("test_loss", loss, on_step=False, on_epoch=True, logger=True,
                sync_dist=True)

        return loss

    def test_epoch_end(self, outs):
        test_acc = self.test_accuracy.compute()
        self.log("test_acc_epoch", test_acc, logger=True, sync_dist=True)

    def configure_optimizers(self):
        model_params = [
            {"params": self.feature_extractor.parameters()},
            {"params": self.classifier.parameters()}
        ]
        optimizer = torch.optim.SGD(model_params, lr=self.config.lr, momentum=self.config.momentum)
        return optimizer

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.config.batch_size,
                        shuffle=True, pin_memory=True, num_workers=self.config.num_workers)

    def val_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.config.batch_size,
                        pin_memory=True, num_workers=self.config.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.config.batch_size,
                        pin_memory=True, num_workers=self.config.num_workers)

    def init_dataset(self):
        train_transform=transforms.Compose([
            transforms.Resize((600, 600), InterpolationMode.BILINEAR),
            transforms.RandomCrop((448, 448)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        test_transform=transforms.Compose([
            transforms.Resize((600, 600), InterpolationMode.BILINEAR),
            transforms.CenterCrop((448, 448)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        self.train_set = CUB200(root="./data", train=True, transform=train_transform)
        self.test_set = CUB200(root="./data", train=False, transform=test_transform)
    

In [4]:
if config.logger:
    from pytorch_lightning.loggers import WandbLogger
    logger = WandbLogger(
        project="xfg",
        name=f"vit"
    )
else:
    logger = pl.loggers.TestTubeLogger(
        "output", name=f"vit")
    logger.log_hyperparams(config)

pl.seed_everything(config.seed)
trainer = pl.Trainer(
    precision=16,
    deterministic=True,
    check_val_every_n_epoch=1,
    gpus=config.gpus,
    logger=logger,
    max_epochs=config.epoch,
    weights_summary="top",
    # accelerator='ddp',
)

model = LitResNet(config)
trainer.fit(model)
trainer.test()

Global seed set to 42
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type       | Params
-------------------------------------------------
0 | feature_extractor | Sequential | 23.5 M
1 | classifier        | Sequential | 409 K 
2 | train_accuracy    | Accuracy   | 0     
3 | val_accuracy      | Accuracy   | 0     
4 | test_accuracy     | Accuracy   | 0     
-------------------------------------------------
23.9 M    Trainable params
0         Non-trainable params
23.9 M    Total params
95.671    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

  value = torch.tensor(value, device=device, dtype=torch.float)
Global seed set to 42


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…



KeyboardInterrupt: 