# Model 01 inference

Evidence retrieval using a Siamese BERT classification model.

Ref:
- [STS continue training guide](https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/sts/training_stsbenchmark_continue_training.py)

## Setup

### Working Directory

In [1]:
# Change the working directory to project root
import pathlib
import os
ROOT_DIR = pathlib.Path.cwd()
while not ROOT_DIR.joinpath("src").exists():
    ROOT_DIR = ROOT_DIR.parent
os.chdir(ROOT_DIR)

### File paths

In [2]:
MODEL_PATH = ROOT_DIR.joinpath("./result/models/*")
OUTPUT_PATH = ROOT_DIR.joinpath("./result/inference")

### Dependencies

In [3]:
# Imports and dependencies
import torch
from sentence_transformers import SentenceTransformer, LoggingHandler, util
from src.torch_utils import get_torch_device
from src.data import load_from_json
from src.model_01 import run_inference
import logging
import random
random.seed(a=42)

torch_device = get_torch_device()

  from .autonotebook import tqdm as notebook_tqdm


Torch device is 'mps'


### Names

In [4]:
model_save_path = MODEL_PATH.with_name(f"model_01_base_e5_equal_neg")
inference_output_path = OUTPUT_PATH.joinpath(model_save_path.name)

### Logging

In [5]:
logging.basicConfig(format='%(asctime)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    handlers=[LoggingHandler()]
)

## Dataset

In [6]:
data_names = ["train-claims", "dev-claims", "test-claims-unlabelled", "evidence"]
train_claims, dev_claims, test_claims, all_evidence = load_from_json(data_names)

Loaded train-claims
Loaded dev-claims
Loaded test-claims-unlabelled
Loaded evidence


In [7]:
print(len(test_claims))
print(len(dev_claims))
print(len(all_evidence))

153
154
1208827


As `all_evidence` exceeds maximum size limit for `tensor.save`, we will test with a reduced set for now.

In [8]:
# Extract a set of named evidence ids
related_evidence_ids = set()
for dataset in [train_claims, dev_claims]:
    for claim in dataset.values():
        related_evidence_ids.update(set(claim["evidences"]))
len(related_evidence_ids)

3443

In [9]:
random_evidence_ids = random.sample(
    population=set(all_evidence.keys()),
    k=5000
)
len(random_evidence_ids)

5000

In [10]:
evidence_lib_ids = related_evidence_ids.union(random_evidence_ids)
len(evidence_lib_ids)

8431

In [11]:
reduced_evidence = {k:v for k, v in all_evidence.items() if k in evidence_lib_ids}

## Select load model from file

In [12]:
model = SentenceTransformer(
    model_name_or_path=model_save_path,
    device=torch_device
)
model

2023-04-27 22:34:23 - Load pretrained SentenceTransformer: /Users/johnsonzhou/git/comp90042-project/result/models/model_01_base_e5_equal_neg


SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
)

## Run inference

In [13]:
run_inference(
    name="dev",
    model=model,
    claims=dev_claims,
    evidence=reduced_evidence,
    scorer=util.dot_score,
    threshold=195.8921,
    output_path=inference_output_path,
    batch_size=64,
    device=torch_device,
    verbose=True
)

Generate claim embeddings n=154
Loaded claim embeddings from file
Generate evidence embeddings n=8431


Batches: 100%|██████████| 132/132 [00:14<00:00,  8.85it/s]


Saved evidence embeddings to file
Calculate scores
Retrieve top scoring evidences


claims: 154it [00:00, 978.58it/s]

Average retrievals = 511.227273
Done!



