# A simple notebook to demonstrate that the hello world is properly configured

Goals:
- Train a simple neural network on the MNIST dataset. 
- Log the training progress to Weight and Biases.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from datetime import datetime

import wandb
import torch
from pathlib import Path
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from inria.helloworld.models import HelloWorldMlp
from inria.helloworld.datamodules import MnistDataModule
from inria.helloworld.trainer_args import TrainerArgs

Let's first check if we have a GPU.

In [None]:
print(f"GPU available: {torch.cuda.is_available()}")

In [None]:
DATA_DIR = Path.cwd().parent / "data"
MODELS_DIR = Path.cwd().parent / "models"

In [None]:
MODEL_CHECKPOINT_DIR = MODELS_DIR / "checkpoints"
BEST_MODEL_DIR = MODELS_DIR / "best_model"

In [None]:
mnist = MnistDataModule(DATA_DIR)
mnist.prepare_data()
mnist.setup()

# grab samples to log predictions on
samples = next(iter(mnist.val_dataloader()))

In [None]:
## use a particular wandb entity
# os.environ['WANDB_ENTITY'] = "other-entity"

In [None]:
WANDB_PROJECT = "inria-helloworld-mnist"

If you have followed the instructions on `README.md`, wandb should be transparent to set up.

In [None]:
wandb.login()

If we are resuming training we want to check what runs are available in WandB, so we can resume it.

In [None]:
try:
    for run in wandb.Api().runs(path=os.environ["WANDB_ENTITY"] + "/" + WANDB_PROJECT):
        when = (
            datetime.fromtimestamp(run.summary["_timestamp"]).strftime("%m/%d/%Y, %H:%M:%S")
            if "_timestamp" in run.summary
            else "--"
        )
        print(f"Run id: {run.id} '{run.name}' {when} ({run.state}): {run.url}")
except ValueError as e:
    print(str(e))

In [None]:
RESUME_RUN_ID = None

In [None]:
# RESUME_RUN_ID = '2em89whs'  # write here the run that you want to continue

In [None]:
wandb.init(dir=MODELS_DIR, project=WANDB_PROJECT, resume="allow", id=RESUME_RUN_ID)

In [None]:
if wandb.run.resumed:
    print("Resumming training from.")
    model = torch.load(wandb.restore("model.ckpt").name)  # setup model
else:
    model = HelloWorldMlp(in_dims=(1, 28, 28))

In [None]:
best_models_checkpoint_callback = ModelCheckpoint(
    dirpath=BEST_MODEL_DIR, save_top_k=1, verbose=False, monitor="valid/loss_epoch", mode="min"
)
resume_checkpoint_callback = ModelCheckpoint(dirpath=MODEL_CHECKPOINT_DIR, save_last=True, save_on_train_epoch_end=True)

In [None]:
early_stop_callback = EarlyStopping(monitor="valid/loss_epoch", min_delta=0.01, patience=3, verbose=False, mode="min")

In [None]:
wandb_logger = WandbLogger(save_dir=MODELS_DIR)

In [None]:
args = TrainerArgs(
    max_epochs=1000,
    log_every_n_steps=10,
    logger=wandb_logger,
    callbacks=[best_models_checkpoint_callback, resume_checkpoint_callback, early_stop_callback],
)

In [None]:
args

In [None]:
trainer = pl.Trainer(**args.to_dict())  # passing training args

In [None]:
if wandb.run.resumed and (MODEL_CHECKPOINT_DIR / "last.ckpt").exists():
    print("Resuming training from last checkpoint.")
    trainer.fit(ckpt_path=str(MODEL_CHECKPOINT_DIR / "last.ckpt"))
else:
    print("Starting training from scratch.")
    trainer.fit(model, mnist)

In [None]:
# evaluate the model on a test set
trainer.test(datamodule=mnist, ckpt_path="best")

In [None]:
wandb.finish()