# Model 06 inference

**This is the final implementation for the label classification stage**

Use a model trained with [Model 06](../04_model/model_06_bert_cross_encoder_classification.ipynb) and use it to create predictions on claim labels based on evidence retrievals created by [Infer 05](./infer_05_bert_cross_encoder_retrieval_classifier.ipynb).

Prerequisites:
1. Shortlist for `train` and `dev` created using [Model 02c](./model_02c_fast_shortlisting.ipynb).
2. Train retrieval model using [Model 05](./model_05_bert_cross_encoder_retrieval_classifier.ipynb).
3. Create retrieval predictions using [Infer 05](../05_inference/infer_05_bert_cross_encoder_retrieval_classifier.ipynb).
4. Train label prediction model using [Model 06](../04_model/model_06_bert_cross_encoder_classification.ipynb).

## Setup

### Working directory

In [None]:
# Change the working directory to project root
from pathlib import Path
import os
ROOT_DIR = Path.cwd()
while not ROOT_DIR.joinpath("src").exists():
    ROOT_DIR = ROOT_DIR.parent
os.chdir(ROOT_DIR)

### Dependencies

In [None]:
# Imports and dependencies
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torcheval.metrics import BinaryAccuracy, BinaryF1Score

from src.torch_utils import get_torch_device
import json
from dataclasses import dataclass
from typing import List, Union, Tuple, Dict
from tqdm import tqdm
import random
import numpy as np
from datetime import datetime
from sklearn.model_selection import ParameterGrid

from src.model_05 import BertCrossEncoderClassifier, RobertaLargeCrossEncoderClassifier
from src.data import RetrievalWithShortlistDataset, RetrievalDevEvalDataset, \
    InferenceClaims, RetrievedInferenceClaims, LabelClassificationDataset
from src.logger import SimpleLogger

TORCH_DEVICE = get_torch_device()

### File paths

In [None]:
MODEL_PATH = ROOT_DIR.joinpath("./result/*")
DATA_PATH = ROOT_DIR.joinpath("./data/*")
LOG_PATH = ROOT_DIR.joinpath("./result/*")
RETRIEVAL_PATH = ROOT_DIR.joinpath("./result/*")
OUTPUT_PATH = ROOT_DIR.joinpath("./result/*")

run_time = datetime.now().strftime('%Y_%m_%d_%H_%M')

## Load model

Use a blank pre-trained

In [None]:
# model = BertCrossEncoderClassifier(
#     pretrained_name="bert-base-uncased",
#     n_classes=3,
#     device=TORCH_DEVICE
# )

Or load fine-tuned

In [None]:
MODEL_SAVE_PATH = MODEL_PATH.with_name("model_06_bert_base_uncased_cross_encoder_label_2023_05_09_19_01.pth")
# MODEL_SAVE_PATH = MODEL_PATH.with_name("model_06a_roberta_mnli_cross_encoder_label_2023_05_11_17_47.pth")
# MODEL_SAVE_PATH = MODEL_PATH.with_name("model_06a_roberta_mnli_cross_encoder_label_2023_05_12_11_36.pth")
# MODEL_SAVE_PATH = MODEL_PATH.with_name("model_05_bert_cross_encoder_retrieval_2023_05_12_15_08.pth")
with open(MODEL_SAVE_PATH, mode="rb") as f:
    model = torch.load(f, map_location=TORCH_DEVICE)

## Inference run code

In [None]:
def inference_run(
    claims_retrieval_path:Path,
    evidence_path:Path,
    save_path:Path = None,
    batch_size:int = 64
):
    # Label map
    LABEL_MAP = {
        0: "REFUTES",
        1: "NOT_ENOUGH_INFO",
        2: "SUPPORTS",
        3: "DISPUTED"
    }
    
    # Generate claims iterations
    inference_claims = RetrievedInferenceClaims(claims_path=claims_retrieval_path)
    
    # Cumulator
    with open(claims_retrieval_path, mode="r") as f:
        claim_predictions = json.load(f)
    
    for claim_id, claim_text, evidence_ids in tqdm(inference_claims, desc="claims"):
        
        # Generate claims-evidence inference interations
        infer_data = LabelClassificationDataset(
            claim_id=claim_id,
            claims_paths=[claims_retrieval_path],
            evidence_path=evidence_path,
            training=False,
            verbose=False
        )
        infer_dataloader = DataLoader(
            dataset=infer_data,
            shuffle=False,
            batch_size=batch_size
        )
        
        print(f"running inference for {claim_id} n={len(infer_data)}")
    
        # Set model mode to evaluation
        model.eval()
        
        for batch in infer_dataloader:
            claim_texts, evidence_texts, labels, batch_claim_ids, batch_evidence_ids = batch
            texts = list(zip(claim_texts, evidence_texts))
            
            # Forward
            output, logits, seq = model(
                texts=texts,
                normalize_text=True,
                max_length=512,
                dropout=None
            )
            
            # Prediction
            predicted = torch.argmax(output, dim=1).cpu()
            
            # Apply label prediction rules, default is NEI
            if 2 in predicted and 0 in predicted:
                predicted_class = LABEL_MAP.get(3) # DISPUTED
            
            elif 2 in predicted:
                predicted_class = LABEL_MAP.get(2) # SUPPORTS
                
            elif 0 in predicted:
                predicted_class = LABEL_MAP.get(0) # REFUTES
            else:
                predicted_class = LABEL_MAP.get(1) # NOT_ENOUGH_INFO
                

            claim_predictions[claim_id]["claim_label"] = predicted_class
            print(f"class={predicted_class}, labels={predicted}")
            continue
    
        continue
    
    with open(save_path, mode="w") as f:
        json.dump(obj=claim_predictions, fp=f)
        print(f"saved to: {save_path}")
        
    return

## Dev inference

In [None]:
inference_run(
    claims_retrieval_path=RETRIEVAL_PATH.with_name("model_05_bert_cross_encoder_retrieval_2023_05_11_08_52_test_shortlist_max_1000.json"),
    evidence_path=DATA_PATH.with_name("evidence.json"),
    save_path=OUTPUT_PATH.with_name("test-claims-predictions.json"),
    batch_size=16
)