# Rapid prototyping with Pytorch lightning

- toc: true 
- hide: false
- branch: master
- search_exclude: false
- badges: true
- comments: true
- categories: [pytorch, wandb, pytorch-lightning]

If one wants to create rapid prototypes of a pytorch model, it can be a bit tedious. Remembering all the `to_device` calls, using `with torch.no_grad():`, `model.eval()`, `model.train()` can be problematic especially when the focus is on writing an algorithm that works out of the box with the correct dimensions and such. Fortunately `pytorch-lightning` makes it [a lot easier](https://pytorch-lightning.readthedocs.io/en/stable/rapid_prototyping_templates.html). I have modified the basic prototyping template to include: 
 - `wandb` logging: my personal favorite logging tool at the moment
 - `einops`: makes tensor operations easy to read and execute

More on wandb integration [here](https://wandb.ai/site/articles/pytorch-lightning-with-weights-biases) and [here](https://wandb.ai/cayush/pytorchlightning/reports/Use-Pytorch-Lightning-with-Weights-Biases--Vmlldzo2NjQ1Mw)

In [1]:
# hide
%%capture
! pip install pytorch-lightning
! pip install pytorch-lightning-bolts
! pip install einops
! pip install wandb

In [2]:
# hide
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pl_bolts.datasets import DummyDataset

# from einops import rearrange, reduce, repeat
import einops
from einops.layers.torch import Rearrange, Reduce

In [None]:
# Initialize wandb
import wandb
wandb.init()

In [4]:
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(project='rapidprotyping-torchlightning')

---
## Data

In [5]:
data = DummyDataset((1, 28, 28), (1,), num_samples=1000)
data.shapes, data.num_samples

(((1, 28, 28), (1,)), 1000)

In [6]:
dl = DataLoader(data, batch_size=10)
x, y = next(iter(dl))
x.size(), y.size()

(torch.Size([10, 1, 28, 28]), torch.Size([10, 1]))

In [7]:
x.view(x.size(0), -1).shape, einops.rearrange(x, 'b c h w -> b (c h w)').shape, x.shape

(torch.Size([10, 784]), torch.Size([10, 784]), torch.Size([10, 1, 28, 28]))

In [8]:
train = DummyDataset((1, 28, 28), (1,))
train = DataLoader(train, batch_size=32)

In [9]:
val = DummyDataset((1, 28, 28), (1,))
val = DataLoader(val, batch_size=32)

In [10]:
test = DummyDataset((1, 28, 28), (1,))
test = DataLoader(test, batch_size=32)

---

## Model

In [11]:
class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128), 
            nn.ReLU(), 
            nn.Linear(128, 3)
            )
        self.decoder = nn.Sequential(
            nn.Linear(3, 128), 
            nn.ReLU(), 
            nn.Linear(128, 28 * 28)
            )

    def training_step(self, batch, batch_idx):
        # add graph in tensor board
        # tensorboard = self.logger.experiment
        x, y = batch
        # x = x.view(x.size(0), -1)
        x = einops.rearrange(x, 'b c h w -> b (c h w)')
        # tensorboard.add_graph(self.encoder, x[0])
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = einops.rearrange(x, 'b c h w -> b (c h w)')
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        x, y = batch
        x = einops.rearrange(x, 'b c h w -> b (c h w)')
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

---
## Train
NOTE: in colab, set progress_bar_refresh_rate high or the screen will freeze because of the rapid tqdm update speed.

In [12]:
# init model
ae = LitAutoEncoder()

# Initialize a trainer
trainer = pl.Trainer(max_epochs=3, 
                     progress_bar_refresh_rate=20, 
                    #  early_stop_callback=True, 
                     logger=wandb_logger)

# Train the model ⚡
trainer.fit(ae, train, val)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 100 K 
1 | decoder | Sequential | 101 K 
---------------------------------------
202 K     Trainable params
0         Non-trainable params
202 K     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

---
## Test

In [13]:
trainer.test(test_dataloaders=test)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.0834)}
--------------------------------------------------------------------------------


[{'test_loss': 0.08343494683504105}]

## Save checkpoint

In [1]:
trainer.save_checkpoint('model_prototype.pth')
wandb.save('model_prototype.pth')

wandb.finish()

---
## Visualize

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

---
## Observations
Do your analysis and notes here!