In [45]:
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
import yaml
import os
import json
import pandas as pd

from src.exp_logger import logger

import pyterrier as pt  # type: ignore

from src.load_index import setup_system, tag
import torch 
import faiss
from tqdm import tqdm


with open("../settings.yml", "r") as yamlfile:
    config = yaml.load(yamlfile, Loader=yaml.FullLoader)

In [None]:
doc_path = "../"+config["WT"]["docs"].replace("Trec", "Json")
index_dir = "../data/index/e5"

In [50]:
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-small')
model = AutoModel.from_pretrained('intfloat/e5-small')

In [114]:
#prepare model for gpu use
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_ = model.to(device)

In [91]:
@torch.no_grad()
def calc_embeddings(texts, mode='passage'):
  input_texts = [f"{mode}: {text}" for text in texts]
  batch_dict = tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt')
  for key, val in batch_dict.items():
    batch_dict[key] = batch_dict[key].cuda(non_blocking=True)
  
  outputs = model(**batch_dict)
  embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
  return embeddings.detach().cpu()#.numpy()


In [49]:
def gen_docs(batch_size=2, doc_path):
    """Generate batches of documents from the WT collection. Creats a global dict of ids to doc ids."""
    global c
    c = 0
    global ids
    ids = {}
    batch = []
    for filename in os.listdir(doc_path):
        with open(doc_path+"/"+filename, "r") as f:
            for line in f:
                l = json.loads(line)
                for doc in l:
                    c+=1
                    if len(batch) == batch_size:
                        full_batch = batch
                        batch  = []
                        batch.append(doc["contents"])
                        yield full_batch
                    else:
                        batch.append(doc["contents"])
                    ids[c]= doc["id"]
                    

In [119]:
def encode(batch_size, num_docs, save_every):
    """create embeddings for docs in batches and save in batches"""
    def save_embs(embs, c, save_every):
        embs = torch.cat(embs)
        embs.save(f"../data/index/e5/e5_embeddings_{c}.pt")
        logger.info(f"Saved embeddings for {c*save_every} documents")

    c = 0
    embs = []
    for batch in tqdm(gen_docs(batch_size=batch_size), total=(int(num_docs/batch_size))):
        embeddings = calc_embeddings(batch)
        embs.append(embeddings)

        if len(embs) >= save_every:
            save_embs(embs, c, save_every)
            c+=1
            embs = []

    save_embs(embs, c, save_every)
    logger.info(f"Done with encoding")

In [117]:
# load index
def create_index(index_dir, size=384):
    """create index from embedding parts"""
    files = os.listdir(index_dir)

    index = faiss.IndexFlatL2(size)   # build the index
    print(index.is_trained)

    for file in files:
        if file.endswith(".pt"):
            index.add(torch.load(index_dir+"/"+file)) 
    index.save(index_dir+"/e5.index")


def load_index(index_dir):
    """load faiss index"""
    index = faiss.read_index(index_dir+"/e5.index")
    return index

In [118]:
# write results
def write_trec(topics, I, D, ids):
    """write results as trec"""
    with open("../results/trec/e5.WT", "w") as f:
        for qid, query, results in zip(topics["qid"].to_list(), I, D):
            for rank, (doc_id, distance) in enumerate(zip(query, results)):
                f.write("{} Q0 {} {} {} IRC-e5\n".format(qid, ids[doc_id], rank, 10-distance))
                # print(qid, "Q0", ids[doc_id], rank, 10-distance, "IRC-e5")

In [56]:
# Topics

topics = pt.io.read_topics("../" + config["WT"]["train"]["topics"])  # load topics
query_embedding = calc_embeddings(topics["query"], mode='query')     # encode query embeddings

index = load_index(index_dir)                                        # load index
D, I = index.search(query_embedding[:2], k = 1000)                   # actual search
