In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchmetrics.functional.classification.accuracy import accuracy

from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

import augmentations

In [2]:
MAX_EPOCHS = 100
LEARNING_RATE = 1e-3
BATCH_SIZE = 1024
DATA_ROOT = '/home/woosung/pytorch/data'
SAVE_PATH = '/home/woosung/pytorch/model'

In [3]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

In [4]:
class ResNetBasicBlock(nn.Module):
    # Basic block of ResNet that is comprised of two Conv2d modules.
    def __init__(self, in_planes, out_planes, in_stride=1):
        super(ResNetBasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, out_planes, in_stride)
        self.bn1 = nn.BatchNorm2d(out_planes)
        self.relu1 = nn.ReLU()
        
        self.conv2 = conv3x3(out_planes, out_planes)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU()
        
        self.res = nn.Identity()
        if in_planes != out_planes:
            self.res = conv3x3(in_planes, out_planes, in_stride)
    
    def forward(self, x):
        residual = self.res(x)
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        
        x = x + residual
        x = self.relu2(x)
        return x

In [5]:
class ResNet18(pl.LightningModule):
    # ResNet with depth 18.
    def __init__(self, num_classes=10):
        super(ResNet18, self).__init__()
        # The preprocessing block before entering the first basic block.
        self.conv = conv3x3(3, 16)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()
        
        # ResNet building blocks
        self.block1 = self._make_block(16, 16, 1)
        self.block2 = self._make_block(16, 32, 2)
        self.block3 = self._make_block(32, 64, 2)
        self.block4 = self._make_block(64, 128, 2)
        
        # Postprocessing part
        self.ap = nn.AvgPool2d(kernel_size=4)
        self.fc = nn.Linear(128, num_classes)
        
    def _make_block(self, in_planes, out_planes, in_stride):
        return nn.Sequential(
            ResNetBasicBlock(in_planes, out_planes, in_stride),
            ResNetBasicBlock(out_planes, out_planes))
    
    def forward(self, x):
        # x: T[B, 3, 32, 32]
        x = self.conv(x) # T[B, X, 32, 32]
        x = self.bn(x)
        x = self.relu(x)
        
        x = self.block1(x) # T[B, X, 32, 32]
        x = self.block2(x) # T[B, 2*X, 16, 16]
        x = self.block3(x) # T[B, 4*X, 8, 8]
        x = self.block4(x) # T[B, 8*X, 4, 4]
        
        x = self.ap(x) # T[B, 8*X, 1, 1]
        x = x.flatten(start_dim=1)
        x = self.fc(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_state(batch, batch_idx)
        metrics = {"val_acc": acc, "val_loss": loss}
        self.log_dict(metrics, prog_bar=True)
        return metrics
    
    def test_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_state(batch, batch_idx)
        metrics = {"test_acc": acc, "test_loss": loss}
        self.log_dict(metrics)
        return metrics
    
    def _shared_eval_state(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = accuracy(logits, y)
        return loss, acc
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=LEARNING_RATE)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[int(0.5 * MAX_EPOCHS), int(0.75 * MAX_EPOCHS)],
            gamma=0.1)
        return [optimizer], [scheduler]

In [6]:
labels = ("airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")

cifar10_mean, cifar10_std = [0.4913, 0.4821, 0.4465], [0.2470, 0.2434, 0.2615]

In [7]:
transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        augmentations.RandAugment(2, 27),
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std),
    ]
)
transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std),
    ]
)

dataset_train = CIFAR10(root=DATA_ROOT, train=True, transform=transform_train, download=True)
dataset_test = CIFAR10(root=DATA_ROOT, train=False, transform=transform_test, download=True)

dataloader_train = DataLoader(
    dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
dataloader_test = DataLoader(
    dataset_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=SAVE_PATH,
    filename='resnet18-{epoch:02d}-{val_loss:.3f}'
)
trainer = pl.Trainer(gpus=1, max_epochs=MAX_EPOCHS, callbacks=[checkpoint_callback])

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [9]:
model = ResNet18().cuda()

In [10]:
trainer.fit(model, dataloader_train, dataloader_test)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type        | Params
---------------------------------------
0 | conv   | Conv2d      | 432   
1 | bn     | BatchNorm2d | 32    
2 | relu   | ReLU        | 0     
3 | block1 | Sequential  | 9.3 K 
4 | block2 | Sequential  | 37.1 K
5 | block3 | Sequential  | 147 K 
6 | block4 | Sequential  | 590 K 
7 | ap     | AvgPool2d   | 0     
8 | fc     | Linear      | 1.3 K 
---------------------------------------
787 K     Trainable params
0         Non-trainable params
787 K     Total params
3.148     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [11]:
trainer.test(model, dataloader_test)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.8837000131607056, 'test_loss': 0.34246334433555603}
--------------------------------------------------------------------------------


[{'test_acc': 0.8837000131607056, 'test_loss': 0.34246334433555603}]