# Model 03 inference

Evidence retrieval using a siamese BERT classification model. This is similar to Model 01 except this does not use any community based models from sentence transformer.

## Setup

### Working directory

In [2]:
# Change the working directory to project root
from pathlib import Path
import os
ROOT_DIR = Path.cwd()
while not ROOT_DIR.joinpath("src").exists():
    ROOT_DIR = ROOT_DIR.parent
os.chdir(ROOT_DIR)

### File paths

In [3]:
MODEL_PATH = ROOT_DIR.joinpath("./result/models/*")
OUTPUT_PATH = ROOT_DIR.joinpath("./result/inference/*")
DATA_PATH = ROOT_DIR.joinpath("./data/*")
NER_PATH = ROOT_DIR.joinpath("./result/ner/*")

### Dependencies

In [28]:
# Imports and dependencies
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import Linear, Module, CrossEntropyLoss, Dropout, CosineSimilarity
from transformers import BertModel, BertTokenizer
from torch.optim import Adam
from torch.optim.lr_scheduler import LinearLR
from torch.nn.functional import relu, softmax
from torcheval.metrics import BinaryAccuracy, BinaryF1Score, BinaryRecall
from sentence_transformers import util

from src.torch_utils import get_torch_device
from src.data import create_claim_output
import json
from dataclasses import dataclass
from typing import List, Union, Tuple, Dict
from tqdm import tqdm
import random
import numpy as np
from datetime import datetime
from math import exp
from collections import defaultdict

TORCH_DEVICE = get_torch_device()

Torch device is 'mps'


## Dataset

In [5]:
@dataclass
class ClaimEvidencePair:
    claim_id:str
    evidence_id:str
    label:int = 0

In [6]:
class InferenceClaims(Dataset):
    
    def __init__(self, claims_path:Path) -> None:
        super(InferenceClaims, self).__init__()
        with open(claims_path, mode="r") as f:
            self.claims = json.load(fp=f)
            self.claim_ids = list(self.claims.keys())
            print(f"loaded inference claims n={len(self.claim_ids)}")
        return
        
    def __len__(self):
        return len(self.claim_ids)

    def __getitem__(self, idx) -> Tuple[str]:
        claim_id = self.claim_ids[idx]
        claim_text = self.claims[claim_id]["claim_text"]
        return claim_id, claim_text

In [7]:
class SeenOnlyDataset(Dataset):
    
    def __init__(
        self,
        claim_id:str,
        claims_path:Path,
        train_claims_paths:List[Path],
        evidence_path:Path,
        verbose:bool=True
    ) -> None:
        super(SeenOnlyDataset, self).__init__()
        self.verbose = verbose
        self.claim_id = claim_id
        
        # Load claims text from json
        with open(claims_path, mode="r") as f:
            claims = json.load(fp=f)
            self.claim_text = claims[self.claim_id]["claim_text"]
            
        # Load evidence library
        with open(evidence_path, mode="r") as f:
            self.evidence = json.load(fp=f)
        
        # Load training claims
        train_claims = dict()
        for path in train_claims_paths:
            with open(path, mode="r") as f:
                train_claims.update(json.load(f))
        
        # Load the evidence shortlist
        evidence_shortlist = set()
        for claim in train_claims.values():
            evidence_shortlist.update(claim["evidences"])
        self.evidence_shortlist = sorted(evidence_shortlist)
        return
        
    def __len__(self):
        return len(self.evidence_shortlist)
    
    def __getitem__(self, idx) -> Tuple[str]:
        # Get text ids
        claim_id = self.claim_id
        evidence_id = self.evidence_shortlist[idx]
        
        # Get text
        claim_text = self.claim_text
        evidence_text = self.evidence[evidence_id]

        return (claim_text, evidence_text, claim_id, evidence_id)

In [8]:
class SiameseDatasetInference(Dataset):
    
    def __init__(
        self,
        claim_id:str,
        claims_path:Path,
        evidence_path:Path,
        claims_shortlist:Path = None,
        evidence_shortlist:Path = None,
        verbose:bool=True
    ) -> None:
        super(SiameseDatasetInference, self).__init__()
        self.verbose = verbose
        self.claim_id = claim_id
        
        # Load claims text from json
        with open(claims_path, mode="r") as f:
            claims = json.load(fp=f)
            self.claim_text = claims[self.claim_id]["claim_text"]
            
        # Load evidence library
        with open(evidence_path, mode="r") as f:
            self.evidence = json.load(fp=f)
        
        # Load the pre-retrieved shortlist of evidences
        # From either a pre-retrieved list of evidences specific for the
        # claim_id or from a pre-collated evidence shortlist
        # Both of which were determined from the fast shortlisting process
        if claims_shortlist is not None:
            with open(claims_shortlist, mode="r") as f:
                claims_shortlist_ = json.load(fp=f)
                self.evidence_shortlist = list(set(
                    claims_shortlist_.get(self.claim_id, [])
                ))
                print(f"loaded claims_shortlist: {claims_shortlist}")
        elif evidence_shortlist is not None:
            with open(evidence_shortlist, mode="r") as f:
                self.evidence_shortlist = list(set(json.load(fp=f)))
                print(f"loaded evidence_shortlist: {evidence_shortlist}")
        else:
            raise Exception(
                "Provide either a claims_shortlist or evidence_shortlist"
            )

        return

    def __len__(self):
        return len(self.evidence_shortlist)
    
    def __getitem__(self, idx) -> Tuple[str]:
        # Get text ids
        claim_id = self.claim_id
        evidence_id = self.evidence_shortlist[idx]
        
        # Get text
        claim_text = self.claim_text
        evidence_text = self.evidence[evidence_id]

        return (claim_text, evidence_text, claim_id, evidence_id)

## Load model

In [9]:
class SiameseTripletEmbedderBert(Module):
    
    def __init__(
            self,
            pretrained_name:str,
            device,
            **kwargs
        ) -> None:
        super(SiameseTripletEmbedderBert, self).__init__(**kwargs)
        self.device = device
        
        # Use a pretrained tokenizer
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_name)
        
        # Use a pretrained model
        self.bert = BertModel.from_pretrained(pretrained_name)
        self.bert.to(device=device)
        return
        
    def forward(self, anchor_texts, pos_texts, neg_texts=None) -> Tuple[torch.Tensor]:
        
        # Run the tokenizer
        t_kwargs = {
            "return_tensors": "pt",
            "padding": True,
            "truncation": True,
            "max_length": 100,
            "add_special_tokens":True
        }
        anchor_x = self.tokenizer(anchor_texts, **t_kwargs)
        pos_x = self.tokenizer(pos_texts, **t_kwargs)
        if neg_texts:
            neg_x = self.tokenizer(neg_texts, **t_kwargs)
        
        anchor_x = anchor_x["input_ids"].to(device=self.device)
        pos_x = pos_x["input_ids"].to(device=self.device)
        if neg_texts:
            neg_x = neg_x["input_ids"].to(device=self.device)
        
        # Run Bert
        anchor_x = self.bert(anchor_x, return_dict=True).pooler_output
        pos_x = self.bert(pos_x, return_dict=True).pooler_output
        if neg_texts:
            neg_x = self.bert(neg_x, return_dict=True).pooler_output
        # dim=768
        
        if neg_texts:
            return anchor_x, pos_x, neg_x
        else:
            return anchor_x, pos_x

In [10]:
MODEL_NAME = f"model_03a_bert_base_triplet_margin_2_neg_5_10epochs.pth"
BATCH_SIZE = 64

In [11]:
with open(MODEL_PATH.with_name(MODEL_NAME), mode="rb") as f:
    model = torch.load(f, map_location=TORCH_DEVICE)

## Inference run code

In [12]:
@dataclass
class EvidenceScore:
    evidence_id:str
    score:float=0
    cos_sim:float=0
    dot:float=0
    
    def to_list(self) -> List[str]:
        return [self.evidence_id, str(self.score)]
    
    def to_dict(self) -> Dict[str, str]:
        return {"evidence_id": self.evidence_id, "score": str(self.score)}

In [64]:
# Generate claims-evidence inference interations
infer_data = SeenOnlyDataset(
    claim_id="claim-1937",
    claims_path=DATA_PATH.with_name("train-claims.json"),
    evidence_path=DATA_PATH.with_name("evidence.json"),
    train_claims_paths=[
        DATA_PATH.with_name("dev-claims.json"),
        DATA_PATH.with_name("train-claims.json")
    ],
)
infer_dataloader = DataLoader(
    dataset=infer_data,
    shuffle=False,
    batch_size=64
)

In [65]:
# Set model mode to evaluation
model.eval()

# Cumulator
claim_predictions = []
infer_batches = tqdm(infer_dataloader, desc="infer batches")
for batch in infer_batches:
    claim_texts, evidence_texts, batch_claim_ids, batch_evidence_ids = batch
    
    # # Forward
    claim_emb, evidence_emb = model(claim_texts, evidence_texts)
    
    for e_id, c, e in zip(
        batch_evidence_ids,
        claim_emb,
        evidence_emb
    ):
        # Score embeddings with cosine similarity
        cos_sim_func = CosineSimilarity(dim=-1)
        cos_sim = cos_sim_func(c, e).detach().item()
        
        # Dot product
        dot = util.dot_score(c, e).squeeze().detach().item()
        
        claim_predictions.append(EvidenceScore(
            evidence_id=e_id,
            cos_sim=cos_sim,
            dot=dot
        ))
    
    continue

infer batches: 100%|██████████| 54/54 [00:16<00:00,  3.33it/s]


In [71]:
for i, predicted_evidence in enumerate(sorted(claim_predictions, key=lambda e: e.cos_sim, reverse=True)):
    if predicted_evidence.evidence_id == "evidence-442946":
        print(i, predicted_evidence)
        break
    # print(predicted_evidence)

1455 EvidenceScore(evidence_id='evidence-442946', score=0, cos_sim=0.941272497177124, dot=339.094482421875)


In [72]:
for i, predicted_evidence in enumerate(sorted(claim_predictions, key=lambda e: e.cos_sim, reverse=True)):
    if i >= 20:
        break
    print(i, predicted_evidence)

0 EvidenceScore(evidence_id='evidence-141844', score=0, cos_sim=0.9891319274902344, dot=383.96307373046875)
1 EvidenceScore(evidence_id='evidence-1180647', score=0, cos_sim=0.9867583513259888, dot=392.5845947265625)
2 EvidenceScore(evidence_id='evidence-140540', score=0, cos_sim=0.985182523727417, dot=374.7574157714844)
3 EvidenceScore(evidence_id='evidence-433622', score=0, cos_sim=0.9834713935852051, dot=368.0030822753906)
4 EvidenceScore(evidence_id='evidence-457889', score=0, cos_sim=0.9834628701210022, dot=376.37432861328125)
5 EvidenceScore(evidence_id='evidence-892616', score=0, cos_sim=0.9826794862747192, dot=371.2298889160156)
6 EvidenceScore(evidence_id='evidence-914173', score=0, cos_sim=0.9817409515380859, dot=376.58978271484375)
7 EvidenceScore(evidence_id='evidence-1069909', score=0, cos_sim=0.9815636277198792, dot=381.1744384765625)
8 EvidenceScore(evidence_id='evidence-754568', score=0, cos_sim=0.9813380837440491, dot=371.28375244140625)
9 EvidenceScore(evidence_id='evi