In [1]:
%load_ext autoreload
%autoreload 2
from wiki_dataset import WikiDataset
from vector_database import VectorDatabase
from torch.utils.data import DataLoader
from fever_dataset import FeverDataset, FeverCollator
from embedding_generation import EmbeddingGenerator, NLI

In [2]:
wiki_dataset = WikiDataset(reduced=True, type='dev', in_mem=True, num_extra_pages=0, seed=0)
wiki_dataloader = DataLoader(wiki_dataset, batch_size=8192, num_workers=8, shuffle=False)

In [3]:
vdb = VectorDatabase(client='docker', wiki_loader=wiki_dataloader)

Creating collection
Block 1/1 done
Time to create collection: 9.688613176345825


In [19]:
fever_collator = FeverCollator()
fever_dataset = FeverDataset(type='dev')
fever_loader = DataLoader(fever_dataset, batch_size=32, num_workers=10, shuffle=False, collate_fn=fever_collator)

emb_gen = EmbeddingGenerator()
nli = NLI()

In [20]:
fever_dataset[1]

{'id': 194462,
 'verifiable': 'NOT VERIFIABLE',
 'label': 'NOT ENOUGH INFO',
 'claim': 'Tilda Swinton is a vegan.',
 'evidence': {'all_evidence': [None], 'unique_evidence': [set()]}}

In [32]:
l = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
l = [item for sublist in l for item in sublist]

[3, 2, 1]

In [59]:
import torch
from train_eval_utils import get_target_changes, get_negative_examples

PAGES_RETRIEVED = 50
PAGES_FOR_EVIDENCE = 10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# get embeddings of the claims

for input_batch in fever_loader:
    batch_size = len(input_batch['claims'])
    with torch.no_grad():
        outputs = emb_gen(input_batch['claims'])

    # search for similar pages
    similar_pages = vdb.search_similar(outputs, PAGES_RETRIEVED)
    #print(similar_pages)
    similar_texts = [[t.payload['text'] for t in s] for s in similar_pages]
    similar_ids = [[t.payload['id'] for t in s] for s in similar_pages]
    similar_embeds = [[t.vector for t in s] for s in similar_pages]

    target_changes, precentage_retrieved = get_target_changes(input_batch, similar_ids, PAGES_FOR_EVIDENCE)
    targets = [v == 'VERIFIABLE' for v in input_batch['verifiable']]

    # dinamically change the target
    #---------------------------------- only training
    targets = [t and tc for t, tc in zip(targets, target_changes)]
    #---------------------------------- only training

    all_evidence = [r['all_evidence'] if r['all_evidence'] != [None] else [] for r in input_batch['evidence']]
    evidence_pages = [vdb.search_ids(all_evidence[i]) for i in range(batch_size)]
    evidence_texts = [[t.payload['text'] for t in s] for s in evidence_pages]
    # pick as negative examples the texts of the last len(evidence_texts) of the 50 retrieved pages
    negative_examples = get_negative_examples(similar_texts, similar_ids, all_evidence)
    # combine the positive and negative examples
    combined_texts = [s + n for s, n in zip(evidence_texts, negative_examples)]
    
    # encode the combined texts in batches
    for i in range(0, len(combined_texts), batch_size):
        with torch.no_grad():
            combined_embeddings = emb_gen(combined_texts[i:i+batch_size])
        if i == 0:
            combined_embeds = combined_embeddings
        else:
            combined_embeds = torch.cat([combined_embeds, combined_embeddings], dim=0)
    
    break

IndexError: list index out of range