In [159]:
import numpy as np
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer, BatchEncoding
from transformers import AutoTokenizer, AutoModel
import sys
import os
import json
import torch.nn.functional as F

sys.path.append("..")  # Adds the parent directory to sys path

from mailio_ai_libs.create_embeddings import Embedder

In [160]:
base_data_dir = "../data"
data_dir = f"{base_data_dir}/embeddings_distilbert_base_uncased_mean_pooling"
embeddings_path = f"{data_dir}/embeddings.npy"
index_path = f"{data_dir}/embeddings_index.npy"
jsonl_files = [f for f in os.listdir(base_data_dir) if f.endswith(".jsonl") and os.path.isfile(os.path.join(base_data_dir, f))]

In [161]:
# convert database to dictionary by id
database_dict = {}
for file in jsonl_files:
    file_path = os.path.join(base_data_dir, file)
    with open(file_path, "r") as f:
        lines = f.read()
    lines = lines.split("\n")
    for line in lines:
        if line == "":
            continue
        j = json.loads(line)
        if "message_id" in j:
            database_dict[j["message_id"]] = j

In [162]:
len(database_dict)

12432

In [163]:
embeddings = np.load(embeddings_path)
index = np.load(index_path)

In [164]:
# sanity check the shapes
print(embeddings.shape, index.shape)
assert embeddings.shape[0] == index.shape[0]

(16328, 384) (16328,)


In [165]:
embeddings = torch.from_numpy(embeddings)

In [166]:
embeddings.shape

torch.Size([16328, 384])

In [167]:
embeddings = F.normalize(embeddings, p=2, dim=1)

In [168]:
# quantize the embeddings
# embeddings = embeddings.type(torch.HalfTensor)

In [169]:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

embedder = Embedder(model, tokenizer)

In [170]:
def search_embeddings(embedder, query, embeddings, index, limit=10):
    q = torch.from_numpy(embedder.embed([query]))
    q = F.normalize(q, p=2, dim=1)
    similarity = F.cosine_similarity(q, embeddings, dim=1)
    values, indices = similarity.topk(limit, dim=0)
    return indices.detach().cpu().numpy().ravel(), values.detach().cpu().numpy().ravel()

In [201]:
query = "Zoom invitation on April 21, 2021"
indices, scores = search_embeddings(embedder, query, embeddings, index, limit=20)

In [202]:
indices, scores

(array([3889,   55, 2164, 2176,  910,   53, 2244, 2227,  950, 2050,  906,
        2084,   46, 4297, 1786, 2271,  914, 4296, 5640, 2232]),
 array([0.6415538 , 0.5701928 , 0.5663995 , 0.5405959 , 0.5203955 ,
        0.50475836, 0.5047425 , 0.4949383 , 0.49443275, 0.49313182,
        0.49093646, 0.49060303, 0.4889248 , 0.4875074 , 0.48664755,
        0.4839882 , 0.48347563, 0.48092467, 0.47975484, 0.4748347 ],
       dtype=float32))

In [203]:
result_ids = index[indices]
for i, idx in enumerate(result_ids):
    item = database_dict[idx.item()]
    print(f"Score: {scores[i]}, Subject: {item['subject']}, id: {item['message_id']}, sentences: {item['sentences']}")

Score: 0.6415538191795349, Subject: None, id: <CAGbyn1QFEkTU3a0XH9CpR2Bd78CQeQ-ZSoOFir8ZecRbRwA-Sg@mail.gmail.com>, sentences: ['Ashley Smith is inviting you to a scheduled Zoom meeting.Topic: Exposing Married People During Shelter-in-Place /Group Therapy on Zoom lolTime: Apr 19, 2020 07:00 PM Pacific Time (US and Canada)Join Zoom MeetingMeeting ID: 986 8883 8642']
Score: 0.570192813873291, Subject: Catchup, id: <CALzVFDRDiZ=MXkfTugxcnaxAZv_0-nxSJiMpsej52=YD+3Aw5w@mail.gmail.com>, sentences: ["Hi Igor,It's been a long time.", 'I hope you are doing well!', 'I was wondering if you would be interested in catching up sometime over zoom.Best,Ryan']
Score: 0.5663995146751404, Subject: Updated invitation: Igor <> Hakimo @ Fri Apr 30, 2021 9am - 10am
 (PDT) (igor@mail.io), id: <000000000000d0c0df05c11fa3cc@google.com>, sentences: ['This event has been changed.Igor <> HakimoWhenChanged: Fri Apr 30, 2021 9am  10am Pacific Time - Los AngelesCalendarigor@mail.ioWhosagar@hakimo.ai - organizeranuj@h