## Using PyTorch Lightning

### by Michael Ruddy

To get PyTorch Lightning:

`conda install -c conda-forge pytorch-lightning`

In [None]:
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

# PyTorch stuff
import torch, torchvision
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms

# PyTorch Lightning
import pytorch_lightning as pl

Let's use the MNIST dataset to test out these features.

In [None]:
# load up the MNIST dataset
trnsfm = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((.5), (.5))])

ds_train = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=trnsfm)
ds_val = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=trnsfm)

# I'm going to do more than one "run" in this notebook
global_hyperparam = {'N_train':len(ds_train),
                     'N_val':len(ds_val)}

batch_size = 4
global_hyperparam['batch_size'] = batch_size

# dataloaders
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
dl_val = torch.utils.data.DataLoader(ds_val, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

### PyTorch Set-Up + Training

In [None]:
# model
class small_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
                
        self.linear1 = nn.Linear(64*7*7, 100)
        self.linear2 = nn.Linear(100, 10)
        
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.relu = nn.ReLU()
        self.unroll = nn.Flatten()
        
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        
        x = self.conv3(x)
        x = self.relu(x)
        
        x = self.linear1(self.unroll(x))
        x = self.relu(x)
        x = self.linear2(x)
        
        return x

Note that these next functions must be altered if I alter the task of number of inputs to the model forward pass or if I want to switch to regression.

In [None]:
# one pass through the dataloader, keyword for whether to backprop or not
def one_pass(model, dataloader, optimizer, scheduler, lossFun, backwards=True, print_loss=False):
    
    if backwards == True:
        model.train()
    else:
        model.eval()
    
    total_loss = 0.0
    for x, y in tqdm(dataloader):
        
        y_pred = model(x)
        loss = lossFun(y_pred, y)
        total_loss += loss.item()
        
        if backwards == True:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
    avg_loss = total_loss / len(dataloader)
    
    if print_loss == True:
        print(avg_loss)
    
    return avg_loss

# one pass to gather metrics
def one_pass_acc(model, dataloader, num_points):
    model.eval()
    total_incorrect = 0
    
    softmax = nn.LogSoftmax(dim=1)
    
    for x, y in dataloader:
        y_pred = softmax(model(x))
        y_pred = torch.argmax(y_pred, dim=1)
        total_incorrect += torch.count_nonzero(y - y_pred).item()
        
    acc = 1 - (total_incorrect / num_points)
    
    return acc

The training loop

In [None]:
num_epochs = 2
model = small_CNN()
lossFun = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.002, epochs=num_epochs, steps_per_epoch=len(dl_train))

for epoch in tqdm(range(num_epochs)):
    
    train_loss = one_pass(model, dl_train, optimizer, lr_scheduler, lossFun)
    valid_loss = one_pass(model, dl_val, optimizer, lr_scheduler, lossFun, backwards=False)
    
    print(f"Train loss, Epoch {epoch}:", train_loss)
    print(f"Val loss, Epoch {epoch}:", valid_loss)
    
    train_acc = one_pass_acc(model, dl_train, len(ds_train))
    valid_acc = one_pass_acc(model, dl_val, len(ds_val))

Now let's do the same model and training in PyTorch Lightning.

In [None]:
class lightning_small_CNN(pl.LightningModule):
    
    # Similarly need to set-up the weights and the forward pass
    def __init__(self, hparams):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
                
        self.linear1 = nn.Linear(64*7*7, 100)
        self.linear2 = nn.Linear(100, 10)
        
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.relu = nn.ReLU()
        self.unroll = nn.Flatten()
        
        # going to attach the loss function to the module
        self.CELoss = nn.CrossEntropyLoss()
        self.softmax = nn.LogSoftmax(dim=1)
        
        # need for scheduler, can't named self.hparams
        self.hp = hparams
        
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        
        x = self.conv3(x)
        x = self.relu(x)
        
        x = self.linear1(self.unroll(x))
        x = self.relu(x)
        x = self.linear2(x)
        
        return x

    # method for computing the loss
    def lossFun(self, y_pred, y):
        return self.CELoss(y_pred, y)
    
    # we can define our metric functions below
    def acc(self, y_pred, y):
        y_pred = torch.argmax(y_pred, dim=1)
        total_incorrect = torch.count_nonzero(y - y_pred).item()
        
        return 1 - (total_incorrect / torch.numel(y))
    
    # this method must be named training_step
    def training_step(self, train_batch, batch_idx):
        
        x, y = train_batch

        # now these functions are wrapped up in self
        y_pred = self.forward(x)
        loss = self.lossFun(y_pred, y)
        self.log('train_loss', loss, on_epoch=True)
        
        # compute metrics
        acc = self.acc(y_pred, y)
        self.log('train_acc', acc, on_step=False, on_epoch=True)
        
        return loss

    # instead of a on/off switch for the backward pass, we simply define a separate step for validation
    # must be named validation_step
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        y_pred = self.forward(x)
        loss = self.lossFun(y_pred, y)
        self.log('val_loss', loss)
        
        acc = self.acc(y_pred, y)
        self.log('val_acc', acc)

    # here we configure the optimizer
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        
        # we can even pass the scheduler!
        # because the annealing scheduler needs to know the number of batches and epochs, we'll pass a hparam dictionary to the model later
        lr_scheduler = {
                        'scheduler': optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.002,
                                                     epochs=self.hp['num_epochs'],
                                                     steps_per_epoch=self.hp['num_batches']),
                        'interval': 'step'  # forces updates after each training step, instead of per epoch
                        }
        # we pass lists here, because lightning support multiple optimizers!
        return [optimizer], [lr_scheduler]

Now we have the same training loop

In [None]:
num_epochs = 2
hparams = {'num_epochs': num_epochs,
           'num_batches': len(dl_train)}
model = lightning_small_CNN(hparams)
trainer = pl.Trainer(max_epochs=num_epochs)

trainer.fit(model, dl_train, dl_val)

There's a few nice bells and whistles here.
- Automatic progress bar!
- Makes logging the train and validation loss easy (logs stored in lightning_logs
- Trainer first makes sure the forward loop runs on the validation set
- Most of the training loop can be abstracted to the Module which makes training from scripts very easy

In [None]:
# load up tensorboard to view the logs!
%load_ext tensorboard
%tensorboard --logdir lightning_logs

We can even pair with neptune.ai, need to run the following first:

`pip install neptune-pytorch-lightning`

In [None]:
from neptune.new.integrations.pytorch_lightning import NeptuneLogger

# frustratingly enough note that api_token is called api_key here!
run = NeptuneLogger(
    project="your_project_name",
    api_key="your_api_key",
    name = "Lightning_Test",
)

num_epochs = 2
hparams = {'num_epochs': num_epochs,
           'num_batches': len(dl_train)}
model = lightning_small_CNN(hparams)
trainer = pl.Trainer(max_epochs=num_epochs, logger=run)

trainer.fit(model, dl_train, dl_val)