In [1]:
import json
import torch
from transformers import AutoTokenizer, AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "facebook/contriever"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [3]:
def embed_texts(texts):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state[:, 0]
    return embeddings.cpu()

In [4]:
N = 10000
corpus_path = "./nq/corpus.jsonl"

corpus_texts = []
corpus_ids = []

with open(corpus_path, 'r') as f:
    for i, line in enumerate(f):
        if i >= N:
            break
        doc = json.loads(line)
        corpus_ids.append(doc["_id"])
        corpus_texts.append(doc["text"])

In [5]:
batch_size = 64
all_embeddings = []

for i in range(0, len(corpus_texts), batch_size):
    batch = corpus_texts[i:i+batch_size]
    emb = embed_texts(batch)
    all_embeddings.append(emb)
    print(f"Embedded batch {i // batch_size + 1} / {(len(corpus_texts) - 1) // batch_size + 1}")

# Stack all the batch outputs into a single tensor
corpus_embeddings = torch.cat(all_embeddings, dim=0)

Embedded batch 1 / 157
Embedded batch 2 / 157
Embedded batch 3 / 157
Embedded batch 4 / 157
Embedded batch 5 / 157
Embedded batch 6 / 157
Embedded batch 7 / 157
Embedded batch 8 / 157
Embedded batch 9 / 157
Embedded batch 10 / 157
Embedded batch 11 / 157
Embedded batch 12 / 157
Embedded batch 13 / 157
Embedded batch 14 / 157
Embedded batch 15 / 157
Embedded batch 16 / 157
Embedded batch 17 / 157
Embedded batch 18 / 157
Embedded batch 19 / 157
Embedded batch 20 / 157
Embedded batch 21 / 157
Embedded batch 22 / 157
Embedded batch 23 / 157
Embedded batch 24 / 157
Embedded batch 25 / 157
Embedded batch 26 / 157
Embedded batch 27 / 157
Embedded batch 28 / 157
Embedded batch 29 / 157
Embedded batch 30 / 157
Embedded batch 31 / 157
Embedded batch 32 / 157
Embedded batch 33 / 157
Embedded batch 34 / 157
Embedded batch 35 / 157
Embedded batch 36 / 157
Embedded batch 37 / 157
Embedded batch 38 / 157
Embedded batch 39 / 157
Embedded batch 40 / 157
Embedded batch 41 / 157
Embedded batch 42 / 157
E

In [6]:
embedding_file = f"corpus_embeddings_{len(corpus_embeddings)}.pt"
id_file = f"corpus_ids_{len(corpus_embeddings)}.json"

# Save embeddings
torch.save(corpus_embeddings, embedding_file)

# Save IDs
with open(id_file, 'w') as f:
    json.dump(corpus_ids, f)