In [1]:
from wiki_dataset import WikiDataset
from vector_database import VectorDatabase
from torch.utils.data import DataLoader
from fever_dataset import FeverDataset, FeverCollator
from embedding_generation import EmbeddingGenerator

In [2]:
wiki_dataset = WikiDataset(reduced=True, type='train', in_mem=True, num_extra_pages=0, seed=0)
wiki_dataloader = DataLoader(wiki_dataset, batch_size=8192, num_workers=10, shuffle=False)

In [3]:
vdb = VectorDatabase(client='docker', wiki_loader=wiki_dataloader)

Creating collection
Block 1/2 done
Block 2/2 done
Time to create collection: 30.077796459197998


In [4]:
fever_collator = FeverCollator()
fever_dataset = FeverDataset(type='dev')
fever_loader = DataLoader(fever_dataset, batch_size=32, num_workers=10, shuffle=False, collate_fn=fever_collator)

model = EmbeddingGenerator()

In [7]:
fever_dataset[1]

{'id': 194462,
 'verifiable': 'NOT VERIFIABLE',
 'label': 'NOT ENOUGH INFO',
 'claim': 'Tilda Swinton is a vegan.',
 'evidence': {'all_evidence': [None], 'unique_evidence': [set()]}}

In [36]:
import torch
from torch.cuda.amp import autocast
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# performs a single validation step
def valid_step(input_batch, model):
    # get embeddings of the claims
    with torch.no_grad():
        outputs = model(input_batch['claims'])

    # search for similar pages
    similar_pages = vdb.search_similar(outputs, 10)
    #print(similar_pages)
    similar_texts = [[t.payload['text'] for t in s] for s in similar_pages]
    similar_ids = [[t.payload['id'] for t in s] for s in similar_pages]
    similar_embeds = [[t.vector for t in s] for s in similar_pages]


    # check if the evidence is enough
    #---------------------------------- only training
    target_changes = []
    for i, r in enumerate(input_batch['evidence']):
        enough_evidence = False
        for evidence_set in r['unique_evidence']:
            if evidence_set is None or evidence_set.issubset(set(similar_ids[i])):
                enough_evidence = True
                break
        target_changes.append(enough_evidence)

    print(target_changes)
    print(np.sum(target_changes))
    #---------------------------------- only training

    targets = [v == 'VERIFIABLE' for v in input_batch['verifiable']]
    # dinamically change the target
    targets = [t and tc for t, tc in zip(targets, target_changes)]

    # input for the NLI model
    inputs = torch.tensor(similar_embeds).to(device)
    # concat the output of the embedding generator
    outputs = torch.cat([inputs, outputs], dim=1)

    
    print(similar_embeds.shape)

    # for s, r in zip(similar_texts, input_batch['claims']):
    #     print(r)
    #     print(s)
    return outputs#, targets

In [37]:
for i, batch in enumerate(fever_loader):
    valid_step(batch, model)
    break

[True, True, False, True, False, True, False, True, True, True, True, True, False, True, False, True, False, True, True, False, False, True, True, False, False, False, False, False, True, True, True, False]
18


TypeError: expected Tensor as element 1 in argument 0, but got numpy.ndarray