# Reappraisal Training on PyTorch Lightning

## Setup
- Required Python Version: 3.7+
- `cd` into the project root and install dependencies:
  - `pip install -r requirements.txt`

### Colab Setup
- When the repository is stored on Google Drive, it can be accessed using Google Colaboratory. The cell below mounts the drive, installs the necessary packages, and changes the root directory to the project directory.
    - Python Version: `3.7.10`
    - Be sure to change the name of the project directory. 

In [None]:
# When Running on Colab
# from google.colab import drive
# drive.mount('/content/drive')

# %pip install pytorch-lightning "ray[tune]" wandb transformers datasets nltk nbdev
# ! nbdev_install_git_hooks

import nltk
nltk.download('punkt')

PROJECT_NAME = "ldh"
ROOT_DIR = f"/content/drive/MyDrive/{PROJECT_NAME}"
%cd {ROOT_DIR}

## Loading and Encoding Data

In [None]:
%load_ext autoreload
import os

import torch

# Define project root directory.
ROOT_DIR = os.path.abspath(".")
# Select the proper strategy. Valid strategy names: "obj", "far"
STRAT = 'obj'
# Define batch size.
BATCH_SIZE = 64
DEV_FLAG = True # Flag for fast runs when debugging.

# Load the DataModule and its corresponding 
from reappraisalmodel.ldhdata import LDHDataModule
ldhdata = LDHDataModule(data_dir=ROOT_DIR, strat=STRAT)
ldhdata.load_train_data()
ldhdata.load_eval_data()

# Load Model
Loading a `LightningReapp` model without any arguments will load a model with uninitialized parameters (that is, a blank, untrained reappraisal odel). We can load any valid LightningReapp model using a checkpoint file like so:
```python
ckpt_path = # Any path on local storage or remote storage (i.e. s3)
model = LightningReapp.load_from_checkpoint(ckpt_path)
```
If loading a model from s3, the `s3fs` package should be installed. 

In [None]:
from reappraisalmodel.lightningreapp import LightningReapp
model = LightningReapp()

In [None]:
%autoreload 2
from reappraisalmodel.trainers import kfold_train
# When running k-fold cross-validation, define the number of folds. 
NUM_FOLDS = 5

# Learns a model NUM_FOLDS times and records the distribution of metrics across the CV.
results = kfold_train(
    NUM_FOLDS, 
    ldhdata, 
    strat=STRAT
)
df = pd.DataFrame(results)
df['r2score'] = df['r2score'].apply(lambda x: x.item())
df['explained_var'] = df['explained_var'].apply(lambda x: x.item())
df.describe()

## Run K-Fold Training
Runs K-Fold Cross-Validation on the training algorithm. Reports the distribution of training results for each fold. 
- See [Trainers.ipynb](./nbs/Trainers.ipynb) for more information. 


# Training Process

### A Note on GPUs:
- It is infeasible to train most machine learning models on a CPU, with a single training epoch (a pass though the training data) taking on the order of hours. GPUs enable fast computation because they're optimized for matrix operations. Listed below are popular services that provide GPU usage with built-in Jupyter Notebook integration:
  - Amazon Web Services
  - Kaggle
  - Google Colaboratory

## Single Training Session
Defines the process of running a training session for `LightningReapp`.

In [None]:
%autoreload
from datetime import datetime
from tqdm import tqdm

import pandas as pd
import torch
import pytorch_lightning as lit
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader

In [None]:
# Model saves the 3 checkpoints with the lowest validation loss throughout training
modelcheckpoint = ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=3,
    verbose=True
)
# Model tracks the loss_distance; shows when training and validation loss begin to diverge 
modelcheckpoint_loss_dist = ModelCheckpoint(
    monitor='loss_distance',
    mode='min',
    save_top_k=3,
    verbose=True
)

# Split train and validation data.
split_data = ldhdata.train_data.train_test_split(test_size=0.2)
train_data = split_data['train'].with_format(type='torch', columns=['score', 'input_ids', 'attention_mask'])
val_data = split_data['test'].with_format(type='torch', columns=['score', 'input_ids', 'attention_mask'])
eval_data = ldhdata.eval_data.with_format(type='torch', columns=['input_ids', 'attention_mask'])

# Create dataloaders
train_dl = DataLoader(train_data, batch_size=BATCH_SIZE)
val_dl = DataLoader(val_data, batch_size=BATCH_SIZE )
eval_dl = DataLoader(eval_data, batch_size=BATCH_SIZE)

# Mark the start time of the training session. 
today = datetime.today().strftime('%Y%m%d_%H%M%S')
session_version = "_".join([STRAT,today])
tb_logger = TensorBoardLogger("lightning_logs", name="reapp_model", version=session_version)


### Trainer
PyTorch Lightning's [Trainer](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html) abstracts aspects of the trainer loop configuration not related to the model. This includes registering callback functions, stop conditions, GPU/CPU configuation, etc. 

In [None]:
trainer = lit.Trainer(
    logger = tb_logger,
    precision=16 if torch.cuda.is_available() else 32, # We use 16-bit precision to reduce computational complexity
    val_check_interval=0.25, # Check validation loss 4 times an epoch
    callbacks=[modelcheckpoint, modelcheckpoint_loss_dist], # Register callbacks with trainer.
    gpus=1 if torch.cuda.is_available() else None,
)
# Fit the model to the training data. 
results = trainer.fit(model, train_dl, val_dl)

## Prediction on Study 1 Data

The following script predicts the reappraisal score for a trained model on the designated training/validation dataset.

Study 1 Data is structured as follows:
- `response`: The participant's written response to the study's stimuli.
- `score`: Each response was rated by numerous raters on a scale of 1-7, with a higher score corresponding to a higher usage of a specific reappraisal strategy (objective distancing vs. spatiotemporal distancing). The ratings are then averaged. 


## Predictions on Study 2 Data

The following script predicts the reappraisal score for a trained model on the designated testing dataset.

Study 2 Data is structured as follows:
- `response`: The participant's written response to the study's stimuli.
- `Condition`: Describes the type of stimulus the participant reacts to.
- `addcode`: Subject Identification
- `daycode`: The day in the study the response was recorded for.

In [None]:
outs = []
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
for idx, batch in enumerate(tqdm(eval_dl)):
    if DEV_FLAG and idx >= 2:
        break
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    out = model(input_ids, attention_mask)
    outs.append(out.sum(dim=1).detach().cpu().tolist())
newouts = []
for batch in outs:
    newouts += batch

df = pd.DataFrame(ldhdata.eval_data[:len(newouts)], columns=['addcode', 'daycode', 'Condition', 'response', 'observed'])
df[['observed']] = newouts
df

In [None]:
outs = []
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
for idx, batch in enumerate(tqdm(train_dl)):
    if DEV_FLAG and idx >= 2:
        break
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    out = model(input_ids, attention_mask)
    outs.append(out.sum(dim=1).detach().cpu().tolist())
newouts = []
for batch in outs:
    newouts += batch
df = pd.DataFrame(ldhdata.train_data[:len(newouts)], columns=['response', 'score', 'observed'])
df[['observed']] = newouts
df