In [1]:
import numpy as np
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer, BatchEncoding
from transformers import AutoTokenizer, AutoModel
import sys
import os
import json
import torch.nn.functional as F
from dotenv import load_dotenv

load_dotenv()

sys.path.append("..")  # Adds the parent directory to sys path

from mailio_ai_libs.create_embeddings import Embedder

  from .autonotebook import tqdm as notebook_tqdm


model id:  intfloat/e5-small-v2


In [2]:
model_id = os.getenv("MODEL_ID")
print(model_id)

intfloat/e5-small-v2


In [3]:
base_data_dir = "../data"
subfolder = model_id.split("/")[-1]
data_dir = f"{base_data_dir}/{subfolder}"
embeddings_path = f"{data_dir}/embeddings.npy"
index_path = f"{data_dir}/embeddings_index.npy"
jsonl_files = [f for f in os.listdir(base_data_dir) if f.endswith(".jsonl") and os.path.isfile(os.path.join(base_data_dir, f))]

In [4]:
# convert database to dictionary by id
database_dict = {}
for file in jsonl_files:
    file_path = os.path.join(base_data_dir, file)
    with open(file_path, "r") as f:
        lines = f.read()
    lines = lines.split("\n")
    for line in lines:
        if line == "":
            continue
        j = json.loads(line)
        if "message_id" in j:
            database_dict[j["message_id"]] = j

In [5]:
len(database_dict)

12432

In [6]:
embeddings = np.load(embeddings_path)
index = np.load(index_path)

In [7]:
# sanity check the shapes
print(embeddings.shape, index.shape)
assert embeddings.shape[0] == index.shape[0]

(16328, 384) (16328,)


In [8]:
embeddings = torch.from_numpy(embeddings)

In [9]:
embeddings.shape

torch.Size([16328, 384])

In [10]:
embeddings = F.normalize(embeddings, p=2, dim=1)

In [11]:
# quantize the embeddings
# embeddings = embeddings.type(torch.HalfTensor)

In [12]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)

embedder = Embedder(model, tokenizer)

In [13]:
def search_embeddings(embedder, query, embeddings, index, limit=10):
    q = torch.from_numpy(embedder.embed([query]))
    q = F.normalize(q, p=2, dim=1)
    similarity = F.cosine_similarity(q, embeddings, dim=1)
    values, indices = similarity.topk(limit, dim=0)
    return indices.detach().cpu().numpy().ravel(), values.detach().cpu().numpy().ravel()

In [14]:
query = "Zoom invitation on April 21, 2021"
indices, scores = search_embeddings(embedder, query, embeddings, index, limit=20)

In [15]:
indices, scores

(array([2164, 2271, 2176, 2049, 2270, 3889, 2025, 2026, 2050, 2346, 1946,
        6102, 1804, 2344, 2244, 2233, 2317, 1350, 6113, 6085]),
 array([0.858194  , 0.85388947, 0.85373616, 0.8522421 , 0.851414  ,
        0.85116816, 0.84978604, 0.8492924 , 0.84928286, 0.8490907 ,
        0.84671575, 0.8457428 , 0.8451334 , 0.8424072 , 0.8417161 ,
        0.8412122 , 0.84103715, 0.8401063 , 0.83927107, 0.83926976],
       dtype=float32))

In [16]:
result_ids = index[indices]
for i, idx in enumerate(result_ids):
    item = database_dict[idx.item()]
    print(f"Score: {scores[i]}, Subject: {item['subject']}, id: {item['message_id']}, sentences: {item['sentences']}")

Score: 0.8581939935684204, Subject: Updated invitation: Igor <> Hakimo @ Fri Apr 30, 2021 9am - 10am
 (PDT) (igor@mail.io), id: <000000000000d0c0df05c11fa3cc@google.com>, sentences: ['This event has been changed.Igor <> HakimoWhenChanged: Fri Apr 30, 2021 9am  10am Pacific Time - Los AngelesCalendarigor@mail.ioWhosagar@hakimo.ai - organizeranuj@hakimo.aiigor@mail.io Sagar Honnungar is inviting you to a scheduled Zoom meeting.Join Zoom MeetingMeeting ID: 913 6424 4868Passcode: 135780One tap mobile+16699006833,,91364244868#,,,,*135780# US (San Jose)+12532158782,,91364244868#,,,,*135780# US (Tacoma)Dial by your location +1 669 900 6833 US (San Jose) +1 253 215 8782 US (Tacoma) +1 346 248 7799 US (Houston) +1 929 205 6099 US (New York) +1 301 715 8592 US (Washington DC) +1 312 626 6799 US (Chicago)Meeting ID: 913 6424 4868Passcode: 135780Find your local number: Going (igor@mail.io)?', '- - Invitation from You are receiving this courtesy email at the account igor@mail.io because you are an 