In [1]:
from wiki_dataset import WikiDataset
from vector_database import VectorDatabase
from torch.utils.data import DataLoader
from fever_dataset import FeverDataset, FeverCollator
from embedding_generation import EmbeddingGenerator

In [2]:
wiki_dataset = WikiDataset(reduced=True, type='train', in_mem=True, num_extra_pages=0)
wiki_dataloader = DataLoader(wiki_dataset, batch_size=8000, num_workers=10, shuffle=False)

In [3]:
vdb = VectorDatabase(client='docker', wiki_loader=wiki_dataloader)

Creating collection
Block 1/2 done
Block 2/2 done
Time to create collection: 17.517935037612915


In [4]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")

fever_collator = FeverCollator(tokenizer)
fever_dataset = FeverDataset(type='train')
fever_loader = DataLoader(fever_dataset, batch_size=32, num_workers=10, shuffle=True, collate_fn=fever_collator)

model = EmbeddingGenerator()

In [5]:
fever_dataset[0]

{'id': 75397,
 'verifiable': 'VERIFIABLE',
 'label': 'SUPPORTS',
 'claim': 'Nikolaj Coster-Waldau worked with the Fox Broadcasting Company.',
 'evidence': [[[92206, 104971, 'Nikolaj_Coster-Waldau', 7],
   [92206, 104971, 'Fox_Broadcasting_Company', 0]]]}

In [7]:
vdb.search_ids(['Nikolaj_Coster-Waldau', 'Fox_Broadcasting_Company'])

ValidationError: 4 validation errors for SearchRequest
vector.list[float]
  Input should be a valid list [type=list_type, input_value=None, input_type=NoneType]
    For further information visit https://errors.pydantic.dev/2.5/v/list_type
vector.NamedVector
  Input should be a valid dictionary or instance of NamedVector [type=model_type, input_value=None, input_type=NoneType]
    For further information visit https://errors.pydantic.dev/2.5/v/model_type
vector.NamedSparseVector
  Input should be a valid dictionary or instance of NamedSparseVector [type=model_type, input_value=None, input_type=NoneType]
    For further information visit https://errors.pydantic.dev/2.5/v/model_type
filter
  Input should be a valid dictionary or instance of Filter [type=model_type, input_value=('should', None), input_type=tuple]
    For further information visit https://errors.pydantic.dev/2.5/v/model_type

In [11]:
import torch
from torch.cuda.amp import autocast

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# performs a single validation step
def valid_step(input_batch, model):
    with torch.no_grad():
        with autocast():
            outputs = model(input_batch['claims'])
    similar_texts = vdb.search_similar(outputs, 1)
    for s, r in zip(similar_texts, input_batch['claims']):
        print(r)
        print(s)
    return outputs#, targets

In [12]:
for i, batch in enumerate(fever_loader):
    valid_step(batch, model)
    break

Bill Cosby has been deemed innocent of drug facilitated sexual assault.
[ScoredPoint(id=50851, version=138, score=0.24123931, payload={'id': 'HIV/AIDS_in_Senegal', 'text': 'Senegal has a low prevalence of HIV , at under 1 % of the adult population . '}, vector=None, shard_key=None)]
Wentworth is a series for television.
[ScoredPoint(id=50481, version=132, score=0.37508267, payload={'id': 'Getsuku', 'text': 'is a Japanese abbreviation for . This is traditionally the time when the most popular TV dramas air in Japan . '}, vector=None, shard_key=None)]
Floyd Mayweather Jr. is an American promoter.
[ScoredPoint(id=102294, version=285, score=0.45633596, payload={'id': 'Serafim_Todorov', 'text': 'Serafim Simeonov Todorov -LRB- Серафим Симеонов Тодоров born 6 July 1969 -RRB- is a Bulgarian former amateur boxer . He won three consecutive gold medals at both the World and European Championships , and silver at the 1996 Olympics . He is the last boxer to defeat Floyd Mayweather Jr. , who later w