# A simple Multi-GPU example using PyTorch Lightning

In [None]:
# Some useful modules for notebooks
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from demopkg.dataset import load_cifar100
from torch.utils.data import DataLoader
from demopkg.model import CNN2D
import matplotlib.pyplot as plt

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pathlib import Path
from pytorch_lightning.loggers import WandbLogger
# output directory
from demopkg.conf import OUTPUTDIR
from demopkg.lightning import LightningClassifier


In [None]:
batch_size = 64
num_epochs = 5
learning_rate = 1e-3

num_classes = 100

input_channel = 3
convs = [32, 64, 128, num_classes]
n_convs = len(convs)
kernel_sizes = [5]*n_convs
strides = [2]*n_convs

In [None]:
net = CNN2D(input_channel, convs, kernel_sizes, strides)

In [None]:
# load data
train, val, test = load_cifar100()
train_loader = DataLoader(train, batch_size=batch_size)
val_loader = DataLoader(val, batch_size=batch_size)
test_loader = DataLoader(test, batch_size=batch_size)


#### Some tests

In [None]:
sample = train[0]
inputs = sample[0]
label = sample[1]
# plt.imshow(inputs.permute(1, 2, 0).numpy())
# here the image is normalized to have mean of 0.5 and std of 0.5 for each channel.
# so matplotlib will not show the image correctly and complain about the value is out of range.
label, inputs.shape, net(inputs).shape

In [None]:
batch = next(iter(train_loader))
inputs, labels = batch
inputs.shape, labels.shape, net(inputs).shape

#### Training

In [None]:

model = LightningClassifier(net, lr_rate=learning_rate)

In [None]:
name = 'cifar100-cnn2d'

# 1. Wandb Logger
wandb_logger = WandbLogger(offline=True) # add project='projectname' to log to a specific project

# 2. Learning Rate Logger
lr_logger = LearningRateMonitor()
# 3. Set Early Stopping
early_stopping = EarlyStopping('val_loss', mode='min', patience=5)
# 4. saves checkpoints to 'model_path' whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(dirpath=OUTPUTDIR / Path(name), filename='{name}_{epoch}-{val_loss:.2f}',
                                      monitor='val_loss', mode='min', save_top_k=5)

default_root_dir=OUTPUTDIR/Path(name)
default_root_dir.mkdir(parents=True, exist_ok=True)
callbacks=[lr_logger, early_stopping, checkpoint_callback]


In [None]:
# Define Trainer
trainer = pl.Trainer(max_epochs=5, logger=wandb_logger, callbacks=callbacks, 
                     default_root_dir=default_root_dir) #gpus=1
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)