
# Model 05 Bert Cross Entropy Classification for Retrieval

**This is the final implementation for the retrieval stage**

Prerequisites:
- Please ensure that shortlists have been created for both `train` and `dev` sets before proceeding using the [Model 02c workbook](./model_02c_fast_shortlisting.ipynb).


## Setup

### Working Directory

In [None]:
# Change the working directory to project root
from pathlib import Path
import os
ROOT_DIR = Path.cwd()
while not ROOT_DIR.joinpath("src").exists():
    ROOT_DIR = ROOT_DIR.parent
os.chdir(ROOT_DIR)

### Dependencies

In [None]:
# Imports and dependencies
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torcheval.metrics import BinaryAccuracy, BinaryF1Score

from src.torch_utils import get_torch_device
import json
from dataclasses import dataclass
from typing import List, Union, Tuple
from tqdm import tqdm
import random
import numpy as np
from datetime import datetime
from sklearn.model_selection import ParameterGrid

from src.model_05 import BertCrossEncoderClassifier
from src.data import RetrievalWithShortlistDataset, RetrievalDevEvalDataset
from src.logger import SimpleLogger

TORCH_DEVICE = get_torch_device()

### File paths

In [None]:
MODEL_PATH = ROOT_DIR.joinpath("./result/*")
DATA_PATH = ROOT_DIR.joinpath("./data/*")
LOG_PATH = ROOT_DIR.joinpath("./result/*")
SHORTLIST_PATH = ROOT_DIR.joinpath("./result/*")

run_time = datetime.now().strftime('%Y_%m_%d_%H_%M')

## Training Loop

In [None]:
def training_loop(
    model,
    claims_paths:List[Path],
    claims_shortlist_paths:List[Path],
    save_path:Path=None,
    n_neg_samples:int=5,
    warmup:float=0.1,
    lr:float=0.00005, # 5e-5
    weight_decay:float=0.01,
    normalize_text:bool=True,
    max_length:int=128,
    dropout:float=None,
    n_epochs:int=5,
    batch_size:int=64,
):
    # Generate training dataset
    train_data = RetrievalWithShortlistDataset(
        claims_paths=claims_paths,
        claims_shortlist_paths=claims_shortlist_paths,
        n_neg_samples=n_neg_samples,
        pos_label=1,
        neg_label=0
    )
    train_dataloader = DataLoader(
        dataset=train_data,
        shuffle=True,
        batch_size=batch_size
    )
    
    # Generate evaluation dataset
    dev_data = RetrievalDevEvalDataset(
        n_neg_samples=3,
        # n_neg_samples=n_neg_samples,
        pos_label=1,
        neg_label=0,
    )
    dev_dataloader = DataLoader(
        dataset=dev_data,
        shuffle=False,
        batch_size=batch_size
    )
    
    # Loss function
    loss_fn = CrossEntropyLoss()
    
    # Optimizer
    optimizer = AdamW(
        params=model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    
    # Scheduler
    scheduler = LinearLR(
        optimizer=optimizer,
        total_iters=warmup * len(train_dataloader),
        verbose=False
    )
    
    # Metrics
    accuracy_fn = BinaryAccuracy()
    f1_fn = BinaryF1Score()
    
    # Training epochs --------------------------------------------------------
    
    best_epoch_loss = 999
    best_epoch_f1 = -1
    best_epoch_acc = -1
    best_epoch = 0
    for epoch in range(n_epochs):
        
        print(f"Epoch: {epoch + 1} of {n_epochs}\n")
        
        # Run training -------------------------------------------------------
        model.train()
        
        train_batches = tqdm(train_dataloader, desc="train batches")
        running_losses = []
        for batch in train_batches:
            claim_texts, evidence_texts, labels, claim_ids, evidence_ids = batch
            texts = list(zip(claim_texts, evidence_texts))
            
            # Reset optimizer
            optimizer.zero_grad()
            
            # Forward + loss
            output, logits, seq = model(
                texts=texts,
                normalize_text=normalize_text,
                max_length=max_length,
                dropout=dropout
            )
            loss = loss_fn(logits, labels)
            
            # Backward + optimizer
            loss.backward()
            optimizer.step()
            
            # Update running loss
            batch_loss = loss.item() * len(batch)
            running_losses.append(batch_loss)
            
            train_batches.postfix = f"loss: {batch_loss:.3f}"
            
            # Update scheduler
            scheduler.step()
            
            continue
        
        # Epoch loss
        epoch_loss = np.average(running_losses)
        print(f"Average epoch loss: {epoch_loss:.3f}")
    
        # Run evaluation ------------------------------------------------------
        model.eval()

        dev_batches = tqdm(dev_dataloader, desc="dev batches")
        dev_acc = []
        dev_f1 = []
        for batch in dev_batches:
            claim_texts, evidence_texts, labels, claim_ids, evidence_ids = batch
            texts = list(zip(claim_texts, evidence_texts))

            # Forward
            output, logits, seq = model(
                texts=texts,
                normalize_text=normalize_text,
                max_length=max_length,
                dropout=dropout
            )
            
            # Prediction
            _, predicted = torch.max(output, dim=-1)

            # Metrics
            accuracy_fn.update(predicted.cpu(), labels.cpu())
            f1_fn.update(predicted.cpu(), labels.cpu())
            
            acc = accuracy_fn.compute()
            f1 = f1_fn.compute()
            
            dev_acc.append(acc)
            dev_f1.append(f1)
            
            dev_batches.postfix = f" acc: {acc:.3f}, f1: {f1:.3f}"

            continue
        
        # Consider metrics
        epoch_acc = np.average(dev_acc)
        print(f"Average epoch accuracy: {epoch_acc:.3f}")
        
        epoch_f1 = np.average(dev_f1)
        print(f"Average epoch f1: {epoch_f1:.3f}")
        
        if epoch_acc > best_epoch_acc:
            best_epoch_acc = epoch_acc
        
        if epoch_f1 > best_epoch_f1:
            best_epoch_f1 = epoch_f1
            best_epoch = epoch + 1
        
        # Save model ----------------------------------------------------------
        
        # Save the model with the best f1 score
        if save_path and epoch_f1 >= best_epoch_f1:
            torch.save(model, save_path)
            print(f"Saved model to: {save_path}")
        
    print("Done!")
    return best_epoch_acc, best_epoch_f1, best_epoch

## Load model

Use a blank pre-trained

In [None]:
model = BertCrossEncoderClassifier(
    pretrained_name="bert-base-uncased",
    n_classes=2,
    device=TORCH_DEVICE
)

Or load one previously trained

In [None]:
# MODEL_SAVE_PATH = MODEL_PATH.with_name("model_05_bert_cross_encoder_retrieval_2023_05_08_17_06.pth")
# with open(MODEL_SAVE_PATH, mode="rb") as f:
#     model = torch.load(f, map_location=TORCH_DEVICE)

## Training and evaluation loop

In [None]:
training_loop(
    model=model,
    claims_paths=[
        DATA_PATH.with_name("train-claims.json"),
        DATA_PATH.with_name("dev-claims.json")
    ],
    claims_shortlist_paths=[
        Path("./result/train_shortlist_evidences_max_500.json"),
        Path("./result/dev_shortlist_evidences_max_500.json"),
    ],
    # save_path=MODEL_PATH.with_name(f"model_05_bert_cross_encoder_retrieval_2023_05_08_17_06_e6.pth"),
    save_path=MODEL_PATH.with_name(f"model_05_bert_cross_encoder_retrieval_{run_time}.pth"),
    n_neg_samples=100,
    warmup=0.1,
    lr=0.000005, # 5e-6
    weight_decay=0.02,
    normalize_text=True,
    max_length=512,
    dropout=0.1,
    n_epochs=1,
    batch_size=16,
)

## Tune hyperparameters

In [None]:
# hyperparams = ParameterGrid(param_grid={
#     "claims_paths": [[
#         DATA_PATH.with_name("train-claims.json")
#     ]],
#     "claims_shortlist_paths": [[
#         Path("./result/pipeline/shortlisting_v2/train_retrieved_evidences_max_500_idf_no_rel.json"),
#     ]],
#     "n_neg_samples": [3, 5, 10],
#     "warmup": [0.1],
#     "lr": [0.00005, 0.0005],
#     "weight_decay": [0.01, 0.02],
#     "normalize_text": [True, False],
#     "max_length": [512],
#     "dropout": [None, 0.1],
#     "n_epochs": [5, 10],
#     "batch_size": [24]
# })

In [None]:
# with SimpleLogger("model_05_cross_encoder_retrieval") as logger:
#     logger.set_stream_handler()
#     logger.set_file_handler(
#         log_path=LOG_PATH,
#         filename="model_05_hyperparam_tuning.txt"
#     )
#     best_f1 = -1
#     best_params = {}
#     for hyperparam in hyperparams:
#         model = BertCrossEncoderClassifier(
#             pretrained_name="bert-base-uncased",
#             n_classes=2,
#             device=TORCH_DEVICE
#         )
#         logger.info("== RUN")
#         logger.info(hyperparam)
        
#         accuracy, f1, epoch = training_loop(model=model, **hyperparam)
        
#         logger.info(f"run_best_epoch: {epoch}, run_best_acc: {accuracy}, run_best_f1: {f1}")
        
#         if f1 > best_f1:
#             best_f1 = f1
#             best_params = hyperparam
        
#         logger.info(f"== CURRENT BEST F1: {best_f1}")
#         logger.info(best_params)