# Example of how to use torchtrainer with *Cross-Validation*

In [1]:
from typing import List, Tuple
import pathlib
import random

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, Subset, DataLoader
from sklearn.model_selection import train_test_split, KFold

# trainer
from torchtrainer import Trainer
# hooks
from torchtrainer import EarlyStopping, NaNStopping, CSVHook

In [2]:
DEVICE = torch.device("cpu")
BATCH = 10
NSPLIT = 5

SEED = 42

def torch_fix_seed(seed=SEED):
    # Python random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

torch_fix_seed()

## Prepare dataset

In [3]:
# dummy dataset
x = torch.rand(100, 32)
y = torch.rand(100, 1)

dataset = TensorDataset(x, y)

## Prepare model

In [4]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(32, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
    def forward(self, x):
        return self.layers(x)
model = Model()

## Define Training

Model training can be easily performed in **only two steps**.

    1. Define one training step. (if necessary, also define one validation step.)
    2. Define Trainer object in cross validation roop.

### 1. Define one training step
Define a function representing one training step.  
The function should take one `batch` and `model`, and return a list of `loss` and a list of `prediction` in that order.  
One `batch` can handle any data format (ex `dict`, `tuple`).

In [5]:
# define loss function
loss_fn = torch.nn.L1Loss(reduction="sum")

def train_step(train_batch, model) -> Tuple[List[torch.Tensor]]:
    # get input and output from one batch
    # if necessary, using `to()` method
    x = train_batch[0].to(DEVICE)
    y = train_batch[1].to(DEVICE)
    # prediction
    pred = model(x)
    # calculate loss
    loss = loss_fn(pred, y)
    
    # return list of loss and list of prediction
    return [loss], [pred]

In [6]:
# Optional: metrics for CSVhook
# metrics get one batch and result list that you define on train_step
# and return List of Tensor
def rmse(batch, result_list) -> List[torch.Tensor]:
    grand_truth = batch[1]
    pred = result_list[0]
    return [torch.sqrt(torch.mean(torch.pow(grand_truth - pred, 2)))]

### 2. Define Trainer object in crooss validation roop

In [7]:
kf = KFold(n_splits=NSPLIT, random_state=SEED, shuffle=True)
save_path = pathlib.Path("./")

scores = []
for fold, (train_index, test_index) in enumerate(kf.split(dataset)):
    print(f"Fold {fold}")
    print("-" * 50)
    
    # data loader
    train_loader = DataLoader(Subset(dataset, train_index), batch_size=BATCH, shuffle=True)
    val_loader = DataLoader(Subset(dataset, test_index), batch_size=BATCH, shuffle=True)
    
    # get model object
    model = Model()

    # define optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # Optional: define LR sceduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)
    
    # you should set save path for each fold
    save_fold_path = save_path.joinpath(f"trainer_{fold}")
    if not save_fold_path.exists():
        save_fold_path.mkdir()
    
    # trainer object
    trainer = Trainer(
        # The best model and the `state_dict` of the last step are automatically saved.
        # last step is saved in checkpoints directory.
        model_path=save_fold_path,

        # PyTorch model
        model=model,

        # training setting
        n_epoch=100,
        device=DEVICE,

        # set loader
        train_loader=train_loader,
        val_loader=val_loader,

        # set List of optimizer
        optimizer_list=[optimizer],

        # Optional: set List of scheduler
        scheduler_list=[scheduler],

        # set List of hooks
        hooks=[
            # early stoppting hooks
            EarlyStopping(10),
            # catch nan of loss
            NaNStopping(),
            # log to csv
            CSVHook(
                log_path=save_fold_path,
                # Optional: you can set any metrics you like other than loss.
                metrics=[
                    rmse
                ],
                # set name of metrics which write on csv file
                metrics_names=["rmse"]
            )
        ]
    )
    
    # Training
    # Just call the train method!
    best = trainer.train(
        batch=BATCH,
        train_step=train_step,
        verbose=True,
    )
    
    # best loss is detached torch.tensor
    scores.append(best)

Fold 0
--------------------------------------------------
epoch 1 start
Training loss:
	 0: 0.263
validation loss:
	 0: 0.299
----------------------------------------

----------------------------------------
Training is stopped by EarlyStopping.
Training stopped with epoch 1.
----------------------------------------

Fold 1
--------------------------------------------------
epoch 1 start
Training loss:
	 0: 0.256
validation loss:
	 0: 0.329
----------------------------------------

----------------------------------------
Training is stopped by EarlyStopping.
Training stopped with epoch 1.
----------------------------------------

Fold 2
--------------------------------------------------
epoch 1 start
Training loss:
	 0: 0.261
validation loss:
	 0: 0.282
model is saved in epoch 1
----------------------------------------
epoch 2 start
Training loss:
	 0: 0.260
validation loss:
	 0: 0.283
----------------------------------------
epoch 3 start
Training loss:
	 0: 0.259
validation loss:
	

In [8]:
# check best scores
scores

[tensor(0.2743),
 tensor(0.2491),
 tensor(0.2824),
 tensor(0.2166),
 tensor(0.2248)]

In [9]:
# check log file
import pandas as pd
df = pd.read_csv("./trainer_0/log.csv")
df.head()

Unnamed: 0,Time,LearningRate_1,TrainLoss_1,ValidationLoss_1,rmse_1
0,0.03971,9e-05,0.547021,0.434241,0.052936
1,0.053729,6.5e-05,0.491073,0.395146,0.049229
2,0.067084,3.5e-05,0.447982,0.368862,0.046312
3,0.080471,1e-05,0.423469,0.355275,0.044848
4,0.094623,0.0,0.413667,0.35161,0.044567
