# K-Fold Training and Cross-Validation

In [None]:
#default_exp trainers
#export
import datetime
import logging
import os
import tempfile

import torch
import pandas as pd
import pytorch_lightning as lit
import wandb
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger

from reappraisalmodel.lightningreapp import LightningReapp

def kfold_train(k: int, ldhdata, strat, **trainer_kwargs) -> None:
    """Fits a LightningReapp instance with k-fold cross-validation.
    Args:
        k (int):
        ldhdata : See `reappraisalmodel.ldhdata.LDHDataModule`
    """
    all_metrics = []

    max_epochs = trainer_kwargs.pop('max_epochs', 20)

    today = datetime.datetime.today().strftime('%Y%m%d_%H%M%S')

    #Create temporary data to store checkpoint files.
    with tempfile.TemporaryDirectory() as tempdir:
        print(f'Created temporary directory: {tempdir}')

        for i in range(k):
            modelcheckpoint = ModelCheckpoint(
                monitor='val_loss',
                mode='min',
                save_top_k=3,
                verbose=True
            )
            # Model tracks the loss_distance; shows when training and validation loss begin to diverge 
            modelcheckpoint_loss_dist = ModelCheckpoint(
                monitor='loss_distance',
                mode='min',
                save_top_k=3,
                verbose=True
            )
            # Select the dataloaders for the given split.
            split = i
            train_dl = ldhdata.get_train_dataloader(split)
            val_dl = ldhdata.get_val_dataloader(split)

            session_version=f"reappmodel_{strat}_{today}"

            model = LightningReapp()
            # Mark the start time of the training session. 
            tb_logger = TensorBoardLogger("lightning_logs", name="reapp_model", version=session_version)
            trainer = lit.Trainer(
                logger = tb_logger,
                precision=16 if torch.cuda.is_available() else 32, # We use 16-bit precision to reduce computational complexity
                val_check_interval=0.25, # Check validation loss 4 times an epoch
                callbacks=[modelcheckpoint, modelcheckpoint_loss_dist], # Register callbacks with trainer.
                gpus=1 if torch.cuda.is_available() else None,
                weights_summary=None,
                max_epochs=max_epochs
            )
            print(f"Training on split {i}")
            trainer.fit(model, train_dl, val_dl)
            all_metrics.append({
                'metrics': trainer.logged_metrics,
                'checkpoint': modelcheckpoint.best_model_path,
                'num_epochs': trainer.current_epoch
            })

        outputs = []
        for split in all_metrics:
            val_loss = split['metrics']['val_loss'].item()
            train_loss = split['metrics']['train_loss'].item()
            num_epochs = split['num_epochs']
            r2score = split['metrics']['r2score']
            explained_variance = split['metrics']['explained_var']

            ckpt_path = split['checkpoint']
            filename = os.path.split(ckpt_path)[-1]

            upload_result = upload_file(ckpt_path, 'ldhdata', f'{strat}/{i}-{str(today)}-{filename}')
            print(f"Successful {filename} to s3: {upload_result}")

            row = {
                'val_loss': val_loss,
                'train_loss': train_loss,
                'num_epochs': num_epochs,
                'r2score': r2score,
                'explained_var': explained_variance
            }
            print(row)
            outputs.append(row)
    df = pd.DataFrame(outputs)
    df['r2score'] = df['r2score'].apply(lambda x: x.item())
    df['explained_var'] = df['explained_var'].apply(lambda x: x.item())

    report_path = os.path.join(tempdir, f'{strat}-{str(today)}-report.csv')
    df.to_csv(report_path)

    upload_report = upload_file(results_path, 'ldhdata', f'{strat}/{str(today)}-report.csv')
    print(f"Successful Uploading Report to s3: {upload_report}")
    print(df.describe())
    return df