# Load embeddings

In [3]:
import tables
import numpy as np
import json
import pickle
import os


DATA_PATH = "/data/zhangzhuocheng/Lab/Python/LLM/datasets/RAG/wikipedia/wiki_2021"


# open embeddings
print("Loading embeddings")
if os.path.exists(os.path.join(DATA_PATH, "contriever_embeddings", "embeds.npy")):
    embeds = np.memmap(
        os.path.join(DATA_PATH, "contriever_embeddings", "embeds.npy"),
        mode="r",
        dtype=np.float16,
        shape=(33176581, 768),
    )
else:
    embeds = os.listdir(os.path.join(DATA_PATH, "contriever_embeddings"))
    embeds = [i for i in embeds if i.startswith("passages")]
    embeds = sorted(embeds)
    print(embeds)
    embeds = [
        pickle.load(open(os.path.join(DATA_PATH, "contriever_embeddings", i), "rb"))[1]
        for i in embeds
    ]
    embeds = np.vstack(embeds)
    embeds_mem = np.memmap(
        os.path.join(DATA_PATH, "contriever_embeddings", "embeds.npy"),
        mode="w+",
        dtype=np.float16,
        shape=(33176581, 768),
    )
    embeds_mem[:] = embeds
print(embeds.shape)

Loading embeddings
['passages_00', 'passages_01', 'passages_02', 'passages_03', 'passages_04', 'passages_05', 'passages_06', 'passages_07']
(33176581, 768)


# Load passages

In [4]:
passages = [json.loads(i) for i in open(os.path.join(DATA_PATH, "text-list-100-sec.jsonl"), "r")]
print(len(passages))

33176581


# Build HDF5 Memmap

In [5]:
from tqdm import tqdm

class RetrieveTable(tables.IsDescription):
    title = tables.StringCol(1024, pos=1)
    section = tables.StringCol(1024, pos=2)
    text = tables.StringCol(1024, pos=3)
    embedding = tables.Float16Col(768, pos=4)

db_path = "/data/zhangzhuocheng/Lab/Python/LLM/datasets/RAG/wikipedia/wiki_2021/contriever_embeddings_4/database.h5"
h5file = tables.open_file(db_path, mode="w", title="Retriever Database")
group = h5file.create_group("/", "passages", "Passages", filters=None)
table = h5file.create_table(group, "data", RetrieveTable, "data")

for emb, p in zip(tqdm(embeds), passages):
    p = {"text": p} if isinstance(p, str) else p
    row = table.row
    row["title"] = p.get("title", "").encode()
    row["section"] = p.get("section", "").encode()
    row["text"] = p.get("text", "").encode()
    row["embedding"] = emb
    row.append()
table.flush()
h5file.close()

100%|██████████| 33176581/33176581 [09:35<00:00, 57685.57it/s]


# Build Index

In [None]:
from index import FaissIndex


index = FaissIndex(
    index_path=os.path.join(DATA_PATH, "contriever_embeddings_4/index.faiss"),
    vector_sz = 768,
    n_subquantizers = 64,
    n_bits = 8,
    n_list = 4096,
    n_probe = 36,
    device_id = 0,
    train_num = 1000000,
    log_interval = 100,
)
print("Training index")
index.train_index(embeds)
print("Adding index")
index.add_embeddings(embeds, batch_size=10000)
index.serialize()

# Test indices

In [None]:
from argparse import Namespace
from retriever import DenseRetriever


args = {
    "retriever_tokenizer": "/data/zhangzhuocheng/Lab/Python/LLM/models/RAG/contriever-msmarco",
    "batch_size": 512,
    "max_query_length": 512,
    "max_passage_length": 512,
    "no_title": False,
    "lowercase": False,
    "normalize_text": False,
    "log_interval": 100,
    "query_encoder": "",
    "read_only": True,
    "compress_database": False,
    "database_path": "/data/zhangzhuocheng/Lab/Python/LLM/datasets/RAG/wikipedia/wiki_2021/contriever_embeddings_4",
    "n_subquantizers": 64,
    "n_list": 4096,
    "n_bits": 8,
    "n_probe": 32,
    "device_id": 0,
    "train_num": 1000000,
}
args = Namespace(**args)


retriever = DenseRetriever(args)
while True:
    query = input("Query: ")
    if query == "quit":
        break
    r = retriever.search([query])[0]
    pass
# retriever.test_acc(top_k=5)
retriever.close()