# Model 05 inference

**This is the final implementation for the retrieval stage**

Use a model trained with [Model 05](../04_model/model_05_bert_cross_encoder_retrieval_classifier.ipynb) and use it to create predictions on evidence retrieval based on shortlists created by [Model 02c](../04_model/model_02c_fast_shortlisting.ipynb).

Prerequisites:
1. Shortlist for `test` created using [Model 02c](../04_model/model_02c_fast_shortlisting.ipynb).
2. Train retrieval model using [Model 05](../04_model/model_05_bert_cross_encoder_retrieval_classifier.ipynb).


## Setup

### Working directory

In [None]:
# Change the working directory to project root
from pathlib import Path
import os
ROOT_DIR = Path.cwd()
while not ROOT_DIR.joinpath("src").exists():
    ROOT_DIR = ROOT_DIR.parent
os.chdir(ROOT_DIR)

### Dependencies

In [None]:
# Imports and dependencies
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torcheval.metrics import BinaryAccuracy, BinaryF1Score

from src.torch_utils import get_torch_device
import json
from dataclasses import dataclass
from typing import List, Union, Tuple, Dict
from tqdm import tqdm
import random
import numpy as np
from datetime import datetime
from sklearn.model_selection import ParameterGrid

from src.model_05 import BertCrossEncoderClassifier, RobertaLargeCrossEncoderClassifier
from src.data import RetrievalWithShortlistDataset, RetrievalDevEvalDataset, \
    InferenceClaims
from src.logger import SimpleLogger

TORCH_DEVICE = get_torch_device()

### File paths

In [None]:
MODEL_PATH = ROOT_DIR.joinpath("./result/*")
DATA_PATH = ROOT_DIR.joinpath("./data/*")
LOG_PATH = ROOT_DIR.joinpath("./result/*")
SHORTLIST_PATH = ROOT_DIR.joinpath("./result/*")
OUTPUT_PATH = ROOT_DIR.joinpath("./result/*")

run_time = datetime.now().strftime('%Y_%m_%d_%H_%M')

## Load model

Use a blank pre-trained

In [None]:
# model = BertCrossEncoderClassifier(
#     pretrained_name="bert-large-uncased",
#     n_classes=2,
#     device=TORCH_DEVICE
# )

Or load fine-tuned

In [None]:
MODEL_SAVE_PATH = MODEL_PATH.with_name("model_05_bert_cross_encoder_retrieval_2023_05_11_08_52.pth")
# MODEL_SAVE_PATH = MODEL_PATH.with_name("model_05_bert_cross_encoder_retrieval_2023_05_12_15_08.pth")
# MODEL_SAVE_PATH = MODEL_PATH.with_name("model_05_bert_cross_encoder_retrieval_2023_05_12_16_03.pth")
# MODEL_SAVE_PATH = MODEL_PATH.with_name("model_05_bert_cross_encoder_retrieval_2023_05_13_18_19.pth")
with open(MODEL_SAVE_PATH, mode="rb") as f:
    model = torch.load(f, map_location=TORCH_DEVICE)

## Inference run code

In [None]:
@dataclass
class EvidenceScore:
    evidence_id:str
    score:float
    
    def to_list(self) -> List[str]:
        return [self.evidence_id, str(self.score)]
    
    def to_dict(self) -> Dict[str, str]:
        return {"evidence_id": self.evidence_id, "score": str(self.score)}

In [None]:
def inference_run(
    claims_path:Path,
    evidence_path:Path,
    claims_shortlist_path:Path = None,
    save_path:Path = None,
    batch_size:int = 64
):
    # Generate claims iterations
    inference_claims = InferenceClaims(claims_path=claims_path)
    
    # Cumulator
    claim_predictions = dict()
    
    # Load from save to continue inference if exists
    saved_predictions = dict()
    if save_path.exists():
        with open(save_path, mode="r") as f:
            saved_predictions.update(json.load(fp=f))
    
    for claim_id, claim_text in inference_claims:
        
        # Skip if a save entry exists
        if claim_id in saved_predictions.keys():
            print(f"skipping {claim_id}, done previously")
            continue
        
        # Create a save entry
        claim_predictions[claim_id] = {
            "claim_text": claim_text,
            "evidences": []
        }
        
        # Generate claims-evidence inference interations
        infer_data = RetrievalWithShortlistDataset(
            claim_id=claim_id,
            claims_paths=[claims_path],
            claims_shortlist_paths=[claims_shortlist_path],
            evidence_path=evidence_path,
            pos_label=0, # Wont' matter what label we use for inference
            neg_label=0,
            n_neg_samples=9999999999,
            inference=True,
            shuffle=False,
            seed=42,
            verbose=True
        )
        infer_dataloader = DataLoader(
            dataset=infer_data,
            shuffle=False,
            batch_size=batch_size
        )
        
        print(f"running inference for {claim_id} n={len(infer_data)}")
    
        # Set model mode to evaluation
        model.eval()
        
        infer_batches = tqdm(infer_dataloader, desc="infer batches")
        for batch in infer_batches:
            claim_texts, evidence_texts, labels, batch_claim_ids, batch_evidence_ids = batch
            texts = list(zip(claim_texts, evidence_texts))
            
            # Forward
            output, logits, seq = model(
                texts=texts,
                normalize_text=True,
                max_length=512,
                dropout=None
            )
            
            # Get score as probabilities for label 1
            batch_scores = output[:, 1]
            
            for c_id, e_id, score in zip(
                batch_claim_ids, batch_evidence_ids,
                batch_scores.cpu().detach().numpy()
            ):
                claim_predictions[c_id]["evidences"].append(EvidenceScore(
                    evidence_id=e_id,
                    score=score
                ))
            
            continue
    
        # Save on every claim_id
        if save_path:
            # Retrieve at most 5 top predicted evidences by score
            claim_predictions_output = dict()
            for claim_id_, claim_ in claim_predictions.items():
                claim_ = claim_.copy()
                claim_["evidences"] = [
                    # (evidence_score.evidence_id, str(round(evidence_score.score, 3)))
                    evidence_score.evidence_id
                    for evidence_score in sorted(
                        claim_["evidences"],
                        key=lambda es: es.score,
                        reverse=True
                    )
                ][:5]
                # ][:4]
                # ][:3]
                claim_predictions_output[claim_id_] = claim_
            
            # Make a copy of existing saved results and add the new results
            # for this run
            save_dict = saved_predictions.copy()
            save_dict.update(claim_predictions_output)

            with open(save_path, mode="w") as f:
                json.dump(obj=save_dict, fp=f)
                print(f"saved to: {save_path}")
    
        continue
    return

## Run inference

In [None]:
inference_run(
    claims_path=DATA_PATH.with_name("test-claims-unlabelled.json"),
    evidence_path=DATA_PATH.with_name("evidence.json"),
    claims_shortlist_path=SHORTLIST_PATH.with_name("test_shortlist_evidences_max_2000.json"),
    save_path=OUTPUT_PATH.with_name("model_05_bert_cross_encoder_retrieval_2023_05_11_08_52_test_shortlist_max_1000.json"),
    batch_size=16
)