In [3]:
import import_ipynb
import Model

import lightning as L
import torchmetrics
import torch
import torch.nn.functional as F

# Lightning Library
- It simplifies deep learning workflow with PyTorch
- Less code & fewer bugs & better maintability
- Excellent modularity
- LightningModule & Trainer are the only 2 APIs, reset is organized
### LightningModule
- Easy to use any PyTorch model including LLM & Transformers etc.
- Easy to use any sized datasets
### Trainer
- Scale models to trillions of parameters accross 1000s of GPUs
- Compressing models with no accuracy loss with precision technics
- Easy to define max_time & max_steps to save time & money
- Easy to reproduce

<img src="https://www.assemblyai.com/blog/content/images/2021/12/breakdown.png" width=1000, height=400>

# Examples

### GPU enabled training
- PyTorch
  ```python 
  model.to("cuda")
  features.to("cuda")
  labels.to("cuda")
  ```
- Lightning
  ```python
  # Train with 1 gpu
  trainer = Trainer(accelerator="gpu", devices=1, ...)
  # Train with 4 gpu
  trainer = Trainer(accelerator="gpu", devices=4, ...)
  # Train with 1st&3rd gpu
  trainer = Trainer(accelerator="gpu", devices=[0, 2], ...)
  ```
### Training Strategies
- Data Parallel (multiple gpus, 1 machine)
    ```python
    Trainer(strategy='dp')
    ```
- Distrubuted Data Parallel (multiple gpus, many machines)
    ```python
    Trainer(strategy='ddp')
    ```

### Debugging
- Controll running without training
  ```python
    # Runs only 1 training and 1 validation batch and the program ends
    Trainer(fast_dev_run=True) 
    # Runs 7 training and 7 validation batch and the program ends
    Trainer(fast_dev_run=7) 
    ```
- Use this much of data
  ```python
    # Use only %1 of the train & val set
    Trainer(overfit_batches=0.01)
    # Use 10 of the same batches
    Trainer(overfit_batches=10) 
    ```

### I want to use Lightning Multi GPU & Mixed Precision Training but not want to restructure my code into LightningModule
- Lightning Fabric
    ```python
    # Define fabric
    fabric = Fabric(accelerator="gpu", devices=64, strategy="ddp")
    # Set up model & optimizer & dataloaders
    model, optimizer = fabric.setup(model, optimizer)
    dataloader = fabric.set_dataloaders(dataloader)
    model.train()
    for epoch in range(num_epochs):
        for batch in dataloader:
            input, target = batch
            optimizer.zero_grad()
            output = model(input)
            loss = loss_fn(output, target)
            # loss.backward() -> fabric.backward(loss)
            fabric.backward(loss)
            optimizer.step()
    ```

# Speed Comparison of Training DistilBERT
<img src="https://sebastianraschka.com/images/blog/2023/pytorch-faster/benchmark-last.png" width="1000" height="400">

# Lightning Model

In [4]:
class LightningModel(L.LightningModule):
    def __init__(self, model, lr):
        super().__init__()
        self.model = model
        self.lr = lr

        self.save_hyperparameters(ignore=["model"])
        
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
        return self.model(x)

    def _common_step(self, batch):
        images, true_labels = batch
        logits = self(images)
        loss = F.cross_entropy(logits, true_labels)
        predicted_labels = torch.argmax(logits, dim=1)
        return loss, predicted_labels, true_labels
    
    def training_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._common_step(batch)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.train_acc(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._common_step(batch)
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.val_acc(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, prog_bar=True, on_step=False, on_epoch=True)

    def test_step(self, batch, batch_idx):
        loss, predicted_labels, true_labels = self._common_step(batch)
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.test_acc(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, prog_bar=True, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr = self.lr)
        return optimizer
    