# Starting notebook

Here we train a simple MLP on MNIST.

In [None]:
# Some useful modules for notebooks
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# Test if the package is installed correctly
from packagename import print_version
print_version()

Parameters...

In [None]:
batch_size = 64
num_epochs = 5
learning_rate = 1e-3
hidden_dim = 256
num_classes = 10
n_layers = 3
input_dim = 28*28

## Build the Neural Network

In [None]:
from packagename.model import MLP 

In [None]:
net = MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=num_classes, n_layers=n_layers, use_softmax=True)

## Prepare Data Loader

In [None]:
from torch.utils.data import DataLoader
from packagename.dataset import load_mnist
# load data
train, val, test = load_mnist()

train_loader, val_loader, test_loader = DataLoader(train, batch_size=batch_size), DataLoader(val, batch_size=batch_size), DataLoader(test, batch_size=batch_size)


## Lightening Module class

In [None]:
from packagename.lightning import LightningClassifier

model = LightningClassifier(net, lr_rate=learning_rate)

## Train the model

In [None]:
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pathlib import Path
from pytorch_lightning.loggers import WandbLogger
# output directory
from packagename.conf import OUTPUTDIR


name = 'mnist-mlp'

# 1. Wandb Logger
wandb_logger = WandbLogger() # add project='projectname' to log to a specific project

# 2. Learning Rate Logger
lr_logger = LearningRateMonitor()
# 3. Set Early Stopping
early_stopping = EarlyStopping('val_loss', mode='min', patience=5)
# 4. saves checkpoints to 'model_path' whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(dirpath=OUTPUTDIR / Path(name), filename='{name}_{epoch}-{val_loss:.2f}',
                                      monitor='val_loss', mode='min', save_top_k=5)

(OUTPUTDIR/Path(name)).mkdir(parents=True, exist_ok=True)
# Define Trainer
trainer = pl.Trainer(max_epochs=5, logger=wandb_logger, callbacks=[lr_logger, early_stopping, checkpoint_callback], 
                     default_root_dir=OUTPUTDIR/Path(name)) #gpus=1

In [None]:
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
path_last = OUTPUTDIR/Path(name)/'last.ckpt'
trainer.save_checkpoint(path_last)
print(path_last)

In [None]:
path_best = Path(checkpoint_callback.best_model_path)
print(path_best)

In [None]:
from packagename.utils import load_model
model = load_model(LightningClassifier, path_best) # Load best model
model = load_model(LightningClassifier, path_last) # Load last model