### import modules

In [None]:
import logging
import os
import time

from pytorch_lightning.callbacks import ModelCheckpoint, BatchSizeFinder
from pytorch_lightning.loggers import NeptuneLogger, TensorBoardLogger
import neptune.new as neptune
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import RichProgressBar
import yaml

from src.data import dataset
from src.utils import callbacks
from src.wakegan import WakeGAN

torch.set_float32_matmul_precision("medium")

### import config file

In [None]:
with open("config.yaml") as file:
    config = yaml.safe_load(file)

### initialize neptune client

In [None]:
neptune_logger = None
if config["ops"]["neptune_logger"]:
    neptune_logger = NeptuneLogger(
        project="idatha/wakegan",
        api_key="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIyNWQ5YjJjZi05OTE1LTRhNWEtODdlZC00MWRlMzMzNGMwMzYifQ==",
        log_model_checkpoints=False,
    )

### add config content as hyperparameters in neptune

In [None]:
if config["ops"]["neptune_logger"]:
    neptune_logger.log_hyperparams(params=config)

### config custom loggers

In [None]:
tb_logger = TensorBoardLogger(save_dir="logs/")
if "logs" not in os.listdir("."):
    os.mkdir("logs")
logging.basicConfig(
    format="%(message)s",
    filename=os.path.join("logs", "train.log"),
    level=logging.INFO,
    filemode="w",
)
logger = logging.getLogger("train")
loggers = (
    [tb_logger, neptune_logger] if config["ops"]["neptune_logger"] else [tb_logger]
)

### configure checkpoint model saving

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath=None,
    save_top_k=1,
    monitor="rmse_val_epoch",
    mode="min",
    filename="wakegan-{epoch}-{rmse_val_epoch:.2f}",
)

### initialize training dataset

In [None]:
dataset_train = dataset.WakeGANDataset(
    data_dir=os.path.join("data", "preprocessed", "tracked", "train"),
    config=config["data"],
    dataset_type="train",
    save_norm_params=True if config["models"]["save"] else False,
)
datamodule = dataset.WakeGANDataModule(config)

### initialize model

In [None]:
model = WakeGAN(config, dataset_train.norm_params)

### initialize trainer

In [None]:
trainer = pl.Trainer(
    default_root_dir="logs",
    accelerator="gpu",
    devices=1,
    log_every_n_steps=1,
    max_epochs=config["train"]["num_epochs"],
    logger=loggers,
    deterministic=True,
    callbacks=[
        callbacks.LoggingCallback(logger),
        callbacks.PlottingCallback(enable_logger=config["ops"]["neptune_logger"]),
        checkpoint_callback,
        RichProgressBar()
    ],
)

### fit model

In [None]:
trainer.fit(model, datamodule)

### save new model version (best checkpoint) to neptune

In [None]:
if config["ops"]["neptune_logger"] and config["models"]["save"]:
    logger.info("Saving model in neptune")

    model_version = neptune.init_model_version(
        model="WAK-MOD",
        project="idatha/wakegan",
        api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIyNWQ5YjJjZi05OTE1LTRhNWEtODdlZC00MWRlMzMzNGMwMzYifQ==",  # your credentials
    )
    path_to_model = trainer.checkpoint_callback.best_model_path
    model_version["model/ckpt"].upload(path_to_model)
    model_version["model/dataset/training"].track_files(
        os.path.join("data", "preprocessed", "tracked", "train", "ux")
    )
    model_version["model/dataset/validation"].track_files(
        os.path.join("data", "preprocessed", "tracked", "val", "ux")
    )
    model_version["model/dataset/testing"].track_files(
        os.path.join("data", "preprocessed", "tracked", "test", "ux")
    )
    model_version["model/run"] = neptune_logger.run["sys/id"].fetch()
    model_version.change_stage("staging")

### stop run

In [None]:
neptune_logger.run.stop()