In [1]:
import torch
import torchvision

import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

import numpy as np
import pytorch_lightning as pl

from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import train_test_split

In [2]:
pl.seed_everything(42)

Global seed set to 42


42

# Download Dataset 

In [3]:
CLASSES = ("plane", "car", "bird", "cat", "deer",
           "dog", "frog", "horse", "ship", "truck")
IMAGE_CHANNEL_NUM = 3
IMAGE_SIZE = 32
CLASS_NUM = 10
BATCH_NUM = 256

In [4]:
class Flatten(nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)

class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(Conv, self).__init__()
        # Depthwise conv
        self.conv1 = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1,
        )
        self.bn = nn.BatchNorm2d(out_channels)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn(x)
        x = F.relu(x)
        
        return x
    
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1, stride=1)
        
        self.block1 = nn.Sequential(
            Conv(
                in_channels=64, out_channels=64, stride=2
            ),
            Conv(
                in_channels=64, out_channels=128, stride=1
            ),
        )
        
        self.block2 = nn.Sequential(
            Conv(
                in_channels=128, out_channels=128, stride=2
            ),
            Conv(
                in_channels=128, out_channels=256, stride=1
            ),
        )
        
        self.block3 = nn.Sequential(
            Conv(
                in_channels=256, out_channels=256, stride=1
            ),
            Conv(
                in_channels=256, out_channels=512, stride=1
            ),
        )
        
        self.block4 = nn.Sequential(
            Conv(
                in_channels=512, out_channels=512, stride=2
            ),
            Conv(
                in_channels=512, out_channels=1024, stride=1
            ),
        )
        
        self.avg_pool = nn.AvgPool2d(4, 4)
        self.fc = nn.Linear(1024, 10)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.avg_pool(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        
        return x

In [5]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
train_set, val_set = torch.utils.data.random_split(dataset, [40000, 10000])
test_set = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_train)

train_dataloader = DataLoader(train_set, batch_size=BATCH_NUM, shuffle=True, num_workers=8)
valid_dataloader = DataLoader(val_set, batch_size=BATCH_NUM, num_workers=8)
test_dataloader = DataLoader(test_set, batch_size=BATCH_NUM, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
class CNN(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.cnn = ConvNet()
        self.train_accuracy = pl.metrics.Accuracy()
        self.valid_accuracy = pl.metrics.Accuracy()
        self.test_accuracy = pl.metrics.Accuracy()

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self.cnn(inputs)
        loss = F.cross_entropy(outputs, targets)
        self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_acc", self.train_accuracy(outputs, targets), on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self.cnn(inputs)
        loss = F.cross_entropy(outputs, targets)
        self.log("valid_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        self.log("valid_acc", self.valid_accuracy(outputs, targets), on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self.cnn(inputs)
        loss = F.cross_entropy(outputs, targets)
        self.log("test_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        self.log("test_acc", self.test_accuracy(outputs, targets), on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def training_epoch_end(self, outs):
        self.log("train_acc_epoch", self.train_accuracy.compute(), on_epoch=True, prog_bar=True, logger=True)

    def validation_epoch_end(self, outs):
        self.log("valid_acc_epoch", self.valid_accuracy.compute(), on_epoch=True, prog_bar=True, logger=True)
        
    def test_epoch_end(self, outs):
        self.log("test_acc_epoch", self.test_accuracy.compute(), on_epoch=True, prog_bar=True, logger=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.cnn.parameters(), lr=5e-4, betas=(0.5, 0.999))
        
        return optimizer

In [7]:
cnn = CNN()
early_stop_callback = EarlyStopping(
    monitor="valid_loss",
    min_delta=0.00,
    patience=15,
    verbose=False,
    mode="min"
)

trainer = pl.Trainer(
    deterministic=True,
    callbacks=[early_stop_callback], 
    check_val_every_n_epoch=1, 
    gpus=1,
    max_epochs=1000,
)
trainer.fit(cnn, train_dataloader, valid_dataloader)
trainer.test(test_dataloaders=test_dataloader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type     | Params
--------------------------------------------
0 | cnn            | ConvNet  | 9.4 M 
1 | train_accuracy | Accuracy | 0     
2 | valid_accuracy | Accuracy | 0     
3 | test_accuracy  | Accuracy | 0     
--------------------------------------------
9.4 M     Trainable params
0         Non-trainable params
9.4 M     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.8706, device='cuda:0'),
 'test_acc_epoch': tensor(0.8706, device='cuda:0'),
 'test_loss': tensor(0.5365, device='cuda:0')}
--------------------------------------------------------------------------------


[{'test_loss': 0.5365256667137146,
  'test_acc': 0.8705999851226807,
  'test_acc_epoch': 0.8705999851226807}]