### Debugging Neural Network
https://lightning.ai/courses/deep-learning-fundamentals/unit-6-overview-essential-deep-learning-tips-tricks/6.8-debugging-deep-neural-networks/

In [14]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from common_def import CustomDataModule, CustomDataset, PyTorchMLP
import lightning
import torchmetrics
import torch.nn.functional as F

In [24]:
class LightningModel(lightning.LightningModule):
    def __init__(self, num_classes=None, num_features=None, hidden_units=None, learning_rate=None):
        super().__init__()
        self.learning_rate = learning_rate
        self.torch_model = PyTorchMLP(num_features=num_features, hidden_units=hidden_units, num_classes=num_classes)

        # save hyperparameters (but skip the model parameters)
        self.save_hyperparameters(ignore=['model'])

        # for model summary; 32 is batch_size?
        self.example_input_array = torch.Tensor(32, num_features)

        self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        self.val_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)

    def forward(self, x):
        return self.torch_model(x)
    
    def _process_step(self, batch, batch_idx):
        batch_features, batch_labels = batch
        logits = self.forward(batch_features)
        loss = F.cross_entropy(logits, batch_labels)
        predictions = torch.argmax(logits, dim=1)
        return loss, predictions, batch_labels
    
    def training_step(self, batch, batch_idx):
        loss, predictions, labels = self._process_step(batch, batch_idx)
        self.log("train_loss", loss)
        self.train_acc(predictions, labels)
        self.log('train_acc', self.train_acc, prog_bar=True, on_epoch=True, on_step=False)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, predictions, labels = self._process_step(batch, batch_idx)
        self.log("val_loss", loss)
        self.val_acc(predictions, labels)
        self.log('val_acc', self.val_acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        loss, predictions, labels = self._process_step(batch, batch_idx)
        self.log("test_loss", loss)
        self.val_acc(predictions, labels)
        self.log('test_acc', self.val_acc, prog_bar=True)

    
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
        # return optimizer # without scheduler

        # use Learning Rate scheduler
        # for every 10 epochs, reduce the learning rate by a factor of 0.5
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
        return [optimizer], [scheduler]

In [16]:
dm = CustomDataModule()
dm.setup(stage='train')

#### Stop early to verify

In [25]:
torch.manual_seed(12)

lightning_model = LightningModel(num_features=100, hidden_units=[50, 25], num_classes=2, learning_rate=0.05)

trainer = lightning.Trainer(
    max_epochs=10, 
    fast_dev_run=5, # stop after 5 steps
    deterministic=True
)

trainer.fit(model=lightning_model, datamodule=dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 5 batch(es). Logging and checkpointing is suppressed.



  | Name        | Type               | Params | In sizes  | Out sizes
---------------------------------------------------------------------------
0 | torch_model | PyTorchMLP    | 6.4 K  | [32, 100] | [32, 2]  
1 | train_acc   | MulticlassAccuracy | 0      | ?         | ?        
2 | val_acc     | MulticlassAccuracy | 0      | ?         | ?        
---------------------------------------------------------------------------
6.4 K     Trainable params
0         Non-trainable params
6.4 K     Total params
0.026     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=5` reached.


#### Model Summary

In [18]:
from lightning.pytorch.utilities.model_summary import ModelSummary

In [26]:
summary = ModelSummary(lightning_model, max_depth=-1)
print(summary)

  | Name                 | Type               | Params | In sizes  | Out sizes
------------------------------------------------------------------------------------
0 | torch_model          | PyTorchMLP    | 6.4 K  | [32, 100] | [32, 2]  
1 | torch_model.layers   | Sequential         | 6.4 K  | [32, 100] | [32, 2]  
2 | torch_model.layers.0 | Linear             | 5.0 K  | [32, 100] | [32, 50] 
3 | torch_model.layers.1 | ReLU               | 0      | [32, 50]  | [32, 50] 
4 | torch_model.layers.2 | Linear             | 1.3 K  | [32, 50]  | [32, 25] 
5 | torch_model.layers.3 | ReLU               | 0      | [32, 25]  | [32, 25] 
6 | torch_model.layers.4 | Linear             | 52     | [32, 25]  | [32, 2]  
7 | train_acc            | MulticlassAccuracy | 0      | ?         | ?        
8 | val_acc              | MulticlassAccuracy | 0      | ?         | ?        
------------------------------------------------------------------------------------
6.4 K     Trainable params
0         Non-trai