In [2]:
from dataset.def_dataset import DefinitionDataset, Fact
from transformers import AutoTokenizer
from models.evidence_selection_model import EvidenceSelectionModel
from models.claim_verification_model import ClaimVerificationModel
import torch
from transformers import AutoModel, AutoModelForSequenceClassification
from general_utils.fever_scorer import fever_score
from pipeline.pipeline import TestPipeline, WikiPipeline
from general_utils.utils import build_fever_instance
from general_utils.utils import convert_document_id_to_word
from sklearn.metrics import classification_report
from tqdm import tqdm

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

#selection_model_tokenizer = AutoTokenizer.from_pretrained('Snowflake/snowflake-arctic-embed-m-long')
#selection_model_raw = AutoModel.from_pretrained('Snowflake/snowflake-arctic-embed-m-long', trust_remote_code=True, add_pooling_layer=False, safe_serialization=True)
selection_model_name = 'lukasellinger/evidence_selection_model-v2'
selection_model_tokenizer = AutoTokenizer.from_pretrained(selection_model_name)
selection_model_raw = AutoModel.from_pretrained(selection_model_name, trust_remote_code=True, add_pooling_layer=False, safe_serialization=True)
selection_model = EvidenceSelectionModel(selection_model_raw).to(device)

#verification_model_name = 'MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7'
verification_model_name = 'lukasellinger/claim_verification_model-v1'
verification_model_tokenizer = AutoTokenizer.from_pretrained(verification_model_name)
verification_model_raw = AutoModelForSequenceClassification.from_pretrained(verification_model_name)
verification_model = ClaimVerificationModel(verification_model_raw).to(device)

<All keys matched successfully>


In [17]:
from datasets import load_dataset

raw_dataset = load_dataset("lukasellinger/fever_evidence_selection-v1", cache_dir=None).get('dev')
# dataset = DefinitionDataset(raw_dataset, tokenizer=None, model='claim_verification')
print(raw_dataset.features)

test_pipeline = TestPipeline(selection_model=selection_model,selection_model_tokenizer=selection_model_tokenizer, 
                             verification_model=verification_model, verification_model_tokenizer=verification_model_tokenizer)

pr_labels = []
gt_labels = []
fever_instances = []
for entry in tqdm(raw_dataset):
    word = entry.get('document_id')
    fallback_word = convert_document_id_to_word(word)

    output = test_pipeline.verify(word, entry['short_claim'], fallback_word, split_facts=False)
    if output.get('factuality') == 1:
        factuality = Fact.SUPPORTED
    else:
        factuality = Fact.NOT_SUPPORTED
    pr_labels.append(factuality.to_factuality())

    if entry['label'] == 'SUPPORTS':
        label = Fact.SUPPORTED
    else:
        label = Fact.NOT_SUPPORTED
    gt_labels.append(label.to_factuality())

    evidence = entry['evidence_lines'].split(';')
    #predicted_label = output.get('factualities')[0]  # TODO add atomic fact support
    #predicted_evidence = output.get('evidences')
    predicted_evidence = [(x, y) for (x, y, z) in output.get('evidences')]
    fever_instance = build_fever_instance(label.name, evidence, entry['document_id'], factuality, predicted_evidence)
    fever_instances.append(fever_instance)

print(classification_report(gt_labels, pr_labels, zero_division=0))
strict_score, label_accuracy, precision, recall, f1 = fever_score(fever_instances)

print(strict_score)
print(label_accuracy)
print(precision)  # TP / TP + FP not too important, rather at least one TP than none
print(recall)     # more important
print(f1)

{'id': Value(dtype='int64', id=None), 'claim': Value(dtype='string', id=None), 'short_claim': Value(dtype='string', id=None), 'label': Value(dtype='string', id=None), 'document_id': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'lines': Value(dtype='string', id=None), 'evidence_lines': Value(dtype='string', id=None), 'atomic_facts': Value(dtype='string', id=None)}


  0%|          | 0/1978 [00:37<?, ?it/s]

KeyboardInterrupt



In [4]:
from datasets import load_dataset

dataset = load_dataset("lukasellinger/german_dpr_claim_verification_dissim-v1").get('train')
offline_wiki = 'lukasellinger/wiki_dump_2024-07-08'
print(dataset.features)

pipeline = WikiPipeline(selection_model=selection_model, selection_model_tokenizer=selection_model_tokenizer, word_lang='de', use_offline_wiki=offline_wiki)

pr_labels = []
gt_labels = []
factualities = []
not_in_wiki = 0
for entry in tqdm(dataset):
    word = entry.get('word')
    english_word = entry.get('english_word', word)
    search_word = entry.get('document_search_word')
    claim = entry.get('english_claim', entry['claim'])
    atomic_facts = entry['atomic_facts']
    atomic_facts = atomic_facts.split('--;--') if atomic_facts else []
    
    factuality = pipeline.verify(word, claim, english_word, only_intro=True, split_facts=False, search_word=search_word)
    factualities.append(factuality)
    if factuality.get('factuality') == 1:
        pr_labels.append(Fact.SUPPORTED.to_factuality())
    elif factuality.get('factuality') == -1:
        not_in_wiki += 1
        continue
    else:
        pr_labels.append(Fact.NOT_SUPPORTED.to_factuality())
    gt_labels.append(Fact[entry['label']].to_factuality())

print(f'Not in wiki {not_in_wiki}')
print(classification_report(gt_labels, pr_labels, zero_division=0))

{'id': Value(dtype='int64', id=None), 'question': Value(dtype='string', id=None), 'claim': Value(dtype='string', id=None), 'english_claim': Value(dtype='string', id=None), 'fact': Value(dtype='string', id=None), 'word': Value(dtype='string', id=None), 'english_word': Value(dtype='string', id=None), 'context': Value(dtype='string', id=None), 'label': Value(dtype='string', id=None), 'atomic_facts': Value(dtype='string', id=None), 'document_search_word': Value(dtype='string', id=None)}


100%|██████████| 168/168 [04:59<00:00,  1.78s/it]

Not in wiki 29
              precision    recall  f1-score   support

           0       0.70      0.91      0.79        69
           1       0.88      0.61      0.72        70

    accuracy                           0.76       139
   macro avg       0.79      0.76      0.76       139
weighted avg       0.79      0.76      0.76       139




