In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from torchvision.models import resnet18
from torch.utils.data import random_split, DataLoader
import pytorch_lightning as pl
from typing import Type, Any, Callable, Union, List, Optional
from pytorch_lightning import Trainer, seed_everything

seed_everything(42, workers=True)


Global seed set to 42


42

In [2]:
class CifarDataModule(pl.LightningDataModule):

    def __init__(self, data_dir: str = './', transform = None, batch_size = 4):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size 
        
        if transform is None:
            self.transform = transforms.Compose(
                    [transforms.ToTensor(),
                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        else:
            self.transform = transform

        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()

    def prepare_data(self):
        # download
        torchvision.datasets.CIFAR10(root=self.data_dir, train=True,
                                        download=True)
        torchvision.datasets.CIFAR10(root=self.data_dir, train=False,
                                       download=True)

    def setup(self, stage: Optional[str] = None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            trainset = torchvision.datasets.CIFAR10(
                root=self.data_dir, train=True, transform=self.transform)
            self.train_subset, self.val_subset = \
                    random_split(trainset, [int(len(trainset) * 0.8), int(len(trainset) * 0.2)])

            self.dims = tuple(self.train_subset[0][0].shape)

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.testset = torchvision.datasets.CIFAR10(
                root=self.data_dir, train=False, transform=self.transform)

            self.dims = tuple(self.testset[0][0].shape)

    def train_dataloader(self):
        return DataLoader(self.train_subset, batch_size=self.batch_size, num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.val_subset, batch_size=len(self.val_subset), num_workers=8)

    def test_dataloader(self):
        return DataLoader(self.testset, batch_size=self.batch_size, num_workers=8)


## example
# dm = CifarDataModule()
# dm.prepare_data()
# dm.setup(stage='fit')

# model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
# trainer.fit(model, dm)

# dm.setup(stage='test')
# trainer.test(datamodule=dm)


In [10]:
# lightening
class LitNet(pl.LightningModule):
    def __init__(self, model, criterion):
        super().__init__()
        self.model = model
        self.criterion = criterion
        
    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        prediction = self.model(x)
        return prediction

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        inputs, labels = batch

        outputs = self.model(inputs)
        loss = criterion(outputs, labels)
        _, predicted = torch.max(outputs.data, 1)
        acc = (predicted == labels).sum().item() / labels.size(0)
        
        # Logging to TensorBoard by default
        self.log('train_loss', loss)
        self.log('train_acc', acc)

        return loss
    
    def validation_step(self, batch, batch_idx):
        inputs, labels = batch

        outputs = self.model(inputs)
        loss = criterion(outputs, labels)
        _, predicted = torch.max(outputs.data, 1)
        acc = (predicted == labels).sum().item() / labels.size(0)
        
        # Logging to TensorBoard by default
        self.log('val_loss', loss)
        self.log('val_acc', acc)

    def test_step(self, batch, batch_idx):
        inputs, labels = batch

        outputs = self.model(inputs)
        loss = criterion(outputs, labels)
        _, predicted = torch.max(outputs.data, 1)
        acc = (predicted == labels).sum().item() / labels.size(0)
        
        # Logging to TensorBoard by default
        self.log('test_loss', loss)
        self.log('test_acc', acc)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


In [7]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

early_stop_callback = EarlyStopping(
   monitor='val_loss',
   min_delta=0.00,
   patience=20,
   verbose=False,
   mode='min'
)
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='my/path/',
    filename='sample-mnist-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)
criterion = nn.CrossEntropyLoss()

# trainer = Trainer(callbacks=[early_stop_callback, checkpoint_callback])
trainer = Trainer(min_epochs=1, max_epochs=100, precision=16, callbacks=[early_stop_callback, checkpoint_callback], gpus = 1)



GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


In [8]:
net = resnet18(pretrained=False, num_classes=10)

In [9]:


dm = CifarDataModule(batch_size = 64)
model = LitNet(net, criterion)
trainer.fit(model, dm)
trainer.test(datamodule=dm)


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | ResNet           | 11.2 M
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Global seed set to 42


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


([], [])