In [1]:
%load_ext autoreload
%autoreload 2

# Init

In [107]:
import json
import sys
import re
import pickle as pkl
from pathlib import Path
from functools import partial
from collections import defaultdict, namedtuple
sys.path.append("../src")

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from joblib import Parallel, delayed
from tqdm import tqdm
from sklearn.metrics import classification_report, precision_recall_fscore_support, accuracy_score, confusion_matrix

import constants
from scoring import scorer
from gen.util import read_data, write_jsonl
from rte import aggregate

In [49]:
data_sent_micro_p = Path("/users/k21190024/study/fact-check-transfer-learning/scratch/dumps/bert-data-sent-evidence")

pred_p = Path("/users/k21190024/study/fact-check-transfer-learning/scratch/thesis/predictions")
result_p = Path("/users/k21190024/study/fact-check-transfer-learning/scratch/thesis/results/metrics")

fever_actual_p = Path("/users/k21190024/study/fact-check-transfer-learning/scratch/dumps/fever-nei-sampled")
cfever_actual_p = Path("/users/k21190024/study/fact-check-transfer-learning/scratch/dumps/climatefever-neg-sampled")
cfeverpure_actual_p = Path("/users/k21190024/study/fact-check-transfer-learning/scratch/dumps/climatefeverpure-neg-sampled")
sf_actual_p = Path("/users/k21190024/study/fact-check-transfer-learning/scratch/dumps/scifact-nei-sampled")

actual_pls = (
    list(fever_actual_p.glob("*.n5.nei.jsonl")) 
    + list(cfever_actual_p.glob("*.n5.nei.jsonl")) 
    + list(cfeverpure_actual_p.glob("*.n5.nei.jsonl")) 
    + list(sf_actual_p.glob("*.n5.nei.jsonl"))
)
actual_pls = [p for p in actual_pls if "train" not in p.stem]

# Concatenate Evidences

In [52]:
concat_evi_res = {
    "fever": defaultdict(list), 
    "climatefeverpure": defaultdict(list), 
    "climatefever": defaultdict(list),
    "scifact": defaultdict(list),
    "scifactpipeline": defaultdict(list),
}

for actual_dp in tqdm(actual_pls):
    dataset = actual_dp.parent.stem.split("-")[0]
    split = actual_dp.stem.split(".")[0]
    actual_data = read_data(actual_dp)
    if dataset == "climatefeverpure":
        score_obj = scorer.ClimateFEVERScorer
    elif dataset == "climatefever":
        score_obj = scorer.ClimateFEVERScorer
    else:
        score_obj = scorer.FEVERScorer
        
    for pp in pred_p.joinpath("doc", dataset).glob(f"*.{split}.jsonl"):
        preds = read_data(pp)
        concat_evi_res[dataset][split].append(score_obj(
            actual_data=actual_data, 
            prediction_data=preds, 
            score_name=pp.stem, 
            oracle_rte=False, 
            oracle_ir=True, 
            max_evidence=99999
        ))
    if dataset == "scifact":
        dataset = "scifactpipeline"
        for pp in pred_p.joinpath("doc", dataset).glob(f"*.{split}.jsonl"):
            preds = read_data(pp)
            concat_evi_res[dataset][split].append(score_obj(
                actual_data=actual_data, 
                prediction_data=preds, 
                score_name=pp.stem, 
                oracle_rte=False, 
                oracle_ir=False, 
                max_evidence=99999
            ))

100%|██████████| 9/9 [00:29<00:00,  3.32s/it]


In [56]:
with result_p.joinpath("concatenate_evidences_metrics.pkl").open("wb") as fn:
    pkl.dump(concat_evi_res, fn)

In [7]:
for score in concat_evi_res["fever"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.9994    0.9961    0.9977      3333
        REFUTES     0.9643    0.9163    0.9397      3333
       SUPPORTS     0.9194    0.9682    0.9432      3333

       accuracy                         0.9602      9999
      macro avg     0.9610    0.9602    0.9602      9999
   weighted avg     0.9610    0.9602    0.9602      9999

===
fever-da.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.9368    0.9421    0.9394      3333
        REFUTES     0.8523    0.5299    0.6535      3333
       SUPPORTS     0.6501    0.8923    0.7521      3333

       accuracy                         0.7881      9999
      macro avg     0.8130    0.7881    0.7817      9999
   weighted avg     0.8130    0.7881    0.7817      9999

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.8520    0.93

In [8]:
for score in concat_evi_res["climatefeverpure"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.8767    0.6737    0.7619        95
        REFUTES     0.4638    0.6275    0.5333        51
       SUPPORTS     0.7500    0.7727    0.7612       132

       accuracy                         0.7122       278
      macro avg     0.6968    0.6913    0.6855       278
   weighted avg     0.7408    0.7122    0.7196       278

===
fever-da.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4171    0.9263    0.5752        95
        REFUTES     0.2593    0.1373    0.1795        51
       SUPPORTS     0.6750    0.2045    0.3140       132

       accuracy                         0.4388       278
      macro avg     0.4504    0.4227    0.3562       278
   weighted avg     0.5106    0.4388    0.3785       278

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.8750    0.88

In [6]:
for score in concat_evi_res["climatefever"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4414    0.6737    0.5333        95
        REFUTES     0.4651    0.3922    0.4255        51
       SUPPORTS     0.7444    0.5076    0.6036       132

       accuracy                         0.5432       278
      macro avg     0.5503    0.5245    0.5208       278
   weighted avg     0.5896    0.5432    0.5469       278

===
fever-da.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.3385    0.9263    0.4958        95
        REFUTES     0.3333    0.0196    0.0370        51
       SUPPORTS     0.5333    0.0606    0.1088       132

       accuracy                         0.3489       278
      macro avg     0.4017    0.3355    0.2139       278
   weighted avg     0.4300    0.3489    0.2279       278

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4264    0.88

In [9]:
for score in concat_evi_res["scifact"]["all"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

climatefeverpure-bert-base-uncased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.8782    0.9880    0.9299       416
        REFUTES     0.6053    0.0970    0.1673       237
       SUPPORTS     0.6733    0.8904    0.7668       456

       accuracy                         0.7574      1109
      macro avg     0.7189    0.6585    0.6213      1109
   weighted avg     0.7356    0.7574    0.6998      1109

===
fever-climatefeverpure-da.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.7030    0.3413    0.4595       416
        REFUTES     0.2935    0.2278    0.2565       237
       SUPPORTS     0.4938    0.7829    0.6056       456

       accuracy                         0.4986      1109
      macro avg     0.4967    0.4507    0.4406      1109
   weighted avg     0.5294    0.4986    0.4762      1109

===
fever-bert-base-uncased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.8832    1.0

In [55]:
for score in concat_evi_res["scifact"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.9333    1.0000    0.9655       112
        REFUTES     0.7805    0.5000    0.6095        64
       SUPPORTS     0.7914    0.8871    0.8365       124

       accuracy                         0.8467       300
      macro avg     0.8351    0.7957    0.8038       300
   weighted avg     0.8420    0.8467    0.8362       300

===
fever-da.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4592    0.9554    0.6203       112
        REFUTES     0.5882    0.1562    0.2469        64
       SUPPORTS     0.7600    0.3065    0.4368       124

       accuracy                         0.5167       300
      macro avg     0.6025    0.4727    0.4347       300
   weighted avg     0.6111    0.5167    0.4648       300

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.9091    0.98

In [53]:
for score in concat_evi_res["scifactpipeline"]["all"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

climatefeverpure-bert-base-uncased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.5051    0.5986    0.5479       416
        REFUTES     0.5833    0.0295    0.0562       237
       SUPPORTS     0.5116    0.6776    0.5830       456

       accuracy                         0.5095      1109
      macro avg     0.5333    0.4352    0.3957      1109
   weighted avg     0.5245    0.5095    0.4572      1109

===
fever-climatefeverpure-da.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.3716    0.1322    0.1950       416
        REFUTES     0.2966    0.1477    0.1972       237
       SUPPORTS     0.4484    0.8289    0.5820       456

       accuracy                         0.4220      1109
      macro avg     0.3722    0.3696    0.3247      1109
   weighted avg     0.3872    0.4220    0.3546      1109

===
fever-bert-base-uncased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.6093    0.2

In [54]:
for score in concat_evi_res["scifactpipeline"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.5273    0.2589    0.3473       112
        REFUTES     0.3770    0.3594    0.3680        64
       SUPPORTS     0.5380    0.7984    0.6429       124

       accuracy                         0.5033       300
      macro avg     0.4808    0.4722    0.4527       300
   weighted avg     0.4997    0.5033    0.4739       300

===
fever-da.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.3715    0.9554    0.5350       112
        REFUTES     0.3333    0.0156    0.0299        64
       SUPPORTS     0.5556    0.0403    0.0752       124

       accuracy                         0.3767       300
      macro avg     0.4201    0.3371    0.2133       300
   weighted avg     0.4394    0.3767    0.2372       300

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4293    0.73

# Sentence

In [91]:
sent_pls = list(data_sent_micro_p.glob("*.dev*")) + list(data_sent_micro_p.glob("*.test*"))
sent_pls = [p for p in sent_pls if "fever-climatefever" not in p.stem and "scifact_test" not in p.stem]

## Micro verdict

In [110]:
sent_micro_res = {
    "fever": defaultdict(list), 
    "climatefeverpure": defaultdict(list), 
    "scifact": defaultdict(list),
    "climatefever": defaultdict(list),
    "scifactpipeline": defaultdict(list),
    "scifactneiignore": defaultdict(list)
}

for actual_dp in tqdm(sent_pls):
    dataset = actual_dp.stem.split(".")[0]
    split = actual_dp.stem.split(".")[1]
    actual_data = read_data(actual_dp)
    if dataset == "scifactpipeline":
        labelled_data = pd.DataFrame(read_data(Path("/users/k21190024/study/fact-check-transfer-learning/scratch/dumps/feverised-scifact/scifact_all.jsonl")))
        labelled_data["claim_id"] = labelled_data["id"].map(lambda x: f"scifact|{x}")
        labelled_data["labels"] = labelled_data["label"].map(constants.LABEL2ID)
        actual_data = pd.DataFrame(actual_data).drop(columns="labels")
        actual_data = actual_data.merge(labelled_data[["claim_id", "labels"]], on="claim_id", how="left")
        actual_data = actual_data.to_dict("records")
        del labelled_data
    elif dataset == "scifact_test":
        continue
    
    for pp in pred_p.joinpath("sent", dataset).glob(f"*.{split}.jsonl"):
        preds = read_data(pp)
        
        assert len(actual_data) == len(preds), f"{actual_dp} ({len(actual_data)}) != {pp} ({len(preds)})"
        
        actual_labels = [constants.ID2LABEL[i["labels"]] for i in actual_data]
        predicted_labels = [i["predicted_label"] for i in preds]

        ma_p, ma_r, ma_f, _ = precision_recall_fscore_support(y_true=actual_labels, y_pred=predicted_labels, average="macro", beta=1.0)
        mi_p, mi_r, mi_f, _ = precision_recall_fscore_support(y_true=actual_labels, y_pred=predicted_labels, average="micro", beta=1.0)
        rte_metrics = {
            "accuracy": accuracy_score(y_true=actual_labels, y_pred=predicted_labels),
            "macro_precision": ma_p,
            "macro_recall": ma_r,
            "macro_f1": ma_f,
            "micro_precision": mi_p,
            "micro_recall": mi_r,
            "micro_f1": mi_f
        }
        
        sent_micro_res[dataset][split].append(scorer.SentenceMicroScorer(
            pp.stem, 
            classification_report(y_true=actual_labels, y_pred=predicted_labels, digits=4),
            classification_report(y_true=actual_labels, y_pred=predicted_labels, output_dict=True),
            rte_metrics,
            confusion_matrix(y_true=actual_labels, y_pred=predicted_labels, labels=["SUPPORTS", "NOT ENOUGH INFO", "REFUTES"])
        ))
        
        if dataset == "scifact":
            predicted_neitrue_labels = [actual if actual == constants.LOOKUP["label"]["nei"] else pred for actual, pred in zip(actual_labels, predicted_labels)]
            ma_p, ma_r, ma_f, _ = precision_recall_fscore_support(y_true=actual_labels, y_pred=predicted_neitrue_labels, average="macro", beta=1.0)
            mi_p, mi_r, mi_f, _ = precision_recall_fscore_support(y_true=actual_labels, y_pred=predicted_neitrue_labels, average="micro", beta=1.0)
            rte_metrics = {
                "accuracy": accuracy_score(y_true=actual_labels, y_pred=predicted_neitrue_labels),
                "macro_precision": ma_p,
                "macro_recall": ma_r,
                "macro_f1": ma_f,
                "micro_precision": mi_p,
                "micro_recall": mi_r,
                "micro_f1": mi_f
            }

            sent_micro_res["scifactneiignore"][split].append(scorer.SentenceMicroScorer(
                pp.stem, 
                classification_report(y_true=actual_labels, y_pred=predicted_neitrue_labels, digits=4),
                classification_report(y_true=actual_labels, y_pred=predicted_neitrue_labels, output_dict=True),
                rte_metrics,
                confusion_matrix(y_true=actual_labels, y_pred=predicted_labels, labels=["SUPPORTS", "NOT ENOUGH INFO", "REFUTES"])
            ))

100%|██████████| 10/10 [00:29<00:00,  2.97s/it]


In [111]:
with result_p.joinpath("sent_micro_verdict_metrics.pkl").open("wb") as fn:
    pkl.dump(sent_micro_res, fn)

In [14]:
for score in sent_micro_res["fever"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.9716    0.9916    0.9815      6666
        REFUTES     0.9651    0.9180    0.9410      6235
       SUPPORTS     0.9402    0.9655    0.9527      6207

       accuracy                         0.9591     19108
      macro avg     0.9590    0.9584    0.9584     19108
   weighted avg     0.9593    0.9591    0.9589     19108

===
fever-climatefever-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.9669    0.9896    0.9781      6666
        REFUTES     0.9605    0.9174    0.9385      6235
       SUPPORTS     0.9422    0.9609    0.9514      6207

       accuracy                         0.9567     19108
      macro avg     0.9565    0.9560    0.9560     19108
   weighted avg     0.9568    0.9567    0.9565     19108

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT EN

In [15]:
for score in sent_micro_res["climatefeverpure"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.6762    0.7432    0.7081       475
        REFUTES     0.4357    0.4621    0.4485       132
       SUPPORTS     0.6906    0.5719    0.6256       320

       accuracy                         0.6440       927
      macro avg     0.6008    0.5924    0.5941       927
   weighted avg     0.6469    0.6440    0.6427       927

===
fever-climatefever-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.5924    0.9516    0.7302       475
        REFUTES     0.6341    0.1970    0.3006       132
       SUPPORTS     0.8374    0.3219    0.4650       320

       accuracy                         0.6268       927
      macro avg     0.6880    0.4901    0.4986       927
   weighted avg     0.6829    0.6268    0.5775       927

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT EN

In [16]:
for score in sent_micro_res["climatefever"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.7874    0.6674    0.7224       938
        REFUTES     0.3280    0.4621    0.3836       132
       SUPPORTS     0.4474    0.5719    0.5021       320

       accuracy                         0.6259      1390
      macro avg     0.5209    0.5671    0.5361      1390
   weighted avg     0.6655    0.6259    0.6395      1390

===
fever-climatefever-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.7344    0.9168    0.8156       938
        REFUTES     0.4727    0.1970    0.2781       132
       SUPPORTS     0.6280    0.3219    0.4256       320

       accuracy                         0.7115      1390
      macro avg     0.6117    0.4786    0.5064      1390
   weighted avg     0.6851    0.7115    0.6747      1390

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT EN

In [17]:
for score in sent_micro_res["scifact"]["all"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefever-xlnet-base-cased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.5108    0.4555    0.4816       832
        REFUTES     0.4519    0.3952    0.4217       463
       SUPPORTS     0.5490    0.6466    0.5938       832

       accuracy                         0.5172      2127
      macro avg     0.5039    0.4991    0.4990      2127
   weighted avg     0.5129    0.5172    0.5124      2127

===
climatefeverpure-bert-base-uncased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.6579    0.1502    0.2446       832
        REFUTES     0.4049    0.1793    0.2485       463
       SUPPORTS     0.4273    0.8894    0.5772       832

       accuracy                         0.4457      2127
      macro avg     0.4967    0.4063    0.3568      2127
   weighted avg     0.5126    0.4457    0.3756      2127

===
climatefever-xlnet-base-cased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO  

In [82]:
for score in sent_micro_res["scifact"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4156    0.2857    0.3386       224
        REFUTES     0.3459    0.3770    0.3608       122
       SUPPORTS     0.5018    0.6389    0.5621       216

       accuracy                         0.4413       562
      macro avg     0.4211    0.4339    0.4205       562
   weighted avg     0.4336    0.4413    0.4293       562

===
fever-climatefever-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4542    0.4866    0.4698       224
        REFUTES     0.3871    0.1967    0.2609       122
       SUPPORTS     0.5154    0.6204    0.5630       216

       accuracy                         0.4751       562
      macro avg     0.4522    0.4346    0.4312       562
   weighted avg     0.4631    0.4751    0.4603       562

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT EN

In [18]:
for score in sent_micro_res["scifactpipeline"]["all"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefever-xlnet-base-cased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.5273    0.5476    0.5373      2080
        REFUTES     0.4130    0.3544    0.3815      1185
       SUPPORTS     0.5604    0.5820    0.5710      2280

       accuracy                         0.5205      5545
      macro avg     0.5002    0.4947    0.4966      5545
   weighted avg     0.5165    0.5205    0.5178      5545

===
climatefeverpure-bert-base-uncased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.6208    0.1606    0.2552      2080
        REFUTES     0.3279    0.1350    0.1913      1185
       SUPPORTS     0.4448    0.8816    0.5913      2280

       accuracy                         0.4516      5545
      macro avg     0.4645    0.3924    0.3459      5545
   weighted avg     0.4858    0.4516    0.3797      5545

===
climatefever-xlnet-base-cased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO  

In [93]:
for score in sent_micro_res["scifactpipeline"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.5177    0.3661    0.4289       560
        REFUTES     0.3400    0.4281    0.3790       320
       SUPPORTS     0.5521    0.6242    0.5859       620

       accuracy                         0.4860      1500
      macro avg     0.4699    0.4728    0.4646      1500
   weighted avg     0.4940    0.4860    0.4831      1500

===
fever-climatefever-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4725    0.5679    0.5158       560
        REFUTES     0.3506    0.1906    0.2470       320
       SUPPORTS     0.5574    0.5871    0.5719       620

       accuracy                         0.4953      1500
      macro avg     0.4602    0.4485    0.4449      1500
   weighted avg     0.4816    0.4953    0.4816      1500

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT EN

## Macro verdict

### Majority

In [99]:
majority_agg_res = {
    "fever": defaultdict(list), 
    "climatefeverpure": defaultdict(list), 
    "scifact": defaultdict(list),
    "climatefever": defaultdict(list),
    "scifactpipeline": defaultdict(list)
}

for actual_dp in tqdm(actual_pls):
    dataset = actual_dp.parent.stem.split("-")[0]
    split = actual_dp.stem.split(".")[0]
    actual_data = read_data(actual_dp)
    if "climatefever" in dataset:
        score_obj = scorer.ClimateFEVERScorer
    else:
        score_obj = scorer.FEVERScorer
        
    for pp in pred_p.joinpath("sent", dataset).glob(f"*.{split}.jsonl"):
        preds = read_data(pp)
        
        # aggregate
        preds = pd.DataFrame(preds)
        preds["predicted_label"] = preds["predicted_label"].map(constants.LABEL2ID)
        preds = preds.groupby("claim_id", sort=False, as_index=False).agg({"predicted_label": aggregate.agg_predict})
        preds["predicted_label"] = preds["predicted_label"].map(constants.ID2LABEL)
        preds = preds.to_dict("records")
        
        assert all([int(a["id"]) == int(p["claim_id"].split("|")[1]) for a, p in zip(actual_data, preds)])
        
        majority_agg_res[dataset][split].append(score_obj(
            actual_data=actual_data, 
            prediction_data=preds, 
            oracle_rte=False, 
            oracle_ir=True, 
            max_evidence=99999, 
            score_name=pp.stem
        ))
    if dataset == "scifact":
        dataset = "scifactpipeline"
        for pp in pred_p.joinpath("sent", dataset).glob(f"*.{split}.jsonl"):
            preds = read_data(pp)

            # aggregate
            preds = pd.DataFrame(preds)
            preds["predicted_label"] = preds["predicted_label"].map(constants.LABEL2ID)
            preds = preds.groupby("claim_id", sort=False, as_index=False).agg({
                "predicted_label": aggregate.agg_predict,
                "predicted_evidence": lambda ls: [i for i in ls]
            })
            preds["predicted_label"] = preds["predicted_label"].map(constants.ID2LABEL)
            preds = preds.to_dict("records")
            
            assert all([int(a["id"]) == int(p["claim_id"].split("|")[1]) for a, p in zip(actual_data, preds)])

            majority_agg_res[dataset][split].append(score_obj(
                actual_data=actual_data, 
                prediction_data=preds, 
                oracle_rte=False, 
                oracle_ir=False, 
                max_evidence=99999, 
                score_name=pp.stem
            ))

100%|██████████| 9/9 [00:33<00:00,  3.68s/it]


In [102]:
with result_p.joinpath("sent_macro_verdict_majority_metrics.pkl").open("wb") as fn:
    pkl.dump(majority_agg_res, fn)

In [21]:
for score in majority_agg_res["fever"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.9883    0.9859    0.9871      3333
        REFUTES     0.9591    0.9157    0.9369      3333
       SUPPORTS     0.9198    0.9637    0.9412      3333

       accuracy                         0.9551      9999
      macro avg     0.9557    0.9551    0.9551      9999
   weighted avg     0.9557    0.9551    0.9551      9999

===
fever-climatefever-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.9873    0.9823    0.9848      3333
        REFUTES     0.9546    0.9148    0.9343      3333
       SUPPORTS     0.9200    0.9631    0.9411      3333

       accuracy                         0.9534      9999
      macro avg     0.9540    0.9534    0.9534      9999
   weighted avg     0.9540    0.9534    0.9534      9999

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT EN

In [22]:
for score in majority_agg_res["climatefeverpure"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4245    0.4737    0.4478        95
        REFUTES     0.4000    0.4314    0.4151        51
       SUPPORTS     0.7009    0.6212    0.6586       132

       accuracy                         0.5360       278
      macro avg     0.5085    0.5088    0.5072       278
   weighted avg     0.5512    0.5360    0.5419       278

===
fever-climatefever-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4175    0.8526    0.5606        95
        REFUTES     0.6111    0.2157    0.3188        51
       SUPPORTS     0.8182    0.4091    0.5455       132

       accuracy                         0.5252       278
      macro avg     0.6156    0.4925    0.4749       278
   weighted avg     0.6433    0.5252    0.5090       278

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT EN

In [23]:
for score in majority_agg_res["climatefever"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4787    0.4737    0.4762        95
        REFUTES     0.3966    0.4510    0.4220        51
       SUPPORTS     0.6905    0.6591    0.6744       132

       accuracy                         0.5576       278
      macro avg     0.5219    0.5279    0.5242       278
   weighted avg     0.5642    0.5576    0.5604       278

===
fever-climatefever-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4402    0.8526    0.5806        95
        REFUTES     0.6364    0.2745    0.3836        51
       SUPPORTS     0.8194    0.4470    0.5784       132

       accuracy                         0.5540       278
      macro avg     0.6320    0.5247    0.5142       278
   weighted avg     0.6563    0.5540    0.5434       278

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT EN

In [24]:
for score in majority_agg_res["scifact"]["all"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefever-xlnet-base-cased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.5433    0.3918    0.4553       416
        REFUTES     0.4860    0.4388    0.4612       237
       SUPPORTS     0.5613    0.7325    0.6356       456

       accuracy                         0.5419      1109
      macro avg     0.5302    0.5210    0.5174      1109
   weighted avg     0.5385    0.5419    0.5307      1109

===
climatefeverpure-bert-base-uncased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.6125    0.1178    0.1976       416
        REFUTES     0.3462    0.1519    0.2111       237
       SUPPORTS     0.4411    0.8947    0.5909       456

       accuracy                         0.4445      1109
      macro avg     0.4666    0.3881    0.3332      1109
   weighted avg     0.4851    0.4445    0.3622      1109

===
climatefever-xlnet-base-cased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO  

In [100]:
for score in majority_agg_res["scifact"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4776    0.2857    0.3575       112
        REFUTES     0.4098    0.3906    0.4000        64
       SUPPORTS     0.5581    0.7742    0.6486       124

       accuracy                         0.5100       300
      macro avg     0.4819    0.4835    0.4687       300
   weighted avg     0.4964    0.5100    0.4869       300

===
fever-climatefever-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.5169    0.4107    0.4577       112
        REFUTES     0.4167    0.2344    0.3000        64
       SUPPORTS     0.5543    0.7823    0.6488       124

       accuracy                         0.5267       300
      macro avg     0.4959    0.4758    0.4688       300
   weighted avg     0.5110    0.5267    0.5031       300

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT EN

In [25]:
for score in majority_agg_res["scifactpipeline"]["all"]:
    print(score._score_name)
    print(score.fever_score)
    print(score.classification_report)
    print("===")

fever-climatefever-xlnet-base-cased.all
0.4003606853020739
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.6149    0.2572    0.3627       416
        REFUTES     0.4350    0.4515    0.4431       237
       SUPPORTS     0.5210    0.7873    0.6271       456

       accuracy                         0.5167      1109
      macro avg     0.5236    0.4987    0.4776      1109
   weighted avg     0.5379    0.5167    0.4886      1109

===
climatefeverpure-bert-base-uncased.all
0.3309287646528404
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.8537    0.0841    0.1532       416
        REFUTES     0.3423    0.1603    0.2184       237
       SUPPORTS     0.4368    0.9167    0.5916       456

       accuracy                         0.4427      1109
      macro avg     0.5443    0.3870    0.3211      1109
   weighted avg     0.5730    0.4427    0.3474      1109

===
climatefever-xlnet-base-cased.all
0.3985572587917042
                 p

In [101]:
for score in majority_agg_res["scifactpipeline"]["dev"]:
    print(score._score_name)
    print(score.fever_score)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
0.35
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.6000    0.1607    0.2535       112
        REFUTES     0.3478    0.5000    0.4103        64
       SUPPORTS     0.5281    0.7581    0.6225       124

       accuracy                         0.4800       300
      macro avg     0.4920    0.4729    0.4288       300
   weighted avg     0.5165    0.4800    0.4395       300

===
fever-climatefever-bert-base-uncased.dev
0.38
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.5763    0.3036    0.3977       112
        REFUTES     0.3261    0.2344    0.2727        64
       SUPPORTS     0.5026    0.7903    0.6144       124

       accuracy                         0.4900       300
      macro avg     0.4683    0.4428    0.4283       300
   weighted avg     0.4924    0.4900    0.4606       300

===
climatefeverpure-xlnet-base-cased.dev
0.35333333333333333
                 precision    reca

### Mean probability

In [103]:
meanproba_agg_res = {
    "fever": defaultdict(list), 
    "climatefeverpure": defaultdict(list), 
    "scifact": defaultdict(list),
    "climatefever": defaultdict(list),
    "scifactpipeline": defaultdict(list)
}

for actual_dp in tqdm(actual_pls):
    dataset = actual_dp.parent.stem.split("-")[0]
    split = actual_dp.stem.split(".")[0]
    actual_data = read_data(actual_dp)
    if "climatefever" in dataset:
        score_obj = scorer.ClimateFEVERScorer
    else:
        score_obj = scorer.FEVERScorer
        
    for pp in pred_p.joinpath("sent", dataset).glob(f"*.{split}.jsonl"):
        preds = read_data(pp)
        
        # aggregate
        preds = pd.DataFrame(preds)
        preds = preds.groupby("claim_id", sort=False, as_index=False).agg({"predicted_proba": aggregate.agg_predict_proba})
        preds = preds.rename(columns={"predicted_proba": "predicted_label"})
        preds["predicted_label"] = preds["predicted_label"].map(constants.ID2LABEL)
        preds = preds.to_dict("records")
        
        assert all([int(a["id"]) == int(p["claim_id"].split("|")[1]) for a, p in zip(actual_data, preds)])
        
        meanproba_agg_res[dataset][split].append(score_obj(
            actual_data=actual_data, 
            prediction_data=preds, 
            oracle_rte=False, 
            oracle_ir=True, 
            max_evidence=99999, 
            score_name=pp.stem
        ))
    if dataset == "scifact":
        dataset = "scifactpipeline"
        for pp in pred_p.joinpath("sent", dataset).glob(f"*.{split}.jsonl"):
            preds = read_data(pp)

            # aggregate
            preds = pd.DataFrame(preds)
            preds = preds.groupby("claim_id", sort=False, as_index=False).agg({
                "predicted_proba": aggregate.agg_predict_proba,
                "predicted_evidence": lambda ls: [i for i in ls]
            })
            preds = preds.rename(columns={"predicted_proba": "predicted_label"})
            preds["predicted_label"] = preds["predicted_label"].map(constants.ID2LABEL)
            preds = preds.to_dict("records")
            
            assert all([int(a["id"]) == int(p["claim_id"].split("|")[1]) for a, p in zip(actual_data, preds)])
            
            meanproba_agg_res[dataset][split].append(score_obj(
                actual_data=actual_data, 
                prediction_data=preds, 
                oracle_rte=False, 
                oracle_ir=False, 
                max_evidence=99999, 
                score_name=pp.stem
            ))

100%|██████████| 9/9 [00:34<00:00,  3.78s/it]


In [106]:
with result_p.joinpath("sent_macro_verdict_meanproba_metrics.pkl").open("wb") as fn:
    pkl.dump(meanproba_agg_res, fn)

In [28]:
for score in meanproba_agg_res["fever"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.9884    0.9961    0.9922      3333
        REFUTES     0.9627    0.9130    0.9372      3333
       SUPPORTS     0.9235    0.9640    0.9433      3333

       accuracy                         0.9577      9999
      macro avg     0.9582    0.9577    0.9576      9999
   weighted avg     0.9582    0.9577    0.9576      9999

===
fever-climatefever-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.9863    0.9952    0.9907      3333
        REFUTES     0.9602    0.9121    0.9355      3333
       SUPPORTS     0.9256    0.9637    0.9443      3333

       accuracy                         0.9570      9999
      macro avg     0.9574    0.9570    0.9569      9999
   weighted avg     0.9574    0.9570    0.9569      9999

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT EN

In [29]:
for score in meanproba_agg_res["climatefeverpure"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4870    0.7895    0.6024        95
        REFUTES     0.4750    0.3725    0.4176        51
       SUPPORTS     0.8571    0.5455    0.6667       132

       accuracy                         0.5971       278
      macro avg     0.6064    0.5692    0.5622       278
   weighted avg     0.6606    0.5971    0.5990       278

===
fever-climatefever-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4115    0.9789    0.5794        95
        REFUTES     0.8000    0.1569    0.2623        51
       SUPPORTS     0.9286    0.2955    0.4483       132

       accuracy                         0.5036       278
      macro avg     0.7134    0.4771    0.4300       278
   weighted avg     0.7283    0.5036    0.4590       278

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT EN

In [30]:
for score in meanproba_agg_res["climatefever"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4491    0.7895    0.5725        95
        REFUTES     0.4595    0.3333    0.3864        51
       SUPPORTS     0.8378    0.4697    0.6019       132

       accuracy                         0.5540       278
      macro avg     0.5821    0.5308    0.5203       278
   weighted avg     0.6356    0.5540    0.5523       278

===
fever-climatefever-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.3843    0.9789    0.5519        95
        REFUTES     0.6250    0.0980    0.1695        51
       SUPPORTS     0.8929    0.1894    0.3125       132

       accuracy                         0.4424       278
      macro avg     0.6341    0.4221    0.3446       278
   weighted avg     0.6699    0.4424    0.3681       278

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT EN

In [31]:
for score in meanproba_agg_res["scifact"]["all"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefever-xlnet-base-cased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.5029    0.4183    0.4567       416
        REFUTES     0.4680    0.4008    0.4318       237
       SUPPORTS     0.5661    0.6952    0.6240       456

       accuracy                         0.5284      1109
      macro avg     0.5123    0.5048    0.5042      1109
   weighted avg     0.5214    0.5284    0.5202      1109

===
climatefeverpure-bert-base-uncased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.6703    0.1466    0.2406       416
        REFUTES     0.3571    0.1477    0.2090       237
       SUPPORTS     0.4424    0.8925    0.5916       456

       accuracy                         0.4536      1109
      macro avg     0.4900    0.3956    0.3471      1109
   weighted avg     0.5097    0.4536    0.3782      1109

===
climatefever-xlnet-base-cased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO  

In [104]:
for score in meanproba_agg_res["scifact"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4533    0.3036    0.3636       112
        REFUTES     0.3824    0.4062    0.3939        64
       SUPPORTS     0.5732    0.7258    0.6406       124

       accuracy                         0.5000       300
      macro avg     0.4696    0.4785    0.4660       300
   weighted avg     0.4878    0.5000    0.4846       300

===
fever-climatefever-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4911    0.4911    0.4911       112
        REFUTES     0.4138    0.1875    0.2581        64
       SUPPORTS     0.5597    0.7177    0.6290       124

       accuracy                         0.5200       300
      macro avg     0.4882    0.4654    0.4594       300
   weighted avg     0.5030    0.5200    0.4984       300

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT EN

In [32]:
for score in meanproba_agg_res["scifactpipeline"]["all"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefever-xlnet-base-cased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.5611    0.5962    0.5781       416
        REFUTES     0.4728    0.3671    0.4133       237
       SUPPORTS     0.5942    0.6294    0.6113       456

       accuracy                         0.5609      1109
      macro avg     0.5427    0.5309    0.5342      1109
   weighted avg     0.5558    0.5609    0.5565      1109

===
climatefeverpure-bert-base-uncased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.6630    0.1466    0.2402       416
        REFUTES     0.3548    0.1392    0.2000       237
       SUPPORTS     0.4437    0.8991    0.5942       456

       accuracy                         0.4545      1109
      macro avg     0.4872    0.3950    0.3448      1109
   weighted avg     0.5070    0.4545    0.3772      1109

===
climatefever-xlnet-base-cased.all
                 precision    recall  f1-score   support

NOT ENOUGH INFO  

In [105]:
for score in meanproba_agg_res["scifactpipeline"]["dev"]:
    print(score._score_name)
    print(score.classification_report)
    print("===")

fever-climatefeverpure-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4943    0.3839    0.4322       112
        REFUTES     0.3200    0.3750    0.3453        64
       SUPPORTS     0.5580    0.6210    0.5878       124

       accuracy                         0.4800       300
      macro avg     0.4574    0.4600    0.4551       300
   weighted avg     0.4834    0.4800    0.4780       300

===
fever-climatefever-bert-base-uncased.dev
                 precision    recall  f1-score   support

NOT ENOUGH INFO     0.4815    0.5804    0.5263       112
        REFUTES     0.3548    0.1719    0.2316        64
       SUPPORTS     0.5672    0.6129    0.5891       124

       accuracy                         0.5067       300
      macro avg     0.4678    0.4550    0.4490       300
   weighted avg     0.4899    0.5067    0.4894       300

===
climatefeverpure-xlnet-base-cased.dev
                 precision    recall  f1-score   support

NOT EN