## Using PyTorch Lightning

After you've  written a dozen  pytorch models you'll discover that there's a lot of common structure and a huge amount of boilerplate. It's good to understand what's going on undere the hood, but when moving to production use cases you'll want to opt for more reliable, reproducible code. PyTorch Lightning & Ignite are great libraries that abstract away these core bits.

#### Install Stuff

Installing `pytorch-lightning` requires tensorboard, which can have terrible conflicts. Good luck; here's the recipe that worked for me:

    pip uninstall tensorboard
    conda install tensorboard -y
    conda install pytorch-lightning -y -c conda-forge

Now install `wandb` which we'll use for logging & dashboarding of our models:

    pip install wandb

In a terminal also login to wandb by runing `wandb login`. You'll have to setup an account if you haven't before.

#### Create a synthetic dataset

In [17]:
from sklearn.model_selection import train_test_split
import numpy as np

n = 50000
# X is just a 9D normally distributed dataset
X = np.random.normal(size=(n, 9)).astype(np.float32)
# The prediction is a linear transformation on X
# from 9D to 4D plus additive noise
Y = np.random.normal(size=(n, 4)) * 1e-2 + np.dot(X, np.random.normal(size=(9, 4)))
Y = Y.astype(np.float32)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y)
# Now split train into trai & validation
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train)
X_train.shape, X_val.shape, X_test.shape, Y_train.shape, Y_val.shape, Y_test.shape

((28125, 9), (9375, 9), (12500, 9), (28125, 4), (9375, 4), (12500, 4))

#### Write an abstract model that defines training loops

Let's write our initial model as Lightning  module:

Don't be afraid of how much extra code this injects. Although it  initially looks like a ton of little class functions, it's all about being organized, deliberate, standardized and repeatable. It's not about science, it's about having good lab hygiene. 

- We'll move some of the iteration code into `training_step`  and `test_step`, and `test_epoch_end`.
- Add in a `configure_optimizers` function.
- Separate out train & test loaders

In [18]:
import torch
from torch import from_numpy
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torch.utils.data import BatchSampler
from torch.utils.data import RandomSampler


class MinimalAbstractModel(pl.LightningModule):
    def __init(*args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def training_step(self, batch, batch_nb):
        inpt, target = batch
        prediction = self.forward(inpt)
        loss = self.loss(prediction, target) + self.reg()
        log = {f'train_loss': loss}
        return {f'train_loss': loss, 'loss':loss, 'log': log}
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

#### Write our specific model

And then we'll keep out `Bottleneck` model, but now it will inherit from our `AbstractModel`. Over the next few notebooks we'll keep using the `AbstractModel` class and just stick to focusing  our changes within the subclasses.

In [19]:
from torch import nn
from pytorch_lightning.loggers.wandb import WandbLogger


class Bottleneck(MinimalAbstractModel):
    def __init__(self, n_in_cols, n_out_cols, n_hidden=3, lam1=1e-3, lam2=1e-3):
        super().__init__()
        self.lin1 = nn.Linear(n_in_cols, n_hidden)
        self.lin2 = nn.Linear(n_hidden, n_out_cols)
        # Regularization coefficients
        self.lam1 = lam1
        self.lam2 = lam2
        # this saves hyper parameters and they'll show up on W&B 
        # useful for when you inevtiably forget what they were, 
        # and then useful again for hyperparameter-tuning!
        self.save_hyperparameters()
    
    def forward(self, x):
        # x is a minibatch of rows of our features
        hidden = self.lin1(x)
        # y is a minibatch of our predictions
        y = self.lin2(hidden)
        return y

    def loss(self, prediction, target):
        # This is just the mean squared error
        return ((prediction - target)**2.0).sum()
    
    def reg(self):
        # This computes our Frobenius norm over both matrices
        # Note that we can access the Linear model's variables
        # directly if we'd like. No tricks here!
        loss_reg_m1 = (self.lin1.weight**2.0 * self.lam1).sum()
        loss_reg_m2 = (self.lin2.weight**2.0 * self.lam2).sum()
        return loss_reg_m1 + loss_reg_m2


model = Bottleneck(9, 4, 3)

# add a logger
logger = WandbLogger(name="00_intro", log_model=True, project="simple_mf")
# logger = TensorBoardLogger("tb_logs", name="bottleneck_model")

# We could have turned on multiple GPUs here, for example
# trainer = pl.Trainer(gpus=8, precision=16)    
trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=10, logger=logger)    

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


#### Train the model

Now let's fit our model and then check the test loss again. 

In [20]:
batch_size = 32
dataset = TensorDataset(from_numpy(X_train), from_numpy(Y_train))
bs = BatchSampler(RandomSampler(dataset), 
                   batch_size=batch_size, drop_last=False)
train = DataLoader(dataset, batch_sampler=bs, num_workers=8)

In [21]:
trainer.fit(model, train) 




wandb: Waiting for W&B process to finish, PID 35807
wandb: Program ended successfully.
wandb:                                                                                
wandb: Find user logs for this run at: wandb/run-20200918_083455-2he0evdd/logs/debug.log
wandb: Find internal logs for this run at: wandb/run-20200918_083455-2he0evdd/logs/debug-internal.log
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: 
wandb: Synced 00_intro_optimize: https://app.wandb.ai/sf-moody/simple_mf/runs/2he0evdd
wandb: Tracking run with wandb version 0.10.1
wandb: Run data is saved locally in wandb/run-20200918_083633-12dluwai
wandb: Syncing run 00_intro






  | Name | Type   | Params
--------------------------------
0 | lin1 | Linear | 30    
1 | lin2 | Linear | 16    


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Saving latest checkpoint..





1

Voila! The test loss (~100) is much lower than it was before  ~4000.

### Visualize the model 

Checkout the link on wandb to see train progress. For me, that link looks like (you'll get your own link, this one shouldn't work for you.)s: 

Run page: https://app.wandb.ai/chrisemoody/simple_mf-notebooks/runs/2o5ofsn4


### Add in test & validation steps

The above `MinimalAbstractModel` is good for learning, but in practice we need to train on training data, continuously monitor  the validation dataset and stop early if it diverges, and then test the final score on test data. The abstract model class gets less legible, but conceptually there's not a lot that's new going on here.

In [33]:
# %load abstract_model.py
import torch
import numpy as np
from random import shuffle
from torch import from_numpy
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torch.utils.data import BatchSampler
from torch.utils.data import RandomSampler


class AbstractModel(pl.LightningModule):
    def step(self, batch, batch_nb, prefix='train', add_reg=True):
        input, target = batch
        prediction = self.forward(input)
        loss, log = self.loss(prediction, target)
        
        if add_reg:
            loss_reg, log_ = self.reg()
            loss = loss + loss_reg
            log.update(log_)
        log[f'{prefix}_loss'] = loss
        return {f'{prefix}_loss': loss, 'loss':loss, 'log': log}

    def training_step(self, batch, batch_nb):
        return self.step(batch, batch_nb, 'train')
    
    def test_step(self, batch, batch_nb):
        # Note that we do *not* include the regularization loss
        # at test time
        return self.step(batch, batch_nb, 'test', add_reg=False)    
    
    def validation_step(self, batch, batch_nb):
        return self.step(batch, batch_nb, 'val', add_reg=False)    
    
    def test_epoch_end(self, outputs):
        loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean()
        log = {'test_loss': loss_mean}
        return {'avg_test_loss': loss_mean, 'log': log}

    def validation_epoch_end(self, outputs):
        loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
        log = {'val_loss': loss_mean}
        return {'avg_val_loss': loss_mean, 'log': log}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

Now let's redefine the BottleNeck class, but this time inherting from `AbstractModel`

In [42]:
from torch import nn
from pytorch_lightning.loggers.wandb import WandbLogger


class Bottleneck(AbstractModel):
    def __init__(self, n_in_cols, n_out_cols, n_hidden=3, lam1=1e-3, lam2=1e-3):
        super().__init__()
        self.lin1 = nn.Linear(n_in_cols, n_hidden)
        self.lin2 = nn.Linear(n_hidden, n_out_cols)
        # Regularization coefficients
        self.lam1 = lam1
        self.lam2 = lam2
        # this saves hyper parameters and they'll show up on W&B 
        # useful for when you inevtiably forget what they were, 
        # and then useful again for hyperparameter-tuning!
        self.save_hyperparameters()
    
    def forward(self, x):
        # x is a minibatch of rows of our features
        hidden = self.lin1(x)
        # y is a minibatch of our predictions
        y = self.lin2(hidden)
        return y

    def loss(self, prediction, target):
        # This is just the mean squared error
        mse = ((prediction - target)**2.0).sum()
        log = {'mse': mse}
        return mse, log
    
    def reg(self):
        # This computes our Frobenius norm over both matrices
        # Note that we can access the Linear model's variables
        # directly if we'd like. No tricks here!
        loss_reg_m1 = (self.lin1.weight**2.0 * self.lam1).sum()
        loss_reg_m2 = (self.lin2.weight**2.0 * self.lam2).sum()
        log = {'loss_reg_m1': loss_reg_m1, 'loss_reg_m2': loss_reg_m2}
        return loss_reg_m1 + loss_reg_m2, log

Before we train the model, the parameters and weights will all be initialized randomly. So when we evaluate the test loss, it'll be pretty bad.

### Tune hyperparameters with Optuna and Weights & Biases

You may have to instal optuna:
    
    pip install optuna

In [43]:
def dataloader(*arrs, batch_size=32):
    dataset = TensorDataset(*arrs)
    bs = BatchSampler(RandomSampler(dataset), 
                      batch_size=batch_size, drop_last=False)
    return DataLoader(dataset, batch_sampler=bs, num_workers=8)
 
train = dataloader(from_numpy(X_train), from_numpy(Y_train))
test = dataloader(from_numpy(X_test), from_numpy(Y_test))
val = dataloader(from_numpy(X_val), from_numpy(Y_val))

In [47]:
import optuna


def objective(trial):
    # Sample parameters -- without declaring them in advance!
    n_hid = trial.suggest_int('n_hid', 1, 10)
    lam1 = trial.suggest_loguniform('lam1', 1e-8, 1e-3)
    lam2 = trial.suggest_loguniform('lam2', 1e-8, 1e-3)
    
    model = Bottleneck(9, 4, n_hid, lam1=lam1, lam2=lam2)
    
    logger = WandbLogger(name="00_intro_optimize", log_model=True, project="simple_mf")
    logger.log_hyperparams(model.hparams)

    # Note that we added early stoping  
    trainer = pl.Trainer(max_epochs=3,
                         early_stop_callback=True,
                         logger=logger)    
    trainer.fit(model, train, val)
    test_results = trainer.test(model, test_dataloaders=[test])
    test_loss = test_results['test_loss']
    return test_loss

In [None]:
study = optuna.create_study()
study.optimize(objective, n_trials=10)

[I 2020-09-18 08:41:42,931] A new study created in memory with name: no-name-4b5586ff-dcbb-46d2-8416-88d80710da04





wandb: Waiting for W&B process to finish, PID 37619
wandb: Program ended successfully.
wandb:                                                                                
wandb: Find user logs for this run at: wandb/run-20200918_084002-2snostrd/logs/debug.log
wandb: Find internal logs for this run at: wandb/run-20200918_084002-2snostrd/logs/debug-internal.log
wandb: Run summary:
wandb:   global_step 2637
wandb:           mse 0.014328286983072758
wandb:   loss_reg_m1 8.590106517658569e-06
wandb:   loss_reg_m2 0.0054049198515713215
wandb:    train_loss 0.01974179595708847
wandb:         epoch 2
wandb:         _step 54
wandb:      _runtime 10
wandb:    _timestamp 1600443616
wandb:      val_loss 0.012885155156254768
wandb:     test_loss 0.01287017296999693
wandb: Run history:
wandb:   global_step ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
wandb:           mse ██▆▆▅▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:   loss_reg_m1 ▁▁▂▂▄▄▅▆▇▇▇█████████████████████████████
wandb:   loss_reg_m2 ▁▁▂▂▄▄▅




GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name | Type   | Params
--------------------------------
0 | lin1 | Linear | 50    
1 | lin2 | Linear | 24    


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': tensor(0.0129), 'test_loss': tensor(0.0129)}
--------------------------------------------------------------------------------



[I 2020-09-18 08:41:56,631] Trial 0 finished with value: 0.012920357286930084 and parameters: {'n_hid': 5, 'lam1': 7.092391053893507e-05, 'lam2': 1.629230674510399e-07}. Best is trial 0 with value: 0.012920357286930084.





wandb: Waiting for W&B process to finish, PID 38225
wandb: Program ended successfully.
wandb:                                                                                
wandb: Find user logs for this run at: wandb/run-20200918_084142-2sze3snq/logs/debug.log
wandb: Find internal logs for this run at: wandb/run-20200918_084142-2sze3snq/logs/debug-internal.log
wandb: Run summary:
wandb:   global_step 2637
wandb:           mse 0.011638514697551727
wandb:   loss_reg_m1 0.0008876376668922603
wandb:   loss_reg_m2 1.804168618946278e-06
wandb:    train_loss 0.012527956627309322
wandb:         epoch 2
wandb:         _step 54
wandb:      _runtime 10
wandb:    _timestamp 1600443716
wandb:      val_loss 0.012905266135931015
wandb:     test_loss 0.012920357286930084
wandb: Run history:
wandb:   global_step ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
wandb:           mse ▇█▇▇▅▅▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:   loss_reg_m1 ▁▁▁▁▂▃▃▄▅▅▅▆▆▆▇▇▇▇▇▇▇▇██████████████████
wandb:   loss_reg_m2 ▁▁▁▂▂




GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name | Type   | Params
--------------------------------
0 | lin1 | Linear | 10    
1 | lin2 | Linear | 8     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': tensor(594.6829), 'test_loss': tensor(594.6829)}
--------------------------------------------------------------------------------



[I 2020-09-18 08:42:08,085] Trial 1 finished with value: 594.6829223632812 and parameters: {'n_hid': 1, 'lam1': 1.4346912298673874e-06, 'lam2': 0.000301701653594169}. Best is trial 0 with value: 0.012920357286930084.





wandb: Waiting for W&B process to finish, PID 38405
wandb: Program ended successfully.
wandb: ERROR Control-C detected -- Run data was not synced
