In [None]:
from google.colab import drive

drive.mount('/content/drive', force_remount=True)
repository = 'evaluating_factuality_word_definitions'

%cd /content/drive/My Drive/{repository}

In [None]:
!pip install datasets
!pip install peft
!pip install rank_bm25

In [1]:
from datasets import Dataset
from dataset.def_dataset import Fact
from config import DB_URL
from transformers import AutoTokenizer
from models.evidence_selection_model import EvidenceSelectionModel
from peft import AutoPeftModelForFeatureExtraction
import torch
from pipeline.pipeline import TestPipeline
from pipeline.pipeline import WikiPipeline
from utils import convert_document_id_to_word
from sklearn.metrics import classification_report
from tqdm import tqdm

In [2]:
dataset_query = """
select dd.id, docs.document_id, dd.claim, dd.label
from def_dataset dd
    join documents docs on docs.document_id = dd.evidence_wiki_url
    join atomic_facts af on af.claim_id = dd.id
where set_type='{set_type}' -- and length(claim) < 50 and length(docs.text) < 400
group by dd.id, evidence_annotation_id, evidence_wiki_url
limit 20
"""

dataset = Dataset.from_sql(dataset_query.format(set_type='dev'), con=DB_URL)

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = 'google/bigbird-roberta-large'
model = AutoPeftModelForFeatureExtraction.from_pretrained('selection_model_intermediate_04-30_09-40')

selection_model = EvidenceSelectionModel(model).to(device)
selection_model_tokenizer = AutoTokenizer.from_pretrained(model_name)

# still using base
verification_model=None
verification_model_tokenizer=None

In [4]:
test_pipeline = TestPipeline(selection_model=selection_model, selection_model_tokenizer=selection_model_tokenizer)

pr_labels = []
gt_labels = []
for entry in tqdm(dataset):
    factuality = test_pipeline.verify(entry['document_id'], entry['claim'])
    pr_labels.extend([fact.to_factuality() for fact in factuality])
    gt_labels += [Fact[entry['label']].to_factuality()] * len(factuality)

print(classification_report(gt_labels, pr_labels, zero_division=0))

  0%|          | 0/20 [00:00<?, ?it/s]Attention type 'block_sparse' is not possible if sequence_length: 10 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3. Changing attention type to 'original_full'...
100%|██████████| 20/20 [01:05<00:00,  3.30s/it]

              precision    recall  f1-score   support

          -1       0.00      0.00      0.00         0
           0       1.00      0.20      0.33        15
           1       0.68      0.83      0.75        18

    accuracy                           0.55        33
   macro avg       0.56      0.34      0.36        33
weighted avg       0.83      0.55      0.56        33






In [5]:
print([int(label) for label in pr_labels])
print(gt_labels)

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 0, 1, -1, 1, 0, -1, -1, -1, -1]
[1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0]


In [14]:
pipeline = WikiPipeline(selection_model=selection_model, selection_model_tokenizer=selection_model_tokenizer)

pr_labels = []
gt_labels = []
for entry in tqdm(dataset):
    word = convert_document_id_to_word(entry['document_id'])
    
    factuality = pipeline.verify(word, entry['claim'])
    pr_labels.extend([fact.to_factuality() for fact in factuality])
    gt_labels += [Fact[entry['label']].to_factuality()] * len(factuality)

print(classification_report(gt_labels, pr_labels, zero_division=0))

100%|██████████| 20/20 [01:52<00:00,  5.63s/it]

              precision    recall  f1-score   support

        -1.0       0.00      0.00      0.00         0
         0.0       0.75      0.86      0.80         7
         1.0       1.00      0.62      0.76        13

    accuracy                           0.70        20
   macro avg       0.58      0.49      0.52        20
weighted avg       0.91      0.70      0.78        20




  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
print([int(label) for label in pr_labels])
print(gt_labels)