# Starting notebook

Here we train a simple MLP on MNIST.

In [1]:
# Some useful modules for notebooks
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
# Test if the package is installed correctly
from packagename import print_version
print_version()

Parameters...

In [4]:
seed = 0 # for reproducibility
batch_size = 64
num_epochs = 5
learning_rate = 1e-3
hidden_dim = 256
num_classes = 10
n_layers = 3
input_dim = 28*28

## Build the Neural Network

In [5]:
from packagename.model import MLP 

In [6]:
net = MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=num_classes, n_layers=n_layers, use_softmax=True)

## Prepare Data Loader

In [7]:
from torch.utils.data import DataLoader
from packagename.dataset import load_mnist
# load data
train, val, test = load_mnist()

train_loader, val_loader, test_loader = DataLoader(train, batch_size=batch_size), DataLoader(val, batch_size=batch_size), DataLoader(test, batch_size=batch_size)


## Lightening Module class

In [8]:
from packagename.lightning import LightningClassifier

model = LightningClassifier(net, lr_rate=learning_rate)

2023-05-31 13:20:07 - torch.distributed.nn.jit.instantiator - INFO - Created a temporary directory at /var/folders/4m/pvkyyf611tz1t4j2ryl7ttv40000gn/T/tmpjtesngxf
2023-05-31 13:20:07 - torch.distributed.nn.jit.instantiator - INFO - Writing /var/folders/4m/pvkyyf611tz1t4j2ryl7ttv40000gn/T/tmpjtesngxf/_remote_module_non_scriptable.py
  rank_zero_warn(


## Train the model

In [11]:
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pathlib import Path
from pytorch_lightning.loggers import WandbLogger
# output directory
from packagename.conf import OUTPUTDIR


name = 'mnist-mlp'

# 1. Wandb Logger
wandb_logger = WandbLogger() # add project='projectname' to log to a specific project

# 2. Learning Rate Logger
lr_logger = LearningRateMonitor()
# 3. Set Early Stopping
early_stopping = EarlyStopping('val_loss', mode='min', patience=5)
# 4. saves checkpoints to 'model_path' whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(dirpath=OUTPUTDIR / Path(name), filename='{name}_{epoch}-{val_loss:.2f}',
                                      monitor='val_loss', mode='min', save_top_k=5)

(OUTPUTDIR/Path(name)).mkdir(parents=True, exist_ok=True)
# Define Trainer
trainer = pl.Trainer(max_epochs=5, logger=wandb_logger, callbacks=[lr_logger, early_stopping, checkpoint_callback], 
                     default_root_dir=OUTPUTDIR/Path(name)) #gpus=1

2023-05-31 13:20:07 - pytorch_lightning.utilities.rank_zero - INFO - GPU available: True (mps), used: True
2023-05-31 13:20:07 - pytorch_lightning.utilities.rank_zero - INFO - TPU available: False, using: 0 TPU cores
2023-05-31 13:20:07 - pytorch_lightning.utilities.rank_zero - INFO - IPU available: False, using: 0 IPUs
2023-05-31 13:20:07 - pytorch_lightning.utilities.rank_zero - INFO - HPU available: False, using: 0 HPUs


In [12]:
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

2023-05-31 13:20:07 - pytorch_lightning.callbacks.model_summary - INFO - 
  | Name | Type | Params
------------------------------
0 | net  | MLP  | 335 K 
------------------------------
335 K     Trainable params
0         Non-trainable params
335 K     Total params
1.340     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

2023-05-31 13:24:39 - pytorch_lightning.utilities.rank_zero - INFO - `Trainer.fit` stopped: `max_epochs=5` reached.


In [16]:
path_last = OUTPUTDIR/Path(name)/'last.ckpt'
trainer.save_checkpoint(path_last)
print(path_last)

/Users/nati/SDSC/codes/code-template-lightning/outputs/mnist-mlp/last.ckpt


In [17]:
path_best = Path(checkpoint_callback.best_model_path)
print(path_best)

/Users/nati/SDSC/codes/code-template-lightning/outputs/mnist-mlp/name=0_epoch=4-val_loss=-0.96.ckpt


In [19]:
from packagename.utils import load_model
model = load_model(LightningClassifier, path_best) # Load best model
model = load_model(LightningClassifier, path_last) # Load last model