### Model Checkpointing and Early Stopping
https://lightning.ai/courses/deep-learning-fundamentals/unit-6-overview-essential-deep-learning-tips-tricks/unit-6.1-model-checkpointing-and-early-stopping/

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from common_def import CustomDataModule, CustomDataset

In [None]:
dm = CustomDataModule()
dm.setup(stage='train')

In [None]:
dm.train_dataset.features.dtype
# dm.train_dataset.labels.dtype

#### Examine dataset

In [None]:
print(f"""
Train size: {len(dm.train_dataset)}
Val size: {len(dm.val_dataset)}
Test size: {len(dm.test_dataset)}
""")

train_labels_dist = pd.Series(dm.train_dataset.labels).value_counts()
val_labels_dist = pd.Series(dm.val_dataset.labels).value_counts()
test_labels_dist = pd.Series(dm.test_dataset.labels).value_counts()

print(f'Train labels distribution\n{train_labels_dist}')
print(f'\nVal labels distribution\n{val_labels_dist}')
print(f'\nTest labels distribution\n{test_labels_dist}')

#### Zero-rule baseline

In [None]:
train_baseline_acc = 100 * max(train_labels_dist) / sum(train_labels_dist)
print(f'Train baseline accuracy: {train_baseline_acc:.2f}%')

#### Model Checkpointing

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint
from common_def import LightningModel, PyTorchMLP
import lightning
from lightning.pytorch.loggers import CSVLogger

In [None]:
# save the best model with highest val_acc
callbacks = [
    ModelCheckpoint(save_top_k=1, monitor='val_acc', mode='max', save_last=True)
]

In [None]:
torch.manual_seed(12)

torch_model = PyTorchMLP(num_features=100, num_classes=2)
lightning_model = LightningModel(torch_model=torch_model, learning_rate=0.05, num_classes=2)

trainer = lightning.Trainer(
    max_epochs=10, 
    callbacks=callbacks,
    logger=CSVLogger('lightning_logs', name='LightningModel'),
    deterministic=True)

trainer.fit(model=lightning_model, datamodule=dm)

#### Visualize metrics

In [None]:
from common_def import plot_csv_logger

plot_csv_logger('lightning_logs/LightningModel/version_5/metrics.csv')

#### Best Checkpoint

In [None]:
trainer.test(model=lightning_model, datamodule=dm, ckpt_path='best')

In [None]:
trainer.test(model=lightning_model, datamodule=dm, ckpt_path='last')

#### Load the best model checkpoint

In [None]:
trainer.checkpoint_callback.best_model_path
model = PyTorchMLP(num_features=100, num_classes=2)
best_model = LightningModel.load_from_checkpoint(
    checkpoint_path=trainer.checkpoint_callback.best_model_path,
    model=model)

best_model

In [None]:
trainer.test(model=best_model, datamodule=dm)