<a href="https://colab.research.google.com/github/cyyeh/time-series-playground/blob/main/pytorch-101/pytorch_lightning_101.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PyTorch Lightning 101

Adapt the [PyTorch quickstart example](https://github.com/cyyeh/time-series-playground/blob/main/pytorch-101/quickstart.ipynb) to PyTorch Lightning!

In [1]:
#@title Install Required Dependencies { vertical-output: true }

!pip install pytorch_lightning

Collecting pytorch_lightning
[?25l  Downloading https://files.pythonhosted.org/packages/b6/6a/20d0bf3b967ab62333efea36fe922aaa252d1762555b4a7afb2be5bbdcbf/pytorch_lightning-1.3.5-py3-none-any.whl (808kB)
[K     |████████████████████████████████| 808kB 29.3MB/s 
Collecting PyYAML<=5.4.1,>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/7a/a5/393c087efdc78091afa2af9f1378762f9821c9c1d7a22c5753fb5ac5f97a/PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636kB)
[K     |████████████████████████████████| 645kB 49.6MB/s 
[?25hCollecting fsspec[http]!=2021.06.0,>=2021.05.0
[?25l  Downloading https://files.pythonhosted.org/packages/bc/52/816d1a3a599176057bf29dfacb1f8fadb61d35fbd96cb1bab4aaa7df83c0/fsspec-2021.5.0-py3-none-any.whl (111kB)
[K     |████████████████████████████████| 112kB 53.0MB/s 
Collecting torchmetrics>=0.2.0
[?25l  Downloading https://files.pythonhosted.org/packages/3b/e8/513cd9d0b1c83dc14cd8f788d05cd6a34758d4fd7e4f9e5ecd5d7d599c95/torchmetrics-0.3.2-py3-none

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl

## Working with data

In [3]:
# Download training data from open datasets.
training_data = FashionMNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = FashionMNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor(),
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=26421880.0), HTML(value='')))


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=29515.0), HTML(value='')))


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4422102.0), HTML(value='')))


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=5148.0), HTML(value='')))


Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
  print(f'Shape of X[N, C, H, W]: {X.shape}')
  print(f'Shape of y: {y.shape} {y.dtype}')
  break

Shape of X[N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


## Creating Models

In [19]:
class NeuralNetwork(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.flatten = nn.Flatten()
    self.linear_relu_stack = nn.Sequential(
        nn.Linear(28*28, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, 10),
        nn.ReLU()
    )

  def forward(self, x):
    x = self.flatten(x)
    logits = self.linear_relu_stack(x)
    return logits

  def configure_optimizers(self):
    return torch.optim.SGD(self.parameters(), lr=1e-3)
  
  def training_step(self, train_batch, batch_idx):
    X, y = train_batch
    pred = self(X)
    loss_fn = nn.CrossEntropyLoss()
    loss = loss_fn(pred, y)
    return loss

  def validation_step(self, val_batch, batch_idx):
    X, y = val_batch
    pred = self(X)
    loss_fn = nn.CrossEntropyLoss()
    loss = loss_fn(pred, y)

In [20]:
model = NeuralNetwork()

trainer = pl.Trainer(gpus=1, max_epochs=5, check_val_every_n_epoch=1)
trainer.fit(model, train_dataloader, test_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type       | Params
-------------------------------------------------
0 | flatten           | Flatten    | 0     
1 | linear_relu_stack | Sequential | 669 K 
-------------------------------------------------
669 K     Trainable params
0         Non-trainable params
669 K     Total params
2.679     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


