# Model 06 Bert Cross Entropy Classification for Label Prediction

Prediction of claim labels based on the matched evidence.

## Setup

### Working Directory

In [1]:
# Change the working directory to project root
from pathlib import Path
import os
ROOT_DIR = Path.cwd()
while not ROOT_DIR.joinpath("src").exists():
    ROOT_DIR = ROOT_DIR.parent
os.chdir(ROOT_DIR)

### Dependencies

In [2]:
# Imports and dependencies
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torcheval.metrics import MulticlassAccuracy, MulticlassF1Score

from src.logger import SimpleLogger
from src.model_05 import BertCrossEncoderClassifier
from src.data import LabelClassificationDataset
from src.torch_utils import get_torch_device
import json
from dataclasses import dataclass
from typing import List, Union, Tuple
from tqdm import tqdm
import random
import numpy as np
from datetime import datetime
from math import exp
from sklearn.model_selection import ParameterGrid

TORCH_DEVICE = get_torch_device()

  from .autonotebook import tqdm as notebook_tqdm


Torch device is 'mps'


### File paths

In [3]:
MODEL_PATH = ROOT_DIR.joinpath("./result/models/*")
DATA_PATH = ROOT_DIR.joinpath("./data/*")
LOG_PATH = ROOT_DIR.joinpath("./result/logs/*")
SHORTLIST_PATH = ROOT_DIR.joinpath("./result/pipeline/shortlisting_v2/*")

run_time = datetime.now().strftime('%Y_%m_%d_%H_%M')

## Training Loop

In [4]:
def training_loop(
    model,
    claims_paths:List[Path],
    save_path:Path=None,
    label_weight:list=None,
    label_smoothing:float=None,
    warmup:float=0.1,
    lr:float=0.00005, # 5e-5
    weight_decay:float=0.01,
    normalize_text:bool=True,
    max_length:int=128,
    dropout:float=None,
    n_epochs:int=5,
    batch_size:int=64,
):
    # Generate training dataset
    train_data = LabelClassificationDataset(
        claims_paths=claims_paths,
        training=True,
    )
    train_dataloader = DataLoader(
        dataset=train_data,
        shuffle=True,
        batch_size=batch_size
    )
    
    # Generate evaluation dataset
    dev_data = LabelClassificationDataset(
        claims_paths=[Path("./data/dev-claims.json")],
        training=True,
    )
    dev_dataloader = DataLoader(
        dataset=dev_data,
        shuffle=False,
        batch_size=batch_size
    )
    
    # Loss function
    loss_fn = CrossEntropyLoss(
        weight=torch.tensor(label_weight, device=TORCH_DEVICE),
        label_smoothing=label_smoothing
    )
    
    # Optimizer
    optimizer = AdamW(
        params=model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    
    # Scheduler
    scheduler = LinearLR(
        optimizer=optimizer,
        total_iters=warmup * len(train_dataloader),
        verbose=False
    )
    
    # Metrics
    accuracy_fn = MulticlassAccuracy()
    f1_fn = MulticlassF1Score()
    
    # Training epochs --------------------------------------------------------
    
    best_epoch_loss = 999
    best_epoch_f1 = -1
    best_epoch_acc = -1
    best_epoch = 0
    for epoch in range(n_epochs):
        
        print(f"\nEpoch: {epoch + 1} of {n_epochs}\n")
        
        # Run training -------------------------------------------------------
        model.train()
        
        train_batches = tqdm(train_dataloader, desc="train batches")
        running_losses = []
        for batch in train_batches:
            claim_texts, evidence_texts, labels, claim_ids, evidence_ids = batch
            texts = list(zip(claim_texts, evidence_texts))
            
            # Reset optimizer
            optimizer.zero_grad()
            
            # Forward + loss
            output, logits, seq = model(
                texts=texts,
                normalize_text=normalize_text,
                max_length=max_length,
                dropout=dropout
            )
            loss = loss_fn(logits, labels)
            
            # Backward + optimizer
            loss.backward()
            optimizer.step()
            
            # Update running loss
            batch_loss = loss.item() * len(batch)
            running_losses.append(batch_loss)
            
            train_batches.postfix = f"loss: {batch_loss:.3f}"
            
            # Update scheduler
            scheduler.step()
            
            continue
        
        # Epoch loss
        epoch_loss = np.average(running_losses)
        print(f"Average epoch loss: {epoch_loss:.3f}")
    
        # Run evaluation ------------------------------------------------------
        model.eval()

        dev_batches = tqdm(dev_dataloader, desc="dev batches")
        dev_acc = []
        dev_f1 = []
        for batch in dev_batches:
            claim_texts, evidence_texts, labels, claim_ids, evidence_ids = batch
            texts = list(zip(claim_texts, evidence_texts))

            # Forward
            output, logits, seq = model(
                texts=texts,
                normalize_text=normalize_text,
                max_length=max_length,
                dropout=dropout
            )
            
            # Prediction
            predicted = torch.argmax(output, dim=1)

            # Metrics
            accuracy_fn.update(predicted.cpu(), labels.cpu())
            f1_fn.update(predicted.cpu(), labels.cpu())
            
            acc = accuracy_fn.compute()
            f1 = f1_fn.compute()
            
            dev_acc.append(acc)
            dev_f1.append(f1)
            
            dev_batches.postfix = f" acc: {acc:.3f}, f1: {f1:.3f}"

            continue
        
        # Consider metrics
        epoch_acc = np.average(dev_acc)
        print(f"Average epoch accuracy: {epoch_acc:.3f}")
        
        epoch_f1 = np.average(dev_f1)
        print(f"Average epoch f1: {epoch_f1:.3f}")
        
        if epoch_acc > best_epoch_acc:
            best_epoch_acc = epoch_acc
        
        if epoch_f1 > best_epoch_f1:
            best_epoch_f1 = epoch_f1
            best_epoch = epoch + 1
        
        # Save model ----------------------------------------------------------
        
        # Save the model with the best f1 score
        if save_path and epoch_f1 >= best_epoch_f1:
            torch.save(model, save_path)
            print(f"Saved model to: {save_path}")
        
    print("Done!")
    return best_epoch_acc, best_epoch_f1, best_epoch

## Load model

Use a blank pre-trained

In [5]:
model = BertCrossEncoderClassifier(
    pretrained_name="bert-base-uncased",
    n_classes=3,
    device=TORCH_DEVICE
)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Or load one previously trained

In [6]:
# MODEL_SAVE_PATH = MODEL_PATH.with_name("")
# with open(MODEL_PATH.with_name(MODEL_SAVE_PATH), mode="rb") as f:
#     model = torch.load(f, map_location=TORCH_DEVICE)

## Training and evaluation loop

In [7]:
training_loop(
    model=model,
    claims_paths=[
        DATA_PATH.with_name("train-claims.json")
    ],
    save_path=MODEL_PATH.with_name(f"model_06_bert_base_uncased_cross_encoder_label_{run_time}_high_weights.pth"),
    warmup=0.1,
    lr=0.000005, # 5e-6
    weight_decay=0.02,
    normalize_text=True,
    max_length=512,
    dropout=0.1,
    n_epochs=1,
    label_weight=[2, 1.2, 1],
    # label_weight=[1, 0.6, 0.4],
    label_smoothing=0.0,
    batch_size=24,
)

Torch device is 'mps'


claims: 100%|██████████| 1228/1228 [00:00<00:00, 407968.74it/s]


generated dataset n=3730
Torch device is 'mps'


claims: 100%|██████████| 154/154 [00:00<00:00, 581388.67it/s]


generated dataset n=433

Epoch: 1 of 1



train batches: 100%|██████████| 156/156 [03:23<00:00,  1.31s/it, loss: 5.070]


Average epoch loss: 5.108


dev batches: 100%|██████████| 19/19 [00:07<00:00,  2.61it/s,  acc: 0.513, f1: 0.513]


Average epoch accuracy: 0.539
Average epoch f1: 0.539
Saved model to: /Users/johnsonzhou/git/comp90042-project/result/models/model_06_bert_base_uncased_cross_encoder_label_2023_05_09_20_05_high_weights.pth
Done!


(0.53880805, 0.53880805, 1)

## Tune hyperparameters

In [None]:
# hyperparams = ParameterGrid(param_grid={
#     "claims_paths": [[
#         DATA_PATH.with_name("train-claims.json")
#     ]],
#     "warmup": [0.1],
#     "lr": [0.000005],
#     "weight_decay": [0.02],
#     "normalize_text": [True],
#     "max_length": [512],
#     "dropout": [0.1],
#     "n_epochs": [10],
#     "batch_size": [24],
#     "freeze_bert": [False],
#     "label_weight":[
#         # [2, 1.2, 1],
#         [1, 0.6, 0.4],
#     ],
#     "label_smoothing": [0.0]
# })

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
with SimpleLogger("model_06_cross_encoder_retrieval") as logger:
    logger.set_stream_handler()
    logger.set_file_handler(
        log_path=LOG_PATH,
        filename="model_06_hyperparam_tuning.txt"
    )
    best_f1 = -1
    best_params = {}
    for hyperparam in hyperparams:
        model = BertCrossEncoderClassifier(
            pretrained_name="bert-base-uncased",
            n_classes=3,
            device=TORCH_DEVICE
        )
        
        model_param = hyperparam.copy()
        
        # Freeze bert parameters if desired
        if "freeze_bert" in model_param.keys():
            if hyperparam["freeze_bert"] is True:
                for param in model.bert.parameters():
                    param.requires_grad = False
            del model_param["freeze_bert"]
        
        logger.info("\n== RUN")
        logger.info(hyperparam)
        
        accuracy, f1, epoch = training_loop(model=model, **model_param)
        
        logger.info(f"run_best_epoch: {epoch}, run_best_acc: {accuracy}, run_best_f1: {f1}")
        
        if f1 > best_f1:
            best_f1 = f1
            best_params = hyperparam
        
        logger.info(f"\n== CURRENT BEST F1: {best_f1}")
        logger.info(best_params)