# Model 06 inference


## Setup

### Working directory

In [1]:
# Change the working directory to project root
from pathlib import Path
import os
ROOT_DIR = Path.cwd()
while not ROOT_DIR.joinpath("src").exists():
    ROOT_DIR = ROOT_DIR.parent
os.chdir(ROOT_DIR)

### Dependencies

In [2]:
# Imports and dependencies
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torcheval.metrics import BinaryAccuracy, BinaryF1Score

from src.torch_utils import get_torch_device
import json
from dataclasses import dataclass
from typing import List, Union, Tuple, Dict
from tqdm import tqdm
import random
import numpy as np
from datetime import datetime
from sklearn.model_selection import ParameterGrid

from src.model_05 import BertCrossEncoderClassifier
from src.data import RetrievalWithShortlistDataset, RetrievalDevEvalDataset, \
    InferenceClaims, RetrievedInferenceClaims, LabelClassificationDataset
from src.logger import SimpleLogger

TORCH_DEVICE = get_torch_device()

  from .autonotebook import tqdm as notebook_tqdm


Torch device is 'mps'


### File paths

In [3]:
MODEL_PATH = ROOT_DIR.joinpath("./result/models/*")
DATA_PATH = ROOT_DIR.joinpath("./data/*")
LOG_PATH = ROOT_DIR.joinpath("./result/logs/*")
RETRIEVAL_PATH = ROOT_DIR.joinpath("./result/pipeline/retrieval_classif/*")
OUTPUT_PATH = ROOT_DIR.joinpath("./result/pipeline/final_classif/*")

run_time = datetime.now().strftime('%Y_%m_%d_%H_%M')

## Load model

Use a blank pre-trained

In [4]:
# model = BertCrossEncoderClassifier(
#     pretrained_name="bert-base-uncased",
#     n_classes=3,
#     device=TORCH_DEVICE
# )

Or load fine-tuned

In [5]:
MODEL_SAVE_PATH = MODEL_PATH.with_name("model_06_bert_base_uncased_cross_encoder_label_2023_05_08_15_13.pth")
with open(MODEL_SAVE_PATH, mode="rb") as f:
    model = torch.load(f, map_location=TORCH_DEVICE)

## Inference run code

In [6]:
def inference_run(
    claims_retrieval_path:Path,
    evidence_path:Path,
    save_path:Path = None,
    batch_size:int = 64
):
    # Label map
    LABEL_MAP = {
        0: "REFUTES",
        1: "NOT_ENOUGH_INFO",
        2: "SUPPORTS",
        3: "DISPUTED"
    }
    
    # Generate claims iterations
    inference_claims = RetrievedInferenceClaims(claims_path=claims_retrieval_path)
    
    # Cumulator
    with open(claims_retrieval_path, mode="r") as f:
        claim_predictions = json.load(f)
    
    for claim_id, claim_text, evidence_ids in tqdm(inference_claims, desc="claims"):
        
        # Generate claims-evidence inference interations
        infer_data = LabelClassificationDataset(
            claim_id=claim_id,
            claims_paths=[claims_retrieval_path],
            evidence_path=evidence_path,
            training=False,
            verbose=False
        )
        infer_dataloader = DataLoader(
            dataset=infer_data,
            shuffle=False,
            batch_size=batch_size
        )
        
        print(f"running inference for {claim_id} n={len(infer_data)}")
    
        # Set model mode to evaluation
        model.eval()
        
        for batch in infer_dataloader:
            claim_texts, evidence_texts, labels, batch_claim_ids, batch_evidence_ids = batch
            texts = list(zip(claim_texts, evidence_texts))
            
            # Forward
            output, logits, seq = model(
                texts=texts,
                normalize_text=True,
                max_length=512,
                dropout=None
            )
            
            # Prediction
            predicted = torch.argmax(output, dim=1).cpu()
            
            # Apply label prediction rules, default is NEI
            if 2 in predicted and 0 in predicted:
                predicted_class = LABEL_MAP.get(3) # DISPUTED
            
            elif 2 in predicted:
                predicted_class = LABEL_MAP.get(2) # SUPPORTS
                
            elif 0 in predicted:
                predicted_class = LABEL_MAP.get(0) # REFUTES
            else:
                predicted_class = LABEL_MAP.get(1) # NOT_ENOUGH_INFO

            claim_predictions[claim_id]["claim_label"] = predicted_class
            print(f"class={predicted_class}, labels={predicted}")
            continue
    
        continue
    
    with open(save_path, mode="w") as f:
        json.dump(obj=claim_predictions, fp=f)
        print(f"saved to: {save_path}")
        
    return

## Dev inference

In [7]:
inference_run(
    claims_retrieval_path=RETRIEVAL_PATH.with_name("model_05_bert_base_2023_05_08_17_06_dev_shortlist_max_500_no_rel.json"),
    evidence_path=DATA_PATH.with_name("evidence.json"),
    save_path=OUTPUT_PATH.with_name("model_05_bert_base_2023_05_08_17_06_dev_shortlist_max_500_no_rel_final.json"),
    batch_size=24
)

loaded inference claims n=154


claims:   0%|          | 0/154 [00:00<?, ?it/s]

Torch device is 'mps'
generated dataset n=5
running inference for claim-752 n=5


claims:   1%|          | 1/154 [00:00<02:31,  1.01it/s]

class=SUPPORTS, labels=tensor([2, 2, 1, 1, 1])
Torch device is 'mps'


claims:   1%|▏         | 2/154 [00:01<02:11,  1.15it/s]

generated dataset n=5
running inference for claim-375 n=5
class=SUPPORTS, labels=tensor([1, 2, 1, 2, 2])
Torch device is 'mps'


claims:   2%|▏         | 3/154 [00:02<01:59,  1.26it/s]

generated dataset n=5
running inference for claim-1266 n=5
class=SUPPORTS, labels=tensor([2, 1, 2, 2, 2])
Torch device is 'mps'


claims:   3%|▎         | 4/154 [00:03<01:55,  1.30it/s]

generated dataset n=5
running inference for claim-871 n=5
class=SUPPORTS, labels=tensor([2, 2, 1, 1, 1])
Torch device is 'mps'


claims:   3%|▎         | 5/154 [00:04<01:55,  1.29it/s]

generated dataset n=5
running inference for claim-2164 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:   4%|▍         | 6/154 [00:04<01:52,  1.32it/s]

generated dataset n=5
running inference for claim-1607 n=5
class=SUPPORTS, labels=tensor([2, 1, 1, 1, 2])
Torch device is 'mps'


claims:   5%|▍         | 7/154 [00:05<01:49,  1.34it/s]

generated dataset n=5
running inference for claim-761 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 1, 2])
Torch device is 'mps'


claims:   5%|▌         | 8/154 [00:06<01:47,  1.36it/s]

generated dataset n=5
running inference for claim-1718 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 1, 1])
Torch device is 'mps'


claims:   6%|▌         | 9/154 [00:06<01:45,  1.38it/s]

generated dataset n=5
running inference for claim-1273 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:   6%|▋         | 10/154 [00:07<01:43,  1.39it/s]

generated dataset n=5
running inference for claim-1786 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:   7%|▋         | 11/154 [00:08<01:42,  1.40it/s]

generated dataset n=5
running inference for claim-2796 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 1, 2])
Torch device is 'mps'


claims:   8%|▊         | 12/154 [00:08<01:41,  1.39it/s]

generated dataset n=5
running inference for claim-2580 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:   8%|▊         | 13/154 [00:09<01:40,  1.41it/s]

generated dataset n=5
running inference for claim-1219 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:   9%|▉         | 14/154 [00:10<01:40,  1.39it/s]

generated dataset n=5
running inference for claim-75 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  10%|▉         | 15/154 [00:11<01:39,  1.40it/s]

generated dataset n=5
running inference for claim-2813 n=5
class=SUPPORTS, labels=tensor([1, 1, 1, 2, 1])
Torch device is 'mps'


claims:  10%|█         | 16/154 [00:11<01:38,  1.40it/s]

generated dataset n=5
running inference for claim-2335 n=5
class=SUPPORTS, labels=tensor([1, 2, 1, 1, 2])
Torch device is 'mps'


claims:  11%|█         | 17/154 [00:12<01:37,  1.40it/s]

generated dataset n=5
running inference for claim-161 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  12%|█▏        | 18/154 [00:13<01:37,  1.40it/s]

generated dataset n=5
running inference for claim-2243 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  12%|█▏        | 19/154 [00:13<01:35,  1.41it/s]

generated dataset n=5
running inference for claim-1256 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  13%|█▎        | 20/154 [00:14<01:35,  1.40it/s]

generated dataset n=5
running inference for claim-506 n=5
class=SUPPORTS, labels=tensor([2, 2, 1, 2, 2])
Torch device is 'mps'


claims:  14%|█▎        | 21/154 [00:15<01:34,  1.40it/s]

generated dataset n=5
running inference for claim-369 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  14%|█▍        | 22/154 [00:16<01:34,  1.39it/s]

generated dataset n=5
running inference for claim-2184 n=5
class=SUPPORTS, labels=tensor([2, 1, 1, 1, 1])
Torch device is 'mps'


claims:  15%|█▍        | 23/154 [00:16<01:32,  1.41it/s]

generated dataset n=5
running inference for claim-1057 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  16%|█▌        | 24/154 [00:17<01:32,  1.40it/s]

generated dataset n=5
running inference for claim-104 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  16%|█▌        | 25/154 [00:18<01:31,  1.41it/s]

generated dataset n=5
running inference for claim-1975 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  17%|█▋        | 26/154 [00:18<01:31,  1.40it/s]

generated dataset n=5
running inference for claim-139 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 1])
Torch device is 'mps'


claims:  18%|█▊        | 27/154 [00:19<01:30,  1.41it/s]

generated dataset n=5
running inference for claim-2062 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  18%|█▊        | 28/154 [00:20<01:29,  1.40it/s]

generated dataset n=5
running inference for claim-1160 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  19%|█▉        | 29/154 [00:21<01:28,  1.41it/s]

generated dataset n=5
running inference for claim-2679 n=5
class=SUPPORTS, labels=tensor([2, 1, 1, 2, 1])
Torch device is 'mps'


claims:  19%|█▉        | 30/154 [00:21<01:28,  1.41it/s]

generated dataset n=5
running inference for claim-2662 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  20%|██        | 31/154 [00:22<01:26,  1.42it/s]

generated dataset n=5
running inference for claim-1490 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  21%|██        | 32/154 [00:23<01:26,  1.40it/s]

generated dataset n=5
running inference for claim-2768 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  21%|██▏       | 33/154 [00:23<01:25,  1.41it/s]

generated dataset n=5
running inference for claim-2168 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 1, 1])
Torch device is 'mps'


claims:  22%|██▏       | 34/154 [00:24<01:25,  1.40it/s]

generated dataset n=5
running inference for claim-785 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  23%|██▎       | 35/154 [00:25<01:24,  1.41it/s]

generated dataset n=5
running inference for claim-2426 n=5
class=SUPPORTS, labels=tensor([2, 1, 1, 2, 2])
Torch device is 'mps'


claims:  23%|██▎       | 36/154 [00:26<01:24,  1.39it/s]

generated dataset n=5
running inference for claim-1292 n=5
class=SUPPORTS, labels=tensor([1, 2, 1, 1, 1])
Torch device is 'mps'


claims:  24%|██▍       | 37/154 [00:26<01:23,  1.41it/s]

generated dataset n=5
running inference for claim-993 n=5
class=SUPPORTS, labels=tensor([1, 2, 2, 1, 1])
Torch device is 'mps'


claims:  25%|██▍       | 38/154 [00:27<01:23,  1.39it/s]

generated dataset n=5
running inference for claim-2593 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  25%|██▌       | 39/154 [00:28<01:21,  1.41it/s]

generated dataset n=5
running inference for claim-1567 n=5
class=SUPPORTS, labels=tensor([1, 1, 1, 1, 2])
Torch device is 'mps'


claims:  26%|██▌       | 40/154 [00:28<01:20,  1.41it/s]

generated dataset n=5
running inference for claim-1834 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  27%|██▋       | 41/154 [00:29<01:19,  1.41it/s]

generated dataset n=5
running inference for claim-856 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 1])
Torch device is 'mps'


claims:  27%|██▋       | 42/154 [00:30<01:19,  1.40it/s]

generated dataset n=5
running inference for claim-540 n=5
class=SUPPORTS, labels=tensor([2, 1, 1, 1, 1])
Torch device is 'mps'


claims:  28%|██▊       | 43/154 [00:31<01:18,  1.41it/s]

generated dataset n=5
running inference for claim-757 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  29%|██▊       | 44/154 [00:31<01:18,  1.40it/s]

generated dataset n=5
running inference for claim-1407 n=5
class=NOT_ENOUGH_INFO, labels=tensor([1, 1, 1, 1, 1])
Torch device is 'mps'


claims:  29%|██▉       | 45/154 [00:32<01:17,  1.41it/s]

generated dataset n=5
running inference for claim-3070 n=5
class=NOT_ENOUGH_INFO, labels=tensor([1, 1, 1, 1, 1])
Torch device is 'mps'


claims:  30%|██▉       | 46/154 [00:33<01:16,  1.41it/s]

generated dataset n=5
running inference for claim-1745 n=5
class=NOT_ENOUGH_INFO, labels=tensor([1, 1, 1, 1, 1])
Torch device is 'mps'


claims:  31%|███       | 47/154 [00:33<01:16,  1.41it/s]

generated dataset n=5
running inference for claim-1515 n=5
class=SUPPORTS, labels=tensor([1, 1, 1, 2, 1])
Torch device is 'mps'


claims:  31%|███       | 48/154 [00:34<01:16,  1.39it/s]

generated dataset n=5
running inference for claim-1519 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  32%|███▏      | 49/154 [00:35<01:14,  1.40it/s]

generated dataset n=5
running inference for claim-3069 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  32%|███▏      | 50/154 [00:36<01:14,  1.40it/s]

generated dataset n=5
running inference for claim-677 n=5
class=SUPPORTS, labels=tensor([1, 1, 2, 1, 1])
Torch device is 'mps'


claims:  33%|███▎      | 51/154 [00:36<01:13,  1.40it/s]

generated dataset n=5
running inference for claim-765 n=5
class=SUPPORTS, labels=tensor([1, 2, 2, 2, 2])
Torch device is 'mps'


claims:  34%|███▍      | 52/154 [00:37<01:12,  1.40it/s]

generated dataset n=5
running inference for claim-2275 n=5
class=SUPPORTS, labels=tensor([1, 2, 1, 1, 1])
Torch device is 'mps'


claims:  34%|███▍      | 53/154 [00:38<01:11,  1.42it/s]

generated dataset n=5
running inference for claim-1113 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 1])
Torch device is 'mps'


claims:  35%|███▌      | 54/154 [00:38<01:10,  1.41it/s]

generated dataset n=5
running inference for claim-2611 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  36%|███▌      | 55/154 [00:39<01:09,  1.42it/s]

generated dataset n=5
running inference for claim-2060 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  36%|███▋      | 56/154 [00:40<01:09,  1.41it/s]

generated dataset n=5
running inference for claim-2326 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  37%|███▋      | 57/154 [00:40<01:08,  1.42it/s]

generated dataset n=5
running inference for claim-1087 n=5
class=SUPPORTS, labels=tensor([1, 2, 1, 1, 1])
Torch device is 'mps'


claims:  38%|███▊      | 58/154 [00:41<01:07,  1.42it/s]

generated dataset n=5
running inference for claim-2867 n=5
class=SUPPORTS, labels=tensor([2, 1, 2, 1, 2])
Torch device is 'mps'


claims:  38%|███▊      | 59/154 [00:42<01:06,  1.42it/s]

generated dataset n=5
running inference for claim-2300 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 1, 1])
Torch device is 'mps'


claims:  39%|███▉      | 60/154 [00:43<01:06,  1.41it/s]

generated dataset n=5
running inference for claim-2250 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 1, 2])
Torch device is 'mps'


claims:  40%|███▉      | 61/154 [00:43<01:06,  1.40it/s]

generated dataset n=5
running inference for claim-2429 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  40%|████      | 62/154 [00:44<01:05,  1.40it/s]

generated dataset n=5
running inference for claim-3051 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  41%|████      | 63/154 [00:45<01:04,  1.41it/s]

generated dataset n=5
running inference for claim-1549 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  42%|████▏     | 64/154 [00:45<01:04,  1.39it/s]

generated dataset n=5
running inference for claim-261 n=5
class=SUPPORTS, labels=tensor([2, 2, 1, 2, 1])
Torch device is 'mps'


claims:  42%|████▏     | 65/154 [00:46<01:03,  1.40it/s]

generated dataset n=5
running inference for claim-2230 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  43%|████▎     | 66/154 [00:47<01:03,  1.39it/s]

generated dataset n=5
running inference for claim-2579 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  44%|████▎     | 67/154 [00:48<01:02,  1.40it/s]

generated dataset n=5
running inference for claim-1416 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  44%|████▍     | 68/154 [00:48<01:02,  1.38it/s]

generated dataset n=5
running inference for claim-2497 n=5
class=SUPPORTS, labels=tensor([1, 2, 1, 1, 1])
Torch device is 'mps'


claims:  45%|████▍     | 69/154 [00:49<01:00,  1.40it/s]

generated dataset n=5
running inference for claim-811 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  45%|████▌     | 70/154 [00:50<01:00,  1.39it/s]

generated dataset n=5
running inference for claim-1896 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 1])
Torch device is 'mps'


claims:  46%|████▌     | 71/154 [00:50<00:59,  1.40it/s]

generated dataset n=5
running inference for claim-2819 n=5
class=SUPPORTS, labels=tensor([1, 2, 1, 2, 1])
Torch device is 'mps'


claims:  47%|████▋     | 72/154 [00:51<00:58,  1.40it/s]

generated dataset n=5
running inference for claim-2643 n=5
class=SUPPORTS, labels=tensor([1, 2, 2, 2, 2])
Torch device is 'mps'


claims:  47%|████▋     | 73/154 [00:52<00:57,  1.41it/s]

generated dataset n=5
running inference for claim-1775 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 1])
Torch device is 'mps'


claims:  48%|████▊     | 74/154 [00:53<00:57,  1.40it/s]

generated dataset n=5
running inference for claim-316 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  49%|████▊     | 75/154 [00:53<00:56,  1.41it/s]

generated dataset n=5
running inference for claim-896 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  49%|████▉     | 76/154 [00:54<00:55,  1.40it/s]

generated dataset n=5
running inference for claim-331 n=5
class=SUPPORTS, labels=tensor([1, 1, 1, 2, 2])
Torch device is 'mps'


claims:  50%|█████     | 77/154 [00:55<00:54,  1.40it/s]

generated dataset n=5
running inference for claim-2574 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  51%|█████     | 78/154 [00:56<00:54,  1.39it/s]

generated dataset n=5
running inference for claim-342 n=5
class=SUPPORTS, labels=tensor([1, 1, 2, 2, 1])
Torch device is 'mps'


claims:  51%|█████▏    | 79/154 [00:56<00:53,  1.40it/s]

generated dataset n=5
running inference for claim-2034 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  52%|█████▏    | 80/154 [00:57<00:52,  1.40it/s]

generated dataset n=5
running inference for claim-578 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 1, 1])
Torch device is 'mps'


claims:  53%|█████▎    | 81/154 [00:58<00:51,  1.41it/s]

generated dataset n=5
running inference for claim-976 n=5
class=SUPPORTS, labels=tensor([2, 1, 1, 1, 1])
Torch device is 'mps'


claims:  53%|█████▎    | 82/154 [00:58<00:51,  1.40it/s]

generated dataset n=5
running inference for claim-1097 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  54%|█████▍    | 83/154 [00:59<00:50,  1.40it/s]

generated dataset n=5
running inference for claim-609 n=5
class=SUPPORTS, labels=tensor([2, 2, 1, 2, 2])
Torch device is 'mps'


claims:  55%|█████▍    | 84/154 [01:00<00:50,  1.40it/s]

generated dataset n=5
running inference for claim-173 n=5
class=SUPPORTS, labels=tensor([2, 1, 1, 1, 2])
Torch device is 'mps'


claims:  55%|█████▌    | 85/154 [01:00<00:49,  1.40it/s]

generated dataset n=5
running inference for claim-1222 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  56%|█████▌    | 86/154 [01:01<00:48,  1.40it/s]

generated dataset n=5
running inference for claim-2441 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 1])
Torch device is 'mps'


claims:  56%|█████▋    | 87/154 [01:02<00:47,  1.41it/s]

generated dataset n=5
running inference for claim-756 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  57%|█████▋    | 88/154 [01:03<00:47,  1.40it/s]

generated dataset n=5
running inference for claim-2577 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  58%|█████▊    | 89/154 [01:03<00:46,  1.40it/s]

generated dataset n=5
running inference for claim-2890 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  58%|█████▊    | 90/154 [01:04<00:45,  1.39it/s]

generated dataset n=5
running inference for claim-2478 n=5
class=SUPPORTS, labels=tensor([1, 1, 2, 1, 1])
Torch device is 'mps'


claims:  59%|█████▉    | 91/154 [01:05<00:45,  1.39it/s]

generated dataset n=5
running inference for claim-2399 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  60%|█████▉    | 92/154 [01:06<00:44,  1.38it/s]

generated dataset n=5
running inference for claim-3091 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  60%|██████    | 93/154 [01:06<00:43,  1.39it/s]

generated dataset n=5
running inference for claim-141 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  61%|██████    | 94/154 [01:07<00:43,  1.38it/s]

generated dataset n=5
running inference for claim-1933 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  62%|██████▏   | 95/154 [01:08<00:42,  1.39it/s]

generated dataset n=5
running inference for claim-1689 n=5
class=SUPPORTS, labels=tensor([2, 2, 1, 2, 2])
Torch device is 'mps'


claims:  62%|██████▏   | 96/154 [01:08<00:42,  1.37it/s]

generated dataset n=5
running inference for claim-443 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 1])
Torch device is 'mps'


claims:  63%|██████▎   | 97/154 [01:09<00:41,  1.38it/s]

generated dataset n=5
running inference for claim-2037 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 1])
Torch device is 'mps'


claims:  64%|██████▎   | 98/154 [01:10<00:40,  1.37it/s]

generated dataset n=5
running inference for claim-1734 n=5
class=SUPPORTS, labels=tensor([1, 1, 1, 2, 1])
Torch device is 'mps'


claims:  64%|██████▍   | 99/154 [01:11<00:39,  1.39it/s]

generated dataset n=5
running inference for claim-2093 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  65%|██████▍   | 100/154 [01:11<00:39,  1.38it/s]

generated dataset n=5
running inference for claim-1400 n=5
class=SUPPORTS, labels=tensor([1, 1, 1, 1, 2])
Torch device is 'mps'


claims:  66%|██████▌   | 101/154 [01:12<00:38,  1.39it/s]

generated dataset n=5
running inference for claim-1638 n=5
class=SUPPORTS, labels=tensor([2, 1, 2, 1, 2])
Torch device is 'mps'


claims:  66%|██████▌   | 102/154 [01:13<00:37,  1.39it/s]

generated dataset n=5
running inference for claim-3075 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  67%|██████▋   | 103/154 [01:13<00:36,  1.39it/s]

generated dataset n=5
running inference for claim-38 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  68%|██████▊   | 104/154 [01:14<00:36,  1.38it/s]

generated dataset n=5
running inference for claim-1643 n=5
class=SUPPORTS, labels=tensor([2, 1, 1, 1, 2])
Torch device is 'mps'


claims:  68%|██████▊   | 105/154 [01:15<00:35,  1.39it/s]

generated dataset n=5
running inference for claim-1259 n=5
class=SUPPORTS, labels=tensor([1, 2, 2, 2, 2])
Torch device is 'mps'


claims:  69%|██████▉   | 106/154 [01:16<00:34,  1.38it/s]

generated dataset n=5
running inference for claim-1605 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  69%|██████▉   | 107/154 [01:16<00:33,  1.38it/s]

generated dataset n=5
running inference for claim-1711 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  70%|███████   | 108/154 [01:17<00:34,  1.33it/s]

generated dataset n=5
running inference for claim-2236 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  71%|███████   | 109/154 [01:18<00:34,  1.30it/s]

generated dataset n=5
running inference for claim-1040 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 1])
Torch device is 'mps'


claims:  71%|███████▏  | 110/154 [01:19<00:34,  1.26it/s]

generated dataset n=5
running inference for claim-392 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  72%|███████▏  | 111/154 [01:20<00:33,  1.30it/s]

generated dataset n=5
running inference for claim-368 n=5
class=SUPPORTS, labels=tensor([2, 1, 2, 2, 2])
Torch device is 'mps'


claims:  73%|███████▎  | 112/154 [01:20<00:31,  1.32it/s]

generated dataset n=5
running inference for claim-559 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  73%|███████▎  | 113/154 [01:21<00:30,  1.34it/s]

generated dataset n=5
running inference for claim-2583 n=5
class=SUPPORTS, labels=tensor([2, 1, 2, 2, 2])
Torch device is 'mps'


claims:  74%|███████▍  | 114/154 [01:22<00:29,  1.35it/s]

generated dataset n=5
running inference for claim-2609 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  75%|███████▍  | 115/154 [01:22<00:28,  1.37it/s]

generated dataset n=5
running inference for claim-492 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  75%|███████▌  | 116/154 [01:23<00:27,  1.36it/s]

generated dataset n=5
running inference for claim-1420 n=5
class=SUPPORTS, labels=tensor([2, 1, 2, 2, 1])
Torch device is 'mps'


claims:  76%|███████▌  | 117/154 [01:24<00:26,  1.38it/s]

generated dataset n=5
running inference for claim-1089 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 1])
Torch device is 'mps'


claims:  77%|███████▋  | 118/154 [01:25<00:26,  1.37it/s]

generated dataset n=5
running inference for claim-1467 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  77%|███████▋  | 119/154 [01:25<00:26,  1.33it/s]

generated dataset n=5
running inference for claim-444 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  78%|███████▊  | 120/154 [01:26<00:26,  1.29it/s]

generated dataset n=5
running inference for claim-803 n=5
class=SUPPORTS, labels=tensor([2, 1, 1, 1, 2])
Torch device is 'mps'


claims:  79%|███████▊  | 121/154 [01:27<00:25,  1.31it/s]

generated dataset n=5
running inference for claim-1668 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  79%|███████▉  | 122/154 [01:28<00:24,  1.30it/s]

generated dataset n=5
running inference for claim-742 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  80%|███████▉  | 123/154 [01:29<00:24,  1.27it/s]

generated dataset n=5
running inference for claim-846 n=5
class=SUPPORTS, labels=tensor([2, 2, 1, 2, 2])
Torch device is 'mps'


claims:  81%|████████  | 124/154 [01:29<00:23,  1.27it/s]

generated dataset n=5
running inference for claim-2119 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 1, 2])
Torch device is 'mps'


claims:  81%|████████  | 125/154 [01:30<00:22,  1.30it/s]

generated dataset n=5
running inference for claim-1167 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 1, 2])
Torch device is 'mps'


claims:  82%|████████▏ | 126/154 [01:31<00:21,  1.31it/s]

generated dataset n=5
running inference for claim-623 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  82%|████████▏ | 127/154 [01:32<00:20,  1.33it/s]

generated dataset n=5
running inference for claim-2882 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 1])
Torch device is 'mps'


claims:  83%|████████▎ | 128/154 [01:32<00:19,  1.33it/s]

generated dataset n=5
running inference for claim-1698 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  84%|████████▍ | 129/154 [01:33<00:18,  1.35it/s]

generated dataset n=5
running inference for claim-181 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  84%|████████▍ | 130/154 [01:34<00:17,  1.35it/s]

generated dataset n=5
running inference for claim-281 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  85%|████████▌ | 131/154 [01:35<00:16,  1.36it/s]

generated dataset n=5
running inference for claim-2809 n=5
class=SUPPORTS, labels=tensor([1, 2, 2, 1, 1])
Torch device is 'mps'


claims:  86%|████████▌ | 132/154 [01:35<00:16,  1.35it/s]

generated dataset n=5
running inference for claim-1928 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  86%|████████▋ | 133/154 [01:36<00:15,  1.36it/s]

generated dataset n=5
running inference for claim-2787 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  87%|████████▋ | 134/154 [01:37<00:14,  1.36it/s]

generated dataset n=5
running inference for claim-478 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  88%|████████▊ | 135/154 [01:38<00:13,  1.36it/s]

generated dataset n=5
running inference for claim-988 n=5
class=SUPPORTS, labels=tensor([1, 1, 2, 2, 2])
Torch device is 'mps'


claims:  88%|████████▊ | 136/154 [01:38<00:13,  1.35it/s]

generated dataset n=5
running inference for claim-266 n=5
class=SUPPORTS, labels=tensor([2, 1, 1, 2, 2])
Torch device is 'mps'


claims:  89%|████████▉ | 137/154 [01:39<00:12,  1.36it/s]

generated dataset n=5
running inference for claim-2282 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  90%|████████▉ | 138/154 [01:40<00:11,  1.36it/s]

generated dataset n=5
running inference for claim-2895 n=5
class=SUPPORTS, labels=tensor([2, 1, 2, 1, 1])
Torch device is 'mps'


claims:  90%|█████████ | 139/154 [01:40<00:11,  1.36it/s]

generated dataset n=5
running inference for claim-349 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  91%|█████████ | 140/154 [01:41<00:10,  1.35it/s]

generated dataset n=5
running inference for claim-2101 n=5
class=SUPPORTS, labels=tensor([1, 2, 2, 2, 2])
Torch device is 'mps'


claims:  92%|█████████▏| 141/154 [01:42<00:09,  1.36it/s]

generated dataset n=5
running inference for claim-897 n=5
class=SUPPORTS, labels=tensor([2, 1, 2, 2, 2])
Torch device is 'mps'


claims:  92%|█████████▏| 142/154 [01:43<00:08,  1.36it/s]

generated dataset n=5
running inference for claim-3063 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  93%|█████████▎| 143/154 [01:43<00:08,  1.37it/s]

generated dataset n=5
running inference for claim-386 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  94%|█████████▎| 144/154 [01:44<00:07,  1.35it/s]

generated dataset n=5
running inference for claim-2691 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  94%|█████████▍| 145/154 [01:45<00:06,  1.36it/s]

generated dataset n=5
running inference for claim-530 n=5
class=SUPPORTS, labels=tensor([2, 1, 1, 1, 1])
Torch device is 'mps'


claims:  95%|█████████▍| 146/154 [01:46<00:05,  1.36it/s]

generated dataset n=5
running inference for claim-2979 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  95%|█████████▌| 147/154 [01:46<00:05,  1.36it/s]

generated dataset n=5
running inference for claim-665 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  96%|█████████▌| 148/154 [01:47<00:04,  1.36it/s]

generated dataset n=5
running inference for claim-199 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  97%|█████████▋| 149/154 [01:48<00:03,  1.37it/s]

generated dataset n=5
running inference for claim-490 n=5
class=SUPPORTS, labels=tensor([1, 2, 2, 2, 1])
Torch device is 'mps'


claims:  97%|█████████▋| 150/154 [01:49<00:02,  1.36it/s]

generated dataset n=5
running inference for claim-2400 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  98%|█████████▊| 151/154 [01:49<00:02,  1.37it/s]

generated dataset n=5
running inference for claim-204 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  99%|█████████▊| 152/154 [01:50<00:01,  1.36it/s]

generated dataset n=5
running inference for claim-1426 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
Torch device is 'mps'


claims:  99%|█████████▉| 153/154 [01:51<00:00,  1.36it/s]

generated dataset n=5
running inference for claim-698 n=5
class=SUPPORTS, labels=tensor([2, 1, 1, 2, 2])
Torch device is 'mps'


claims: 100%|██████████| 154/154 [01:51<00:00,  1.38it/s]

generated dataset n=5
running inference for claim-1021 n=5
class=SUPPORTS, labels=tensor([2, 2, 2, 2, 2])
saved to: /Users/johnsonzhou/git/comp90042-project/result/pipeline/final_classif/model_05_bert_base_2023_05_08_17_06_dev_shortlist_max_500_no_rel_final.json



