# Model 03

Evidence retrieval using a Siamese BERT classification model.
This is similar to Model 01, however, it only uses official pre-trained models from hugging face.

This extends model 03 by continuing the training with greater proportion of negative samples.

Ref:
- [Hugging face pre-trained models](https://huggingface.co/transformers/v3.3.1/pretrained_models.html)
- [Hugging face guide to fine-tuning](https://huggingface.co/transformers/v3.3.1/custom_datasets.html)
- [Hugging face guide to fine-tuning easy](https://huggingface.co/docs/transformers/training)
- [SO Guide](https://stackoverflow.com/a/64156912)

## Setup

### Working Directory

In [1]:
# Change the working directory to project root
from pathlib import Path
import os
ROOT_DIR = Path.cwd()
while not ROOT_DIR.joinpath("src").exists():
    ROOT_DIR = ROOT_DIR.parent
os.chdir(ROOT_DIR)

### File paths

In [2]:
MODEL_PATH = ROOT_DIR.joinpath("./result/models/*")
DATA_PATH = ROOT_DIR.joinpath("./data/*")
NER_PATH = ROOT_DIR.joinpath("./result/ner/*")

### Dependencies

In [3]:
# Imports and dependencies
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import Linear, Module, CrossEntropyLoss, Dropout
from transformers import BertModel, BertTokenizer
from torch.optim import Adam
from torch.optim.lr_scheduler import LinearLR
from torch.nn.functional import relu, softmax
from torcheval.metrics import BinaryAccuracy, BinaryF1Score, BinaryRecall

from src.torch_utils import get_torch_device
import json
from dataclasses import dataclass
from typing import List, Union, Tuple
from tqdm import tqdm
import random
import numpy as np
from datetime import datetime
from math import exp

TORCH_DEVICE = get_torch_device()

  from .autonotebook import tqdm as notebook_tqdm


Torch device is 'mps'


## Dataset

In [4]:
@dataclass
class ClaimEvidencePair:
    claim_id:str
    evidence_id:str
    label:int = 0

In [5]:
class SiameseDataset(Dataset):
    
    def __init__(
        self,
        claims_paths:List[Path],
        claims_shortlist_paths:List[Path],
        evidence_path:Path,
        evidence_shortlists:List[Path] = None,
        device = None,
        n_neg_shortlist:int = 10,
        n_neg_general:int = 10,
        verbose:bool=True
    ) -> None:
        super(SiameseDataset, self).__init__()
        self.verbose = verbose
        self.device = device
        self.n_neg_shortlist = n_neg_shortlist
        self.n_neg_general = n_neg_general

        # Load claims data from json, this is a list as we could use
        # multiple json files in the same dataset
        self.claims = dict()
        for json_file in claims_paths:
            with open(json_file, mode="r") as f:
                self.claims.update(json.load(fp=f))
                # print(f"loaded claims: {json_file}")
        
        # Load the pre-retrieved shortlist of evidences by claim
        self.claims_shortlist = dict()
        for json_file in claims_shortlist_paths:
            with open(json_file, mode="r") as f:
                self.claims_shortlist.update(json.load(fp=f))
                # print(f"loaded claims_shortlist: {json_file}")
        
        # Load evidence library
        self.evidence = dict()
        with open(evidence_path, mode="r") as f:
            self.evidence.update(json.load(fp=f))
            # print(f"loaded evidences: {json_file}")
        
        # Load the evidence shortlists if available
        # Reduce the overall evidence list to the shortlist
        if evidence_shortlists is not None:
            self.evidence_shortlist = set()
            for json_file in evidence_shortlists:
                with open(json_file, mode="r") as f:
                    self.evidence_shortlist.update(json.load(fp=f))
                    # print(f"loaded evidence shortlist: {json_file}")
        
        # print(f"n_evidences: {len(self.evidence)}")
        
        # Generate the data
        self.data = self.__generate_data()
        return

    def __generate_data(self):
        print("Generate siamese dataset")
        
        data = []
        for claim_id, claim in tqdm(
            iterable=self.claims.items(),
            desc="claims",
            disable=not self.verbose
        ):
            # Check if we have evidences supplied, this will inform
            # whether this is for training
            is_training = "evidences" in claim.keys()
            pos_evidence_ids = set()
            
            # Get positive samples from evidences with label=1
            if is_training:
                pos_evidence_ids.update(claim["evidences"])

                for evidence_id in pos_evidence_ids:
                    data.append(ClaimEvidencePair(
                        claim_id=claim_id,
                        evidence_id=evidence_id,
                        label=1
                    ))
                    
            # Get negative samples from pre-retrieved evidences
            # for each claim with label=0
            retrieved_evidence_ids = self.claims_shortlist.get(claim_id, [])
            if len(retrieved_evidence_ids) > 0:
                retrieved_neg_evidence_ids = random.sample(
                    population=retrieved_evidence_ids,
                    k=min(self.n_neg_shortlist, len(retrieved_evidence_ids))
                )
                
                # Generate claim and shortlisted negative evidence pairs
                for evidence_id in retrieved_neg_evidence_ids:
                    data.append(ClaimEvidencePair(
                        claim_id=claim_id,
                        evidence_id=evidence_id,
                        label=0
                    ))
            
            # Get negative samples from shortlisted evidences list with label=0
            if len(self.evidence_shortlist) > 0:
                shortlist_neg_evidence_ids = random.sample(
                    population=self.evidence_shortlist,
                    k=min(self.n_neg_general, len(self.evidence_shortlist))
                )
                
                # Generate claim and shortlisted negative evidence pairs
                for evidence_id in shortlist_neg_evidence_ids:
                    data.append(ClaimEvidencePair(
                        claim_id=claim_id,
                        evidence_id=evidence_id,
                        label=0
                    ))
            
            continue
        
        print(f"Generated data n={len(data)}")
        
        return data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx) -> Tuple[Union[str, torch.Tensor]]:
        # Fetch the required data rows
        data = self.data[idx]
        
        # Get the label
        label = torch.tensor(data.label, device=self.device)
        
        # Get text ids
        claim_id = data.claim_id
        evidence_id = data.evidence_id
        
        # Get text
        claim_text = self.claims[claim_id]["claim_text"]
        evidence_text = self.evidence[evidence_id]

        return (claim_text, evidence_text, label)

In [6]:
# WE WILL GENERATE THE DATASET PER EPOCH SO TO RANDOMISE THE NEGATIVE SAMPLES

# train_data = SiameseDataset(
#     claims_paths=[DATA_PATH.with_name("train-claims.json")],
#     claims_shortlist_paths=[NER_PATH.with_name("train_claim_evidence_retrieved.json")],
#     evidence_shortlists=[NER_PATH.with_name("shortlist_train_claim_evidence_retrieved.json")],
#     evidence_path=DATA_PATH.with_name("evidence.json"),
#     device=TORCH_DEVICE,
#     n_neg_shortlist=100,
#     n_neg_general=100
# )

## Build model

In [7]:
class SiameseClassifierBert(Module):
    
    def __init__(
            self,
            pretrained_name:str,
            device,
            **kwargs
        ) -> None:
        super(SiameseClassifierBert, self).__init__(**kwargs)
        self.device = device
        
        # Use a pretrained tokenizer
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_name)
        
        # Use a pretrained model
        self.bert = BertModel.from_pretrained(pretrained_name)
        self.bert.to(device=device)
        
        # Classification layers
        self.linear1 = Linear(2304, 1024, bias=True, device=device)
        self.linear2 = Linear(1024, 512, bias=True, device=device)
        self.linear3 = Linear(512, 2, bias=True, device=device)
        self.relu = relu
        self.softmax = softmax
        self.dropout_in = Dropout(p=0.2)
        self.dropout_out = Dropout(p=0.5)
        
        # print(self.tokenizer)
        # print(self.bert)
        # print(self.linear1)
        # print(self.linear2)
        # print(self.activation)
        # print(self.softmax)
        return
        
    def forward(self, claim_texts, evidence_texts) -> Tuple[torch.Tensor]:
        
        # Run the tokenizer
        t_kwargs = {
            "return_tensors": "pt",
            "padding": True,
            "truncation": True,
            "max_length": 100,
            "add_special_tokens":True
        }
        claim_x = self.tokenizer(claim_texts, **t_kwargs)
        evidence_x = self.tokenizer(evidence_texts, **t_kwargs)
        
        claim_x = claim_x["input_ids"].to(device=self.device)
        evidence_x = evidence_x["input_ids"].to(device=self.device)
        
        # Run Bert
        claim_x = self.bert(claim_x, return_dict=True).pooler_output
        evidence_x = self.bert(evidence_x, return_dict=True).pooler_output
        # dim=768
        
        # Concatenate the two embeddings
        x = torch.cat((claim_x, evidence_x, claim_x - evidence_x), dim=1)
        # dim=2304
        
        # Run classification layers
        x = self.dropout_in(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout_out(x)
        x = self.linear2(x)
        x = self.relu(x)
        x = self.dropout_out(x)
        x = self.linear3(x)
        
        # Create the predictions
        y = self.softmax(x, dim=-1)
        
        return (y, claim_x, evidence_x)

## Training and evaluation loop

We are continuing to train the model.

In [8]:
# model = SiameseClassifierBert(
#     pretrained_name="bert-base-uncased",
#     device=TORCH_DEVICE
# )

In [9]:
run_time = datetime.now().strftime('%Y_%m_%d_%H_%M')
MODEL_NAME = f"model_03_base_run_01_continue.pth"
N_EPOCHS = 4
BATCH_SIZE = 64


In [10]:
with open(MODEL_PATH.with_name(MODEL_NAME), mode="rb") as f:
    model = torch.load(f, map_location=TORCH_DEVICE)

In [11]:
loss_fn = CrossEntropyLoss()
optimizer = Adam(
    params=model.parameters(),
    lr=0.000002
) #! Hyperparams

In [12]:
dev_data = SiameseDataset(
    claims_paths=[DATA_PATH.with_name("dev-claims.json")],
    claims_shortlist_paths=[NER_PATH.with_name("dev_claim_evidence_retrieved.json")],
    evidence_shortlists=[NER_PATH.with_name("shortlist_dev_claim_evidence_retrieved.json")],
    evidence_path=DATA_PATH.with_name("evidence.json"),
    device=TORCH_DEVICE,
    n_neg_shortlist=10,
    n_neg_general=10
)

dev_dataloader = DataLoader(
    dataset=dev_data,
    shuffle=False,
    batch_size=BATCH_SIZE
)

Generate siamese dataset


claims: 100%|██████████| 154/154 [00:00<00:00, 469.17it/s]

Generated data n=3532





In [13]:
import warnings
warnings.filterwarnings('ignore')

In [14]:
metric_accuracy = BinaryAccuracy()
metric_f1 = BinaryF1Score()
metric_recall = BinaryF1Score()

scheduler = LinearLR(
    optimizer=optimizer,
    start_factor=1,
    end_factor=1,
    total_iters=int(N_EPOCHS/10),
    verbose=True
)
last_epoch_loss = 999
for epoch in range(N_EPOCHS):
    
    print(f"Epoch: {epoch} of {N_EPOCHS}\n")
    
    # Run training
    model.train()
    
    train_data = SiameseDataset(
        claims_paths=[DATA_PATH.with_name("train-claims.json")],
        claims_shortlist_paths=[NER_PATH.with_name("train_claim_evidence_retrieved.json")],
        evidence_shortlists=[NER_PATH.with_name("shortlist_train_claim_evidence_retrieved.json")],
        evidence_path=DATA_PATH.with_name("evidence.json"),
        device=TORCH_DEVICE,
        n_neg_shortlist=10,
        n_neg_general=10
    )
    
    train_dataloader = DataLoader(
        dataset=train_data,
        shuffle=True,
        batch_size=BATCH_SIZE
    )
    
    train_batches = tqdm(train_dataloader, desc="train batches")
    running_losses = []
    for batch in train_batches:
        claim_texts, evidence_texts, labels = batch
        
        # Reset optimizer
        optimizer.zero_grad()
        
        # Forward + loss
        predictions, *_ = model(claim_texts, evidence_texts)
        loss = loss_fn(predictions, labels)
        
        # Backward + optimiser
        loss.backward()
        optimizer.step()
        
        # Update running loss
        batch_loss = loss.item() * len(batch)
        running_losses.append(batch_loss)
        
        train_batches.postfix = f"loss: {batch_loss:.3f}"
        
        continue
    
    scheduler.step()
    
    epoch_loss = np.average(running_losses)
    print(f"Average epoch loss: {epoch_loss}")
    
    # Save model
    if epoch_loss <= last_epoch_loss:
        torch.save(model, MODEL_PATH.with_name(MODEL_NAME))
        print(f"Saved model to: {MODEL_PATH.with_name(MODEL_NAME)}")
    last_epoch_loss = epoch_loss
    
    # Evaluate every 5 epochs
    if epoch % 2 != 0:
        continue
    
    # Run evaluation
    model.eval()
    
    dev_batches = tqdm(dev_dataloader, desc="dev batches")
    dev_acc = []
    dev_f1 = []
    dev_rec = []
    for batch in dev_batches:
        claim_texts, evidence_texts, labels = batch
        
        # Forward
        predictions, *_ = model(claim_texts, evidence_texts)
        
        # Prediction
        _, predicted = torch.max(predictions, dim=-1)
        
        # Metrics
        metric_accuracy.update(predicted.cpu(), labels.cpu())
        metric_f1.update(predicted.cpu(), labels.cpu())
        metric_recall.update(predicted.cpu(), labels.cpu())
        
        acc = metric_accuracy.compute()
        f1 = metric_f1.compute()
        rec = metric_recall.compute()
        
        dev_acc.append(acc)
        dev_f1.append(f1)
        dev_rec.append(rec)
        
        dev_batches.postfix = \
            f" acc: {acc:.3f}" \
            + f" f1: {f1:.3f}" \
            + f" rec: {rec:.3f}"
        
        continue
    
    val_acc = np.mean(dev_acc)
    val_f1 = np.mean(dev_f1)
    val_rec = np.mean(dev_rec)
    
    print(f"Epoch accuracy on dev: {val_acc:.3f}")
    print(f"Epoch f1 on dev: {val_f1:.3f}")
    print(f"Epoch recall on dev: {val_rec:.3f}\n")

print("Done!")

Adjusting learning rate of group 0 to 2.0000e-06.
Epoch: 0 of 4

Generate siamese dataset


claims: 100%|██████████| 1228/1228 [00:07<00:00, 172.66it/s]


Generated data n=28097


train batches: 100%|██████████| 440/440 [06:15<00:00,  1.17it/s, loss: 1.091]


Adjusting learning rate of group 0 to 2.0000e-06.
Average epoch loss: 1.0903123059056021
Saved model to: /Users/johnsonzhou/git/comp90042-project/result/models/model_03_base_run_01_continue.pth


dev batches: 100%|██████████| 56/56 [00:16<00:00,  3.31it/s,  acc: 0.882 f1: 0.604 rec: 0.604]


Epoch accuracy on dev: 0.877
Epoch f1 on dev: 0.602
Epoch recall on dev: 0.602

Epoch: 1 of 4

Generate siamese dataset


claims: 100%|██████████| 1228/1228 [00:07<00:00, 174.75it/s]


Generated data n=28097


train batches: 100%|██████████| 440/440 [06:15<00:00,  1.17it/s, loss: 0.940]


Adjusting learning rate of group 0 to 2.0000e-06.
Average epoch loss: 1.0925235448235815
Epoch: 2 of 4

Generate siamese dataset


claims: 100%|██████████| 1228/1228 [00:06<00:00, 182.41it/s]


Generated data n=28097


train batches:  54%|█████▍    | 239/440 [03:26<02:53,  1.16it/s, loss: 1.091]


KeyboardInterrupt: 