In [None]:
import orjson

sentences = []
with open("tdt.jsonl", 'rb') as fin:
    for line in fin:
        record = orjson.loads(line)
        sentences.append(record["text"])

In [None]:
import faiss
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader, Dataset


class TextDataset(Dataset):
    def __init__(self, sentences):
        self.sentences = sentences

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        return self.sentences[idx]


dataset = TextDataset(sentences)
dataloader = DataLoader(dataset, batch_size=1536, shuffle=False)

# You can use any other model
model = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
model.half()

d = 384 
index = faiss.IndexFlatIP(d)


total = len(sentences)
with tqdm(total=total, desc="Processing Progress", unit="items") as pbar:
    for i, batch in enumerate(dataloader):
        batch_embeddings = model.encode(
            batch,
            batch_size=len(batch),
            show_progress_bar=False,
            convert_to_numpy=True,
            normalize_embeddings=True,
        )
        index.add(batch_embeddings)
        pbar.update(len(batch))

In [None]:
faiss.write_index(index, "news_mpnet.index")