# Let's try training now a model to train on the CIFAR'10 dataset

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging

logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)

In [None]:
import os
from datetime import datetime

import wandb
import torch
from pathlib import Path
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, TQDMProgressBar

from inria.helloworld.trainer_args import TrainerArgs
from inria.helloworld.datamodules import Cifar10DataModule
from inria.helloworld.models import HelloWorldResnet

Let's first check if we have a GPU.

In [None]:
print(f"GPU available: {torch.cuda.is_available()}")

In [None]:
DATA_DIR = Path.cwd().parent / "data"
MODELS_DIR = Path.cwd().parent / "models"

In [None]:
MODEL_CHECKPOINT_DIR = MODELS_DIR / "checkpoints"
BEST_MODEL_DIR = MODELS_DIR / "best_model"

In [None]:
WANDB_PROJECT = "inria-helloworld-cifar"

In [None]:
wandb.login()

If we are resuming training we want to check what runs are available in WandB, so we can resume it.

In [None]:
try:
    for run in wandb.Api().runs(path=os.environ["WANDB_ENTITY"] + "/" + WANDB_PROJECT):
        when = (
            datetime.fromtimestamp(run.summary["_timestamp"]).strftime("%m/%d/%Y, %H:%M:%S")
            if "_timestamp" in run.summary
            else "--"
        )
        print(f"Run id: {run.id} '{run.name}' {when} ({run.state}): {run.url}")
except ValueError as e:
    print(str(e))

In [None]:
RESUME_RUN_ID = None

In [None]:
# RESUME_RUN_ID = '2em89whs'  # write here the run that you want to continue

In [None]:
wandb.init(dir=MODELS_DIR, project=WANDB_PROJECT, resume="allow", id=RESUME_RUN_ID)

In [None]:
if wandb.run.resumed:
    print("Resumming training from.")
    model = torch.load(wandb.restore("model.ckpt").name)  # setup model
else:
    model = HelloWorldResnet()

In [None]:
best_models_checkpoint_callback = ModelCheckpoint(
    dirpath=BEST_MODEL_DIR, save_top_k=1, verbose=False, monitor="valid/loss_epoch", mode="min"
)
resume_checkpoint_callback = ModelCheckpoint(dirpath=MODEL_CHECKPOINT_DIR, save_last=True, save_on_train_epoch_end=True)

In [None]:
early_stop_callback = EarlyStopping(monitor="valid/loss_epoch", min_delta=0.0000001, patience=30, verbose=False, mode="min")

In [None]:
wandb_logger = WandbLogger(save_dir=MODELS_DIR)

In [None]:
args = TrainerArgs(
    max_epochs=30,
    log_every_n_steps=10,
    logger=wandb_logger,
    callbacks=[
        best_models_checkpoint_callback,
        resume_checkpoint_callback,
        early_stop_callback,
        TQDMProgressBar(refresh_rate=10),
    ],
)

In [None]:
args

In [None]:
cifar = Cifar10DataModule(DATA_DIR)

In [None]:
args.max_epochs = 30

In [None]:
trainer = pl.Trainer(**args.to_dict())

In [None]:
trainer.fit(model, cifar)

In [None]:
trainer.test(datamodule=cifar, ckpt_path="best")

In [None]:
wandb.finish()