In [1]:
from wiki_dataset import WikiDataset
from vector_database import VectorDatabase
from torch.utils.data import DataLoader
from fever_dataset import FeverDataset, FeverCollator
from embedding_generation import EmbeddingGenerator

In [2]:
wiki_dataset = WikiDataset(reduced=True, type='dev', in_mem=True, num_extra_pages=0)
wiki_dataloader = DataLoader(wiki_dataset, batch_size=8000, num_workers=10, shuffle=False)

In [3]:
vdb = VectorDatabase(client='docker', wiki_loader=wiki_dataloader)

Creating collection
Block 1/1 done
Time to create collection: 2.5375070571899414


In [5]:
fever_collator = FeverCollator()
fever_dataset = FeverDataset(type='dev')
fever_loader = DataLoader(fever_dataset, batch_size=32, num_workers=10, shuffle=True, collate_fn=fever_collator)

model = EmbeddingGenerator()

In [9]:
import torch
from torch.cuda.amp import autocast
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# performs a single validation step
def valid_step(input_batch, model):
    # get embeddings of the claims
    with torch.no_grad():
        with autocast():
            outputs = model(input_batch['claims'])

    # search for similar pages
    similar_pages = vdb.search_similar(outputs, 10)
    #print(similar_pages)
    similar_texts = [[t.payload['text'] for t in s] for s in similar_pages]
    similar_ids = [[t.payload['id'] for t in s] for s in similar_pages]
    similar_embeds = [[torch.tensor(t.vector) for t in s] for s in similar_pages]


    # combine the similar pages into a single tensor
    similar_embeds = [torch.stack(s) for s in similar_embeds]

    # check if the evidence is enough
    targets = []
    for i, r in enumerate(input_batch['evidence']):
        for evidence_set in r['unique_evidence']:
            enough_evidence = False
            if evidence_set.issubset(set(similar_ids[i])):
                enough_evidence = True
                break
        targets.append(enough_evidence)

    print (targets)
    print(np.sum(targets))
    # for s, r in zip(similar_texts, input_batch['claims']):
    #     print(r)
    #     print(s)
    return outputs#, targets

In [10]:
for i, batch in enumerate(fever_loader):
    valid_step(batch, model)
    break

[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]
0
