In [1]:
%load_ext autoreload
%autoreload 2
from wiki_dataset import WikiDataset
from vector_database import VectorDatabase
from torch.utils.data import DataLoader
from fever_dataset import FeverDataset, FeverCollator
from embedding_generation import EmbeddingGenerator, NLI

In [2]:
wiki_dataset = WikiDataset(reduced=True, type='dev', in_mem=True, num_extra_pages=0, seed=0)
wiki_dataloader = DataLoader(wiki_dataset, batch_size=8192, num_workers=8, shuffle=False)

In [3]:
vdb = VectorDatabase(client='docker', wiki_loader=wiki_dataloader)

Creating collection
Block 1/1 done
Time to create collection: 9.688613176345825


In [19]:
fever_collator = FeverCollator()
fever_dataset = FeverDataset(type='dev')
fever_loader = DataLoader(fever_dataset, batch_size=32, num_workers=10, shuffle=False, collate_fn=fever_collator)

emb_gen = EmbeddingGenerator()
nli = NLI()

In [20]:
fever_dataset[1]

{'id': 194462,
 'verifiable': 'NOT VERIFIABLE',
 'label': 'NOT ENOUGH INFO',
 'claim': 'Tilda Swinton is a vegan.',
 'evidence': {'all_evidence': [None], 'unique_evidence': [set()]}}

In [32]:
l = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
l = [item for sublist in l for item in sublist]

[3, 2, 1]

In [99]:
import torch
from train_eval_utils import get_target_changes, get_negative_examples
import numpy as np
from sklearn.metrics import f1_score

PAGES_RETRIEVED = 50
PAGES_FOR_EVIDENCE = 10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loss_fn1 = torch.nn.CosineEmbeddingLoss()
loss_fn2 = torch.nn.BCELoss()

# get embeddings of the claims

for input_batch in fever_loader:
    batch_size = len(input_batch['claims'])
    with torch.no_grad():
        outputs = emb_gen(input_batch['claims'])

    # search for similar pages
    similar_pages = vdb.search_similar(outputs, PAGES_RETRIEVED)
    #print(similar_pages)
    similar_texts = [[t.payload['text'] for t in s] for s in similar_pages]
    similar_ids = [[t.payload['id'] for t in s] for s in similar_pages]
    similar_embeds = [[t.vector for t in s[:PAGES_FOR_EVIDENCE]] for s in similar_pages]

    target_changes, precentage_retrieved = get_target_changes(input_batch, similar_ids, PAGES_FOR_EVIDENCE)
    nli_targets = [v == 'VERIFIABLE' for v in input_batch['verifiable']]

    # dinamically change the target
    #---------------------------------- only training
    nli_targets = [int(t and tc) for t, tc in zip(nli_targets, target_changes)]
    #---------------------------------- only training

    all_evidence = [r['all_evidence'] if r['all_evidence'] != [None] else [] for r in input_batch['evidence']]
    evidence_pages = [vdb.search_ids(all_evidence[i]) for i in range(batch_size)]
    evidence_texts = [[t.payload['text'] for t in s] for s in evidence_pages]
    # pick as negative examples the texts of the last len(evidence_texts) of the 50 retrieved pages
    negative_examples = get_negative_examples(similar_texts, similar_ids, all_evidence)

    # combine all the batches
    unfolded_outputs = []
    unfolded_combined_texts = []
    unfolded_labels = []
    for i in range(batch_size):
        unfolded_outputs.extend([outputs[i]] * len(evidence_texts[i] + negative_examples[i]))
        unfolded_combined_texts.extend(evidence_texts[i] + negative_examples[i])
        unfolded_labels.extend([1] * len(evidence_texts[i]) + [0] * len(negative_examples[i]))
    
    # encode the combined texts in batches
    combined_embeddings = []
    for i in range(0, len(unfolded_combined_texts), batch_size):
        with torch.no_grad():
            combined_embeddings.extend(emb_gen(unfolded_combined_texts[i:i+batch_size]))

    combined_embeddings = torch.tensor(combined_embeddings).to(device)
    unfolded_outputs = torch.tensor(unfolded_outputs).to(device)
    unfolded_labels = torch.tensor(unfolded_labels).to(device)

    loss1 = loss_fn1(combined_embeddings, unfolded_outputs, unfolded_labels)
    print(loss1)


    # input for the NLI model
    outputs = torch.tensor(outputs).unsqueeze(1).to(device)
    similar_embeds = torch.tensor(similar_embeds).to(device)

    # concat the output of the embedding generator
    nli_inputs = torch.cat([outputs, similar_embeds], dim=1)

    with torch.no_grad():
        nli_outputs = nli(nli_inputs.half())
    
    preds = torch.argmax(nli_outputs, dim=1)
    targets = torch.tensor(nli_targets).half().to(device)

    # nli_outputs is 32,2 keep only the 1
    nli_outputs = nli_outputs[:, 1]

    # Convert lists of tensors to tensors
    loss2 = loss_fn2(nli_outputs, targets)
    print(loss2)

    # calculate the f1 score, we apply a sigmoid to the output if the loss function is BCEWithLogitsLoss
    if loss_fn2.__class__.__name__ == 'BCEWithLogitsLoss':
        nli_outputs = torch.sigmoid(nli_outputs)
    nli_outputs = nli_outputs.detach().cpu().numpy()
    targets = targets.detach().cpu().numpy()
    f1 = f1_score(targets, (nli_outputs > 0.5).astype(int), average='macro')

    result = {'preds': preds, 'targets': targets, 'target_changes': target_changes, 'precentage_retrieved': precentage_retrieved}

    print(result)
    break

tensor(0.1944, device='cuda:0')
tensor(0.7065, device='cuda:0', dtype=torch.float16)
{'preds': tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0'), 'targets': array([0., 0., 1., 0., 1., 1., 1., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1.,
       0., 0., 1., 1., 0., 1., 1., 1., 1., 0., 1., 0., 0., 1., 1.],
      dtype=float16), 'target_changes': [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, True, True, True, True, True], 'precentage_retrieved': [0.0, 0.0, 1.0, 0.0, 0.5, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0]}
