# Reappraisal Training on PyTorch Lightning

## Setup
- When running on Google Colab, mount Google Drive to access scripts.
- `cd` into the project root and install dependencies.

In [None]:
%load_ext autoreload
from google.colab import drive
drive.mount('/content/drive')
root_dir = "/content/drive/MyDrive/ldh"
%cd {root_dir}

import nltk
nltk.download('punkt')

%pip install transformers datasets pytorch-lightning "ray[tune]"
#root_dir = "/Users/danielpham/Google Drive/ldh"

In [None]:
from datetime import datetime
from pathlib import Path

import torch
import pandas as pd 
import pytorch_lightning as lit 
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, Callback

num_folds = 3
batch_size = 16
strat = 'far'
#root_dir = "/content/drive/MyDrive/ldh"

In [None]:
from reappraisalmodel.ldhdata import LDHDataModule

default_config = {
    'lr': 1e-3,
    'hidden_layer_size': 50
}

save_dir = Path(root_dir, 'output')
# Stops when the val loss stops changing
callback_earlystopping = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=5, verbose=True)
ldhdata = LDHDataModule(batch_size=batch_size, strat=strat, kfolds=num_folds, data_dir=root_dir)

In [None]:
### K-FOLD CV
%load_ext autoreload
%autoreload
from datetime import datetime

import torch

from reappraisalmodel.lightningreapp import LightningReapp

def kfold_train(num_folds):
    all_metrics = []
    for i in range(num_folds):
        split = i # Current split being trained
        train_dl = ldhdata.get_train_dataloader(split)
        val_dl = ldhdata.get_val_dataloader(split)
        
        strat = ldhdata.strat
        dt = datetime.now() # get the datetime 
        
        callback_checkpoint = ModelCheckpoint(
            dirpath=save_dir / strat,
            filename='{epoch}-{val_loss:.2f}',
            monitor='val_loss', verbose=False, 
            save_last=False, save_top_k=1, save_weights_only=False, 
            mode='min', period=1, prefix=strat)

        model = LightningReapp(default_config)
        trainer = lit.Trainer(
            gpus = 1 if torch.cuda.is_available() else None,
            gradient_clip_val=1.0,
            progress_bar_refresh_rate=30,
            terminate_on_nan=True,
            weights_summary=None,
            weights_save_path=save_dir / stra,
            callbacks=[callback_checkpoint,callback_earlystopping])
        print(f"Training on split {i + 1}")
        trainer.fit(model, train_dl, val_dl)
        all_metrics.append(trainer.logged_metrics)
    return pd.DataFrame(all_metrics)

In [None]:
kfold_train(num_folds)
