# Model 03a

Evidence retrieval using a Siamese BERT classification model.
This is similar to Model 01, however, it only uses official pre-trained models from hugging face.

Ref:
- [Hugging face pre-trained models](https://huggingface.co/transformers/v3.3.1/pretrained_models.html)
- [Hugging face guide to fine-tuning](https://huggingface.co/transformers/v3.3.1/custom_datasets.html)
- [Hugging face guide to fine-tuning easy](https://huggingface.co/docs/transformers/training)
- [SO Guide](https://stackoverflow.com/a/64156912)

## Setup

### Working Directory

In [1]:
# Change the working directory to project root
from pathlib import Path
import os
ROOT_DIR = Path.cwd()
while not ROOT_DIR.joinpath("src").exists():
    ROOT_DIR = ROOT_DIR.parent
os.chdir(ROOT_DIR)

### File paths

In [2]:
MODEL_PATH = ROOT_DIR.joinpath("./result/models/*")
DATA_PATH = ROOT_DIR.joinpath("./data/*")
NER_PATH = ROOT_DIR.joinpath("./result/ner/*")

### Dependencies

In [3]:
# Imports and dependencies
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import Module, CosineEmbeddingLoss
from transformers import BertModel, BertTokenizer
from torch.optim import Adam
from torch.optim.lr_scheduler import LinearLR
from torcheval.metrics import BinaryAccuracy, BinaryF1Score

from src.torch_utils import get_torch_device
import json
from dataclasses import dataclass
from typing import List, Union, Tuple
from tqdm import tqdm
import random
import numpy as np
from datetime import datetime
from math import exp

TORCH_DEVICE = get_torch_device()

  from .autonotebook import tqdm as notebook_tqdm


Torch device is 'mps'


## Dataset

In [4]:
@dataclass
class ClaimEvidencePair:
    claim_id:str
    evidence_id:str
    label:int = 0

In [5]:
class SiameseEvalDataset(Dataset):
    def __init__(
        self,
        dev_claims_path:Path,
        evidence_path:Path,
        device = None,
        verbose:bool=True
    ) -> None:
        super(SiameseEvalDataset, self).__init__()
        self.verbose = verbose
        self.device = device
        
        # Load claims data from json
        with open(dev_claims_path, mode="r") as f:
            self.claims = (json.load(fp=f))

        # Load evidence library
        self.evidence = dict()
        with open(evidence_path, mode="r") as f:
            self.evidence.update(json.load(fp=f))
        
        # Get a list of all evidences within the dev set
        self.related_evidences = sorted({
            evidence_id
            for claim in self.claims.values()
            for evidence_id in claim["evidences"]
        })
        
        # Generate the data
        self.data = self.__generate_data()
        return
        
    def __generate_data(self):
        data = []
        for claim_id, claim in tqdm(
            iterable=self.claims.items(),
            desc="claims",
            disable=not self.verbose
        ):
            evidence_ids = claim["evidences"]
            
            # Get the positives
            for evidence_id in evidence_ids:
                data.append(ClaimEvidencePair(
                    claim_id=claim_id,
                    evidence_id=evidence_id,
                    label=1
                ))
            
            # Get some negatives
            n_neg = 0
            for rel_evidence_id in self.related_evidences:
                if n_neg >= 10:
                    break
                if rel_evidence_id in evidence_ids:
                    continue
                data.append(ClaimEvidencePair(
                    claim_id=claim_id,
                    evidence_id=rel_evidence_id,
                    label=-1
                ))
                n_neg += 1
        return data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx) -> Tuple[Union[str, torch.Tensor]]:
        # Fetch the required data rows
        data = self.data[idx]
        
        # Get the label
        label = torch.tensor(data.label, device=self.device)
        
        # Get text ids
        claim_id = data.claim_id
        evidence_id = data.evidence_id
        
        # Get text
        claim_text = self.claims[claim_id]["claim_text"]
        evidence_text = self.evidence[evidence_id]

        return (claim_text, evidence_text, label)

In [6]:
class SiameseDataset(Dataset):
    
    def __init__(
        self,
        claims_paths:List[Path],
        claims_shortlist_paths:List[Path],
        evidence_path:Path,
        evidence_shortlists:List[Path] = None,
        device = None,
        n_neg_shortlist:int = 10,
        n_neg_general:int = 10,
        verbose:bool=True
    ) -> None:
        super(SiameseDataset, self).__init__()
        self.verbose = verbose
        self.device = device
        self.n_neg_shortlist = n_neg_shortlist
        self.n_neg_general = n_neg_general

        # Load claims data from json, this is a list as we could use
        # multiple json files in the same dataset
        self.claims = dict()
        for json_file in claims_paths:
            with open(json_file, mode="r") as f:
                self.claims.update(json.load(fp=f))
                # print(f"loaded claims: {json_file}")
        
        # Load the pre-retrieved shortlist of evidences by claim
        self.claims_shortlist = dict()
        for json_file in claims_shortlist_paths:
            with open(json_file, mode="r") as f:
                self.claims_shortlist.update(json.load(fp=f))
                # print(f"loaded claims_shortlist: {json_file}")
        
        # Load evidence library
        self.evidence = dict()
        with open(evidence_path, mode="r") as f:
            self.evidence.update(json.load(fp=f))
            # print(f"loaded evidences: {json_file}")
        
        # Load the evidence shortlists if available
        # Reduce the overall evidence list to the shortlist
        if evidence_shortlists is not None:
            self.evidence_shortlist = set()
            for json_file in evidence_shortlists:
                with open(json_file, mode="r") as f:
                    self.evidence_shortlist.update(json.load(fp=f))
                    # print(f"loaded evidence shortlist: {json_file}")
        
        # print(f"n_evidences: {len(self.evidence)}")
        
        # Generate the data
        self.data = self.__generate_data()
        return

    def __generate_data(self):
        print("Generate siamese dataset")
        
        data = []
        for claim_id, claim in tqdm(
            iterable=self.claims.items(),
            desc="claims",
            disable=not self.verbose
        ):
            # Check if we have evidences supplied, this will inform
            # whether this is for training
            is_training = "evidences" in claim.keys()
            pos_evidence_ids = set()
            
            # Get positive samples from evidences with label=1
            if is_training:
                pos_evidence_ids.update(claim["evidences"])

                for evidence_id in pos_evidence_ids:
                    data.append(ClaimEvidencePair(
                        claim_id=claim_id,
                        evidence_id=evidence_id,
                        label=1
                    ))
                    
            # Get negative samples from pre-retrieved evidences
            # for each claim with label=-1
            retrieved_evidence_ids = self.claims_shortlist.get(claim_id, [])
            if len(retrieved_evidence_ids) > 0:
                retrieved_neg_evidence_ids = random.sample(
                    population=retrieved_evidence_ids,
                    k=min(self.n_neg_shortlist, len(retrieved_evidence_ids))
                )
                
                # Generate claim and shortlisted negative evidence pairs
                for evidence_id in retrieved_neg_evidence_ids:
                    data.append(ClaimEvidencePair(
                        claim_id=claim_id,
                        evidence_id=evidence_id,
                        label=-1
                    ))
            
            # Get negative samples from shortlisted evidences list with label=0
            if len(self.evidence_shortlist) > 0:
                shortlist_neg_evidence_ids = random.sample(
                    population=self.evidence_shortlist,
                    k=min(self.n_neg_general, len(self.evidence_shortlist))
                )
                
                # Generate claim and shortlisted negative evidence pairs
                for evidence_id in shortlist_neg_evidence_ids:
                    data.append(ClaimEvidencePair(
                        claim_id=claim_id,
                        evidence_id=evidence_id,
                        label=-1
                    ))
            
            continue
        
        print(f"Generated data n={len(data)}")
        
        return data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx) -> Tuple[Union[str, torch.Tensor]]:
        # Fetch the required data rows
        data = self.data[idx]
        
        # Get the label
        label = torch.tensor(data.label, device=self.device)
        
        # Get text ids
        claim_id = data.claim_id
        evidence_id = data.evidence_id
        
        # Get text
        claim_text = self.claims[claim_id]["claim_text"]
        evidence_text = self.evidence[evidence_id]

        return (claim_text, evidence_text, label)

In [7]:
# WE WILL GENERATE THE DATASET PER EPOCH SO TO RANDOMISE THE NEGATIVE SAMPLES

# train_data = SiameseDataset(
#     claims_paths=[DATA_PATH.with_name("train-claims.json")],
#     claims_shortlist_paths=[NER_PATH.with_name("train_claim_evidence_retrieved.json")],
#     evidence_shortlists=[NER_PATH.with_name("shortlist_train_claim_evidence_retrieved.json")],
#     evidence_path=DATA_PATH.with_name("evidence.json"),
#     device=TORCH_DEVICE,
#     n_neg_shortlist=100,
#     n_neg_general=100
# )

## Build model

In [8]:
class SiameseEmbedderBert(Module):
    
    def __init__(
            self,
            pretrained_name:str,
            device,
            **kwargs
        ) -> None:
        super(SiameseEmbedderBert, self).__init__(**kwargs)
        self.device = device
        
        # Use a pretrained tokenizer
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_name)
        
        # Use a pretrained model
        self.bert = BertModel.from_pretrained(pretrained_name)
        self.bert.to(device=device)
        return
        
    def forward(self, claim_texts, evidence_texts, eval_mode:bool=False) -> Tuple[torch.Tensor]:
        
        # Run the tokenizer
        t_kwargs = {
            "return_tensors": "pt",
            "padding": True,
            "truncation": True,
            "max_length": 100,
            "add_special_tokens":True
        }
        claim_x = self.tokenizer(claim_texts, **t_kwargs)
        evidence_x = self.tokenizer(evidence_texts, **t_kwargs)
        
        claim_x = claim_x["input_ids"].to(device=self.device)
        evidence_x = evidence_x["input_ids"].to(device=self.device)
        
        # Run Bert
        claim_x = self.bert(claim_x, return_dict=True).pooler_output
        evidence_x = self.bert(evidence_x, return_dict=True).pooler_output
        # dim=768
        
        # Cosine similarity
        if eval_mode:
            cos_sim = torch.cosine_similarity(x1=claim_x, x2=evidence_x)
            return claim_x, evidence_x, cos_sim
        
        return claim_x, evidence_x

## Training and evaluation loop

In [9]:
model = SiameseEmbedderBert(
    pretrained_name="bert-base-cased",
    device=TORCH_DEVICE
)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
loss_fn = CosineEmbeddingLoss()
optimizer = Adam(
    params=model.parameters(),
    lr=0.00000128
) #! Hyperparams

In [11]:
run_time = datetime.now().strftime('%Y_%m_%d_%H_%M')
MODEL_NAME = f"model_03a_bert_base_cos_sim_{run_time}.pth"
N_EPOCHS = 100
BATCH_SIZE = 64


In [12]:
dev_data = SiameseEvalDataset(
    dev_claims_path=DATA_PATH.with_name("dev-claims.json"),
    evidence_path=DATA_PATH.with_name("evidence.json"),
    device=TORCH_DEVICE
)

dev_dataloader = DataLoader(
    dataset=dev_data,
    shuffle=False,
    batch_size=BATCH_SIZE
)

claims: 100%|██████████| 154/154 [00:00<00:00, 199297.38it/s]


In [13]:
import warnings
warnings.filterwarnings('ignore')

In [14]:
# Run evaluation before training to establish baseline
model.eval()

dev_batches = tqdm(dev_dataloader, desc="dev batches")
epoch_pos_cos_sim = []
epoch_neg_cos_sim = []
for batch in dev_batches:
    claim_texts, evidence_texts, labels = batch
    
    # Forward
    claim_emb, evidence_emb, cos_sim = model(claim_texts, evidence_texts, eval_mode=True)
    
    # Cosine similarity
    labelled_cos_sim = cos_sim * labels
    pos_cos_sim = labelled_cos_sim[torch.where(labelled_cos_sim > 0)]
    neg_cos_sim = labelled_cos_sim[torch.where(labelled_cos_sim < 0)]
    
    batch_pos_cos_sim = torch.mean(pos_cos_sim).cpu().item()
    batch_neg_cos_sim = torch.mean(neg_cos_sim).cpu().item() * -1
    
    epoch_pos_cos_sim.append(batch_pos_cos_sim)
    epoch_neg_cos_sim.append(batch_neg_cos_sim)
    
    dev_batches.postfix = f"pos cos_sim: {batch_pos_cos_sim:.3f}" + \
        f" neg cos_sim: {batch_neg_cos_sim:.3f}"
    
    continue

dev batches: 100%|██████████| 32/32 [00:08<00:00,  3.63it/s, pos cos_sim: 0.972 neg cos_sim: 0.868]


In [15]:
print(f"Average cos sim (pos, neg): {np.mean(epoch_pos_cos_sim):3f}, {np.mean(epoch_neg_cos_sim):3f}")

Average cos sim (pos, neg): 0.863419, 0.838396


In [16]:
metric_accuracy = BinaryAccuracy()
metric_f1 = BinaryF1Score()
metric_recall = BinaryF1Score()

scheduler = LinearLR(
    optimizer=optimizer,
    start_factor=1,
    end_factor=1,
    total_iters=2,
    verbose=True
)
best_epoch_loss = 999
for epoch in range(N_EPOCHS):
    
    print(f"Epoch: {epoch} of {N_EPOCHS}\n")
    
    # Run training
    model.train()
    
    train_data = SiameseDataset(
        claims_paths=[DATA_PATH.with_name("train-claims.json")],
        claims_shortlist_paths=[NER_PATH.with_name("train_claim_evidence_retrieved.json")],
        evidence_shortlists=[NER_PATH.with_name("shortlist_train_claim_evidence_retrieved.json")],
        evidence_path=DATA_PATH.with_name("evidence.json"),
        device=TORCH_DEVICE,
        n_neg_shortlist=2,
        n_neg_general=1
    )
    
    train_dataloader = DataLoader(
        dataset=train_data,
        shuffle=True,
        batch_size=BATCH_SIZE
    )
    
    train_batches = tqdm(train_dataloader, desc="train batches")
    running_losses = []
    for batch in train_batches:
        claim_texts, evidence_texts, labels = batch
        
        # Reset optimizer
        optimizer.zero_grad()
        
        # Forward + loss
        claim_emb, evidence_emb = model(claim_texts, evidence_texts)
        loss = loss_fn(input1=claim_emb, input2=evidence_emb, target=labels)
        
        # Backward + optimiser
        loss.backward()
        optimizer.step()
        
        # Update running loss
        batch_loss = loss.item() * len(batch)
        running_losses.append(batch_loss)
        
        train_batches.postfix = f"loss: {batch_loss:.3f}"
        
        continue
    
    scheduler.step()
    
    epoch_loss = np.average(running_losses)
    print(f"Average epoch loss: {epoch_loss}")
    
    # Save model
    if epoch_loss <= best_epoch_loss:
        best_epoch_loss = epoch_loss
        torch.save(model, MODEL_PATH.with_name(MODEL_NAME))
        print(f"Saved model to: {MODEL_PATH.with_name(MODEL_NAME)}")
    
    # Evaluate every 5 epochs
    # if epoch % 5 != 0:
    #     continue
    
    # Run evaluation before training to establish baseline
    model.eval()

    dev_batches = tqdm(dev_dataloader, desc="dev batches")
    epoch_pos_cos_sim = []
    epoch_neg_cos_sim = []
    for batch in dev_batches:
        claim_texts, evidence_texts, labels = batch

        # Forward
        claim_emb, evidence_emb, cos_sim = model(claim_texts, evidence_texts, eval_mode=True)

        # Cosine similarity
        labelled_cos_sim = cos_sim * labels
        pos_cos_sim = labelled_cos_sim[torch.where(labelled_cos_sim > 0)]
        neg_cos_sim = labelled_cos_sim[torch.where(labelled_cos_sim < 0)]

        batch_pos_cos_sim = torch.mean(pos_cos_sim).cpu().item()
        batch_neg_cos_sim = torch.mean(neg_cos_sim).cpu().item() * -1

        epoch_pos_cos_sim.append(batch_pos_cos_sim)
        epoch_neg_cos_sim.append(batch_neg_cos_sim)

        dev_batches.postfix = f"pos cos_sim: {batch_pos_cos_sim:.3f}" + \
            f" neg cos_sim: {batch_neg_cos_sim:.3f}"

        continue
    
    print(f"Average cos sim (pos, neg): {np.mean(epoch_pos_cos_sim):3f}, {np.mean(epoch_neg_cos_sim):3f}")

print("Done!")

Adjusting learning rate of group 0 to 1.2800e-06.
Epoch: 0 of 100

Generate siamese dataset


claims: 100%|██████████| 1228/1228 [00:07<00:00, 175.38it/s]


Generated data n=7693


train batches: 100%|██████████| 121/121 [01:46<00:00,  1.14it/s, loss: 1.296]


Adjusting learning rate of group 0 to 1.2800e-06.
Average epoch loss: 1.3890747568331474
Saved model to: /Users/johnsonzhou/git/comp90042-project/result/models/model_03a_bert_base_cos_sim_2023_05_03_21_10.pth


dev batches: 100%|██████████| 32/32 [00:07<00:00,  4.36it/s, pos cos_sim: 0.894 neg cos_sim: 0.791]


Average cos sim (pos, neg): 0.828207, 0.789408
Epoch: 1 of 100

Generate siamese dataset


claims: 100%|██████████| 1228/1228 [00:07<00:00, 168.43it/s]


Generated data n=7693


train batches: 100%|██████████| 121/121 [01:45<00:00,  1.15it/s, loss: 0.994]


Adjusting learning rate of group 0 to 1.2800e-06.
Average epoch loss: 1.0412948363699204
Saved model to: /Users/johnsonzhou/git/comp90042-project/result/models/model_03a_bert_base_cos_sim_2023_05_03_21_10.pth


dev batches: 100%|██████████| 32/32 [00:07<00:00,  4.32it/s, pos cos_sim: 0.854 neg cos_sim: 0.846]


Average cos sim (pos, neg): 0.786210, 0.797201
Epoch: 2 of 100

Generate siamese dataset


claims: 100%|██████████| 1228/1228 [00:06<00:00, 175.61it/s]


Generated data n=7693


train batches: 100%|██████████| 121/121 [01:44<00:00,  1.16it/s, loss: 0.306]


Adjusting learning rate of group 0 to 1.2800e-06.
Average epoch loss: 0.6969264437837049
Saved model to: /Users/johnsonzhou/git/comp90042-project/result/models/model_03a_bert_base_cos_sim_2023_05_03_21_10.pth


dev batches: 100%|██████████| 32/32 [00:07<00:00,  4.35it/s, pos cos_sim: 0.730 neg cos_sim: 0.880]


Average cos sim (pos, neg): 0.786910, 0.803039
Epoch: 3 of 100

Generate siamese dataset


claims: 100%|██████████| 1228/1228 [00:06<00:00, 175.95it/s]


Generated data n=7693


train batches: 100%|██████████| 121/121 [01:44<00:00,  1.16it/s, loss: 0.427]


Adjusting learning rate of group 0 to 1.2800e-06.
Average epoch loss: 0.6255223235069227
Saved model to: /Users/johnsonzhou/git/comp90042-project/result/models/model_03a_bert_base_cos_sim_2023_05_03_21_10.pth


dev batches: 100%|██████████| 32/32 [00:07<00:00,  4.36it/s, pos cos_sim: 0.718 neg cos_sim: 0.872]


Average cos sim (pos, neg): 0.799902, 0.823209
Epoch: 4 of 100

Generate siamese dataset


claims: 100%|██████████| 1228/1228 [00:06<00:00, 178.55it/s]


Generated data n=7693


train batches: 100%|██████████| 121/121 [01:43<00:00,  1.17it/s, loss: 0.560]


Adjusting learning rate of group 0 to 1.2800e-06.
Average epoch loss: 0.6126381267200817
Saved model to: /Users/johnsonzhou/git/comp90042-project/result/models/model_03a_bert_base_cos_sim_2023_05_03_21_10.pth


dev batches: 100%|██████████| 32/32 [00:07<00:00,  4.36it/s, pos cos_sim: 0.655 neg cos_sim: 0.872]


Average cos sim (pos, neg): 0.798563, 0.819310
Epoch: 5 of 100

Generate siamese dataset


claims: 100%|██████████| 1228/1228 [00:06<00:00, 178.24it/s]


Generated data n=7693


train batches: 100%|██████████| 121/121 [01:43<00:00,  1.17it/s, loss: 1.209]


Adjusting learning rate of group 0 to 1.2800e-06.
Average epoch loss: 0.5860244390393091
Saved model to: /Users/johnsonzhou/git/comp90042-project/result/models/model_03a_bert_base_cos_sim_2023_05_03_21_10.pth


dev batches: 100%|██████████| 32/32 [00:07<00:00,  4.42it/s, pos cos_sim: 0.635 neg cos_sim: 0.859]


Average cos sim (pos, neg): 0.795910, 0.793795
Epoch: 6 of 100

Generate siamese dataset


claims: 100%|██████████| 1228/1228 [00:06<00:00, 186.28it/s]


Generated data n=7693


train batches: 100%|██████████| 121/121 [01:42<00:00,  1.17it/s, loss: 0.605]


Adjusting learning rate of group 0 to 1.2800e-06.
Average epoch loss: 0.5642308885893546
Saved model to: /Users/johnsonzhou/git/comp90042-project/result/models/model_03a_bert_base_cos_sim_2023_05_03_21_10.pth


dev batches: 100%|██████████| 32/32 [00:07<00:00,  4.44it/s, pos cos_sim: 0.624 neg cos_sim: 0.877]


Average cos sim (pos, neg): 0.790249, 0.790002
Epoch: 7 of 100

Generate siamese dataset


claims: 100%|██████████| 1228/1228 [00:06<00:00, 183.56it/s]


Generated data n=7693


train batches: 100%|██████████| 121/121 [01:43<00:00,  1.17it/s, loss: 0.266]


Adjusting learning rate of group 0 to 1.2800e-06.
Average epoch loss: 0.5422538236891927
Saved model to: /Users/johnsonzhou/git/comp90042-project/result/models/model_03a_bert_base_cos_sim_2023_05_03_21_10.pth


dev batches: 100%|██████████| 32/32 [00:07<00:00,  4.44it/s, pos cos_sim: 0.662 neg cos_sim: 0.861]


Average cos sim (pos, neg): 0.739562, 0.798187
Epoch: 8 of 100

Generate siamese dataset


claims: 100%|██████████| 1228/1228 [00:06<00:00, 186.53it/s]


Generated data n=7693


train batches: 100%|██████████| 121/121 [01:42<00:00,  1.18it/s, loss: 0.372]


Adjusting learning rate of group 0 to 1.2800e-06.
Average epoch loss: 0.5334123548520499
Saved model to: /Users/johnsonzhou/git/comp90042-project/result/models/model_03a_bert_base_cos_sim_2023_05_03_21_10.pth


dev batches: 100%|██████████| 32/32 [00:07<00:00,  4.33it/s, pos cos_sim: 0.687 neg cos_sim: 0.869]


Average cos sim (pos, neg): 0.773839, 0.813255
Epoch: 9 of 100

Generate siamese dataset


claims: 100%|██████████| 1228/1228 [00:07<00:00, 174.73it/s]


Generated data n=7693


train batches:  32%|███▏      | 39/121 [00:34<01:12,  1.13it/s, loss: 0.578]


KeyboardInterrupt: 