# Two Towers Model
- based on word2vec embeddings from gensim
- use a simple average of the word embeddings as the document embedding
- use a simple feedforward neural network as the encoder


In [3]:
%load_ext autoreload
%autoreload 2


In [None]:
# Import Libraries
import os
import sys

import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import faiss
import numpy as np

import wandb
from tqdm import tqdm

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from utils.collate import collate
from utils.load_data import load_word2vec
from utils.preprocess_str import str_to_tokens
from utils.checkpoint import save_checkpoint
from core import DocumentDataset, TwoTowerModel, loss_fn


In [3]:
# import inspect
# print(inspect.getsource(collate).split('\n', 1)[1].strip())

In [4]:
# import importlib
# import utils.checkpoint

# importlib.reload(utils.checkpoint)

In [24]:
# Define HYPERPARAMETERS
RANDOM_SEED = 42
FREEZE_EMBEDDINGS = True
VERBOSE = True
HIDDEN_DIM = 128
NUM_LAYERS = 1
MARGIN = 0.5
LEARNING_RATE = 0.00001
NUM_EPOCHS = 3
MODEL_NAME = "mlx-w2-two-tower-search"
PROJECTION_DIM = 64



In [6]:
# Load embeddings
vocab,embeddings, word_to_idx = load_word2vec()
embedding_layer = nn.Embedding.from_pretrained(embeddings, freeze=FREEZE_EMBEDDINGS)

EMBEDDING_DIM = embeddings.shape[1]
VOCAB_SIZE = len(vocab)

In [7]:
# Load training data
df = pd.read_parquet('../data/training.parquet')
df_validation = pd.read_parquet('../data/validation.parquet')
df_test = pd.read_parquet('../data/test.parquet')
# df = df.sample(n=10000, random_state=RANDOM_SEED)

In [7]:
def tokenize(df, word_to_idx):
    # Tokenize
    df.loc[:, 'doc_rel_tokens'] = df['doc_relevant'].apply(lambda x: str_to_tokens(x, word_to_idx))
    df.loc[:, 'doc_irr_tokens'] = df['doc_irrelevant'].apply(lambda x: str_to_tokens(x, word_to_idx))
    df.loc[:, 'query_tokens'] = df['query'].apply(lambda x: str_to_tokens(x, word_to_idx))
    return df


In [9]:
# Preprocess data

df = tokenize(df, word_to_idx)
# df['doc_rel_tkn_length'] = df['doc_rel_tokens'].apply(len)
# df['doc_irr_tkn_length'] = df['doc_irr_tokens'].apply(len)
# df['query_tkn_length'] = df['query_tokens'].apply(len)



In [11]:
df_full = df.copy()

In [12]:
df = df.sample(n=100000, random_state=RANDOM_SEED)


In [None]:
if VERBOSE:
    df[['query_tkn_length']].hist(bins=100, layout=(2,1), figsize=(3, 3))
    df[['doc_rel_tkn_length']].hist(bins=100, layout=(2,1), figsize=(10, 3))
    df[['doc_irr_tkn_length']].hist(bins=100, layout=(2,1), figsize=(10, 3))

In [13]:

# Create dataset and dataloader
dataset = DocumentDataset(df_full)

In [14]:

dataloader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=collate)

In [None]:
if VERBOSE:
    i = 0
    for docs_rel, docs_irr, queries, docs_rel_mask, docs_irr_mask, query_mask in dataloader:
        print('Batch', i + 1)
        print("Relevant Documents shape:", docs_rel.shape)
        print("Irrelevant Documents shape:", docs_irr.shape)
        print("Queries shape:", queries.shape)
        print("Relevant Document mask shape:", docs_rel_mask.shape)
        print("Irrelevant Document mask shape:", docs_irr_mask.shape)
        print("Query mask shape:", query_mask.shape)

        i += 1
        if i > 0:
            break  # Just print the first batch


In [27]:
import importlib
import core

importlib.reload(core)
from core import TwoTowerModel, DocumentDataset, loss_fn

import utils.collate

importlib.reload(utils.collate)
from utils.collate import collate


In [49]:

# Create model
model = TwoTowerModel(embedding_dim=EMBEDDING_DIM, projection_dim=PROJECTION_DIM, embedding_layer=embedding_layer, margin=MARGIN)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)



In [None]:
if VERBOSE:
    model.eval()
    dataiter = iter(dataloader)
    docs_rel, docs_irr, queries, docs_rel_mask, docs_irr_mask, query_mask = next(dataiter)

    print("docs_rel.shape:", docs_rel.shape)
    print("docs_irr.shape:", docs_irr.shape)
    print("queries.shape:", queries.shape)
    print("docs_rel_mask.shape:", docs_rel_mask.shape)
    print("docs_irr_mask.shape:", docs_irr_mask.shape)
    print("query_mask.shape:", query_mask.shape)

    with torch.no_grad():
        similarity_rel = model(docs_rel, queries, doc_mask=docs_rel_mask, query_mask=query_mask)
        similarity_irr = model(docs_irr, queries, doc_mask=docs_irr_mask, query_mask=query_mask)
    similarity_rel, similarity_irr

    loss = loss_fn(similarity_rel, similarity_irr, MARGIN)
    print("Loss:", loss.item())

## Model Training

In [29]:
projection_dim_sweep = [24, 48, 96, 192]
margin_sweep = [0.1, 0.4, 0.7, 1.0]
lr_sweep = [LEARNING_RATE * i for i in [0.0001, 0.001, 0.01, 0.1, 1]]

In [18]:
# import importlib
# import wandb

# importlib.reload(wandb)
# import wandb


In [None]:
for projection_dim in projection_dim_sweep:
    run_name = f"avg_pooling_projection_dim_{projection_dim}_commit_3799989"
    model = TwoTowerModel(embedding_dim=EMBEDDING_DIM, 
                          projection_dim=projection_dim, 
                          embedding_layer=embedding_layer, 
                          margin=MARGIN)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    wandb.init(project=MODEL_NAME, name=run_name)
    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch + 1} of {NUM_EPOCHS}")
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx % 1000 == 0:
                print(f"Batch {batch_idx + 1} of {len(dataloader)}")
            docs_rel, docs_irr, queries, docs_rel_mask, docs_irr_mask, query_mask = batch

            similarity_rel = model(docs_rel, queries, doc_mask=docs_rel_mask, query_mask=query_mask)
            similarity_irr = model(docs_irr, queries, doc_mask=docs_irr_mask, query_mask=query_mask)

            loss = loss_fn(similarity_rel, similarity_irr, MARGIN)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            wandb.log({"loss": loss.item()})
        save_checkpoint(model, epoch, run_name)
    wandb.finish()


In [None]:
for margin in margin_sweep:
    run_name = f"avg_pooling_margin_{margin}_commit_3799989"
    model = TwoTowerModel(embedding_dim=EMBEDDING_DIM, projection_dim=PROJECTION_DIM, embedding_layer=embedding_layer, margin=margin)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    wandb.init(project=MODEL_NAME, name=run_name)
    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch + 1} of {NUM_EPOCHS}")
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx % 1000 == 0:
                print(f"Batch {batch_idx + 1} of {len(dataloader)}")
            docs_rel, docs_irr, queries, docs_rel_mask, docs_irr_mask, query_mask = batch

            similarity_rel = model(docs_rel, queries, doc_mask=docs_rel_mask, query_mask=query_mask)
            similarity_irr = model(docs_irr, queries, doc_mask=docs_irr_mask, query_mask=query_mask)

            loss = loss_fn(similarity_rel, similarity_irr, MARGIN)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            wandb.log({"loss": loss.item()})
        save_checkpoint(model, epoch, MODEL_NAME)
        
    wandb.finish()


In [None]:
for lr in lr_sweep:
    run_name = f"avg_pooling_learning_rate_{lr}_commit_3799989"
    model = TwoTowerModel(embedding_dim=EMBEDDING_DIM, projection_dim=PROJECTION_DIM, embedding_layer=embedding_layer, margin=MARGIN)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    wandb.init(project=MODEL_NAME, name=run_name)
    for epoch in range(NUM_EPOCHS):
        print('Learning Rate:', lr)
        print(f"Epoch {epoch + 1} of {NUM_EPOCHS}")
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx % 5000 == 0:
                print(f"E{epoch + 1}: Batch {batch_idx + 1} of {len(dataloader)}")
            docs_rel, docs_irr, queries, docs_rel_mask, docs_irr_mask, query_mask = batch

            similarity_rel = model(docs_rel, queries, doc_mask=docs_rel_mask, query_mask=query_mask)
            similarity_irr = model(docs_irr, queries, doc_mask=docs_irr_mask, query_mask=query_mask)

            loss = loss_fn(similarity_rel, similarity_irr, MARGIN)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            wandb.log({"loss": loss.item()})
        save_checkpoint(model, epoch, run_name)
    wandb.finish()


In [32]:
model = TwoTowerModel(embedding_dim=EMBEDDING_DIM, projection_dim=PROJECTION_DIM, embedding_layer=embedding_layer, margin=MARGIN)

In [None]:
print(os.getcwd())

In [None]:
model.load_state_dict(torch.load(f'./checkpoints/avg_pooling_learning_rate_1e-05_commit_3799989_20241024_170821_epoch_3_3799989.pth'))

In [36]:
def validate(model, dataloader, margin):
    model.eval()
    total_loss = 0
    total_batches = 0
    
    with torch.no_grad():
        for batch in dataloader:
            docs_rel, docs_irr, queries, docs_rel_mask, docs_irr_mask, query_mask = batch
            
            similarity_rel = model(docs_rel, queries, doc_mask=docs_rel_mask, query_mask=query_mask)
            similarity_irr = model(docs_irr, queries, doc_mask=docs_irr_mask, query_mask=query_mask)
            
            loss = loss_fn(similarity_rel, similarity_irr, margin)
            total_loss += loss.item()
            total_batches += 1
    
    avg_loss = total_loss / total_batches
    return avg_loss

In [37]:

df_validation_sample = tokenize(df_validation.sample(n=1000, random_state=RANDOM_SEED), word_to_idx).reset_index(drop=True)


In [38]:
# Create validation dataset and dataloader
validation_dataset = DocumentDataset(df_validation_sample)
validation_dataloader = DataLoader(validation_dataset, batch_size=32, shuffle=False, collate_fn=collate)


In [None]:
validate(model, validation_dataloader, MARGIN)

In [None]:
rel_doc = df_validation_sample.loc[0, 'doc_relevant']
irr_doc = df_validation_sample.loc[0, 'doc_irrelevant']
query = df_validation_sample.loc[0, 'query']

rel_doc_tokens = torch.tensor(df_validation_sample.loc[0, 'doc_rel_tokens'])
irr_doc_tokens = torch.tensor(df_validation_sample.loc[0, 'doc_irr_tokens'])
query_tokens = torch.tensor(df_validation_sample.loc[0, 'query_tokens'])

print(rel_doc_tokens.shape)
print(irr_doc_tokens.shape)
print(query_tokens.shape)


In [None]:
model.eval()
with torch.no_grad():
    similarity_rel = model(rel_doc_tokens.unsqueeze(0), query_tokens.unsqueeze(0))
    similarity_irr = model(irr_doc_tokens.unsqueeze(0), query_tokens.unsqueeze(0))

similarity_rel, similarity_irr


In [42]:
query = "What are the effects of climate change?"
documents = [
    "Climate change is causing rising sea levels and more frequent extreme weather events.",
    "The Earth orbits around the Sun in an elliptical path.",
    "Global warming is leading to the melting of polar ice caps and glaciers.",
    "Photosynthesis is the process by which plants convert sunlight into energy.",
    "Increased greenhouse gas emissions are a major contributor to global climate change.",
    "The recipe for a classic Margherita pizza includes fresh mozzarella, tomatoes, and basil.",
    "The history of the Roman Empire is marked by significant military conquests and cultural achievements.",
    "Quantum mechanics explores the behavior of particles at the atomic and subatomic levels.",
    "The rules of chess involve strategic movement of pieces like the knight, bishop, and rook.",
    "The process of photosynthesis in plants involves converting carbon dioxide and water into glucose and oxygen using sunlight."
]


In [43]:
model.eval()
with torch.no_grad():
    # Tokenize and prepare the query
    query_tokens = torch.tensor([str_to_tokens(query, word_to_idx)])
    query_mask = (query_tokens != 0).float()

    # Tokenize and prepare the documents
    doc_tokens = [torch.tensor([str_to_tokens(doc, word_to_idx)]) for doc in documents]
    doc_masks = [(doc != 0).float() for doc in doc_tokens]

    # Calculate similarities
    similarities = []
    for doc, mask in zip(doc_tokens, doc_masks):
        similarity = model(doc, query_tokens, doc_mask=mask, query_mask=query_mask)
        similarities.append(similarity.item())

    # Sort documents by similarity
    ranked_docs = sorted(zip(documents, similarities), key=lambda x: x[1], reverse=True)



In [None]:
df_ranked_docs = pd.DataFrame(ranked_docs, columns=['Document', 'Similarity'])
df_ranked_docs['Query'] = query
df_ranked_docs = df_ranked_docs[['Query', 'Document', 'Similarity']]
pd.set_option('display.max_colwidth', None)

styled_df = df_ranked_docs.style.set_table_styles(
    {
        'Query': [{'selector': '', 'props': [('width', '150px')]}],
        'Document': [{'selector': '', 'props': [('width', '600px')]}]
    }
)

styled_df



# Building the document embeddings matrix
- For each document, compute the projection
- Store the document embeddings in a matrix
- Store the document ids in a list

In [36]:
df = pd.read_parquet('../data/training-with-tokens.parquet')


NameError: name 'pd' is not defined

In [111]:
df = df[['query', 'doc_relevant', 'url_relevant']]


In [112]:
# No need, we have the tokenized parquet file — run if you need to rebuild this
# tqdm.pandas()
# df['doc_rel_tokens'] = df.apply(lambda x: str_to_tokens(x['doc_relevant'], word_to_idx), axis=1)
# df.to_parquet('../data/training_with_tokens.parquet')



In [86]:
# def get_doc_projection(model, doc_tokens):
#     doc_tensor = torch.tensor(df.loc[0, 'doc_rel_tokens']).unsqueeze(0)
#     doc_mask = (doc_tensor != 0).float()
#     doc_embeddings = model.embedding(doc_tensor)
#     doc_encoding = doc_embeddings.mean(dim=1).unsqueeze(1)
#     doc_projection = model.doc_project(doc_encoding).squeeze()
#     return doc_projection



In [117]:
import importlib
import core

importlib.reload(core)
from core import DocDataset, collate_docdataset


In [None]:
doc_dataset = DocDataset(df, word_to_idx)
doc_dataloader = DataLoader(doc_dataset, batch_size=32, shuffle=False, collate_fn=collate_docdataset)
    # Shuffle MUST be set to false to preserve the order of the documents
for tokens, mask, indices in doc_dataloader:
    print(tokens.shape)
    print(mask.shape)
    print(indices.shape)
    break

In [None]:
model.eval()

doc_projections = []

with torch.no_grad():
    for batch_tokens, batch_mask, batch_indices in tqdm(doc_dataloader):

        doc_encodings = model.doc_encode(batch_tokens, batch_mask)
        batch_projections = model.doc_project(doc_encodings)

        doc_projections.append(batch_projections)
        doc_indices.append(batch_indices)


In [None]:
doc_projections = torch.cat(doc_projections, dim=0)


In [130]:

doc_projection_dim = doc_projections.shape[1] # same as PROJECTION_DIM
num_docs = doc_projections.shape[0] # same as len(df), len(doc_indices)
doc_embedding_matrix = nn.Embedding.from_pretrained(doc_projections, freeze=True)


In [136]:
torch.save(doc_embedding_matrix.weight.data, '../data/doc-embedding-matrix-64.pth')

# Build the FAISS index

In [26]:
import os
import sys
import torch
from torch import nn
import faiss
import numpy as np
from tqdm import tqdm

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from utils.preprocess_str import str_to_tokens
from utils.load_data import load_word2vec
from core import TwoTowerModel

vocab,embeddings, word_to_idx = load_word2vec()
embedding_layer = nn.Embedding.from_pretrained(embeddings, freeze=True)

EMBEDDING_DIM = embeddings.shape[1]
PROJECTION_DIM = 64
MARGIN = 0.5



In [2]:
doc_projections = torch.load('../data/doc-embedding-matrix-64.pth', weights_only=True)
doc_projections_matrix = nn.Embedding.from_pretrained(doc_projections, freeze=True)

In [3]:
def get_doc_projection(doc_id, doc_projections_matrix=doc_projections_matrix):
    return doc_projections_matrix(torch.tensor(doc_id))



In [4]:
get_doc_projection(0)

tensor([ 0.1573, -0.1160,  0.0340,  0.0244, -0.1314,  0.1808,  0.1675,  0.0514,
         0.0780,  0.0397,  0.0029,  0.1056,  0.0607, -0.0872,  0.1253, -0.0288,
        -0.2060,  0.0794,  0.0403, -0.1794, -0.1079,  0.1292,  0.0056,  0.0942,
         0.2491, -0.0205, -0.1048, -0.2030,  0.0147, -0.0356, -0.0752,  0.0238,
         0.0261,  0.0346,  0.0281,  0.0402,  0.0281, -0.0985, -0.1416,  0.1403,
        -0.0901, -0.0294,  0.1185,  0.1044,  0.0388,  0.0470, -0.0932, -0.0366,
         0.0066,  0.0458, -0.0994, -0.0824, -0.1157,  0.0107,  0.0670, -0.0250,
         0.1338, -0.0573,  0.1258, -0.0362,  0.0197,  0.0346, -0.0270,  0.1348])

In [5]:
doc_projections_np = doc_projections_matrix.weight.data.numpy()
dimension = doc_projections_np.shape[1]


In [6]:
doc_projections_np[0]

array([ 0.15728916, -0.11602148,  0.03402309,  0.0243528 , -0.13137114,
        0.18077607,  0.16750537,  0.05138404,  0.07804341,  0.03973445,
        0.00292916,  0.10556114,  0.06071091, -0.08719876,  0.12527376,
       -0.02882197, -0.20598128,  0.07935658,  0.04032397, -0.17938608,
       -0.10789324,  0.1292444 ,  0.00563716,  0.09421676,  0.24905732,
       -0.02047169, -0.1047775 , -0.2030437 ,  0.01469753, -0.03559981,
       -0.07521243,  0.0238389 ,  0.02609443,  0.03464166,  0.02805131,
        0.04018955,  0.02807336, -0.09852565, -0.14158927,  0.14026608,
       -0.0900612 , -0.02940441,  0.1184871 ,  0.10437067,  0.0388245 ,
        0.04695689, -0.09324557, -0.03660403,  0.00662894,  0.04576072,
       -0.09939746, -0.0824083 , -0.11573933,  0.01073795,  0.06702912,
       -0.02503499,  0.13383216, -0.05731182,  0.12581532, -0.03617848,
        0.01969791,  0.03459676, -0.02697205,  0.13477667], dtype=float32)

In [7]:
# Function to normalize vectors in batches
def normalize_vectors(vectors, batch_size=10000):
    normalized_vectors = np.copy(vectors)
    for i in tqdm(range(0, len(normalized_vectors), batch_size)):
        batch = normalized_vectors[i:i+batch_size]
        faiss.normalize_L2(batch)
        normalized_vectors[i:i+batch_size] = batch
    return normalized_vectors

In [8]:
normalized_doc_projections_np = normalize_vectors(doc_projections_np)
normalized_doc_projections_np[0]

100%|██████████| 68/68 [00:00<00:00, 1619.67it/s]


array([ 0.2017562 , -0.14882177,  0.04364172,  0.03123755, -0.16851091,
        0.23188305,  0.2148606 ,  0.06591076,  0.10010696,  0.05096773,
        0.00375726,  0.1354042 ,  0.07787441, -0.11185062,  0.16068976,
       -0.03697019, -0.264214  ,  0.10179138,  0.0517239 , -0.2301001 ,
       -0.1383956 ,  0.16578293,  0.00723084,  0.12085267,  0.31946802,
       -0.02625922, -0.13439901, -0.26044592,  0.01885265, -0.04566418,
       -0.09647565,  0.03057837,  0.03347155,  0.04443516,  0.03598166,
        0.05155149,  0.03600994, -0.12637971, -0.1816178 ,  0.17992052,
       -0.1155223 , -0.0377173 ,  0.15198445,  0.13387717,  0.04980053,
        0.06023201, -0.11960691, -0.04695231,  0.008503  ,  0.05869768,
       -0.12749799, -0.10570585, -0.14845985,  0.01377366,  0.08597884,
       -0.03211259,  0.17166768, -0.07351437,  0.1613844 , -0.04640646,
        0.02526668,  0.04437757, -0.03459728,  0.17287922], dtype=float32)

In [9]:
index = faiss.IndexFlatIP(dimension)

batch_size = 1000
for i in tqdm(range(0, len(normalized_doc_projections_np), batch_size)):
    batch = normalized_doc_projections_np[i:i+batch_size]
    index.add(batch)


100%|██████████| 677/677 [00:00<00:00, 2988.85it/s]


In [42]:
faiss.write_index(index, '../data/doc-index-64.faiss')

# Test it out!

In [37]:
import pandas as pd

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from core import TwoTowerModel

In [38]:
df = pd.read_parquet('../data/training-with-tokens.parquet')
# index = faiss.read_index('../data/doc-index-64.faiss')

In [27]:
model = TwoTowerModel(embedding_dim=EMBEDDING_DIM, projection_dim=PROJECTION_DIM, embedding_layer=embedding_layer, margin=MARGIN)
model.load_state_dict(torch.load(f'./checkpoints/avg_pooling_learning_rate_1e-05_commit_3799989_20241024_170821_epoch_3_3799989.pth'))


  model.load_state_dict(torch.load(f'./checkpoints/avg_pooling_learning_rate_1e-05_commit_3799989_20241024_170821_epoch_3_3799989.pth'))


<All keys matched successfully>

In [70]:
# Function to get nearest neighbors
def get_nearest_neighbors(query, model, df, k=5):
    query_tokens = torch.tensor([str_to_tokens(query, word_to_idx)])
    query_mask = (query_tokens != 0).float()
    query_encoding = model.query_encode(query_tokens, query_mask)
    query_projection = model.query_project(query_encoding)

    query_vector = query_projection.detach().numpy()
    faiss.normalize_L2(query_vector)
    distances, indices = index.search(query_vector, k)

    documents = df.loc[indices.squeeze()]['doc_relevant']
    urls = df.loc[indices.squeeze()]['url_relevant']

    return documents, urls, distances

    # return df.loc[indices][['doc_relevant', 'url_relevant']]


In [46]:
q = "What is the capital of France?"


In [77]:
documents, urls, distances = get_nearest_neighbors(q, model, df)

In [78]:
documents

8081      In Countries, States, and Cities. The currency...
129459    Rome is the capital of Italy and of the Lazio ...
66010     • Embassy is the office of the ambassador whil...
271635    1 Prague: The Capital of the Czech Republic Pr...
66005     Embassy and consulate refer to government repr...
Name: doc_relevant, dtype: object

In [79]:
urls

8081      http://www.answers.com/Q/What_is_the_currency_...
129459                   https://en.wikipedia.org/wiki/Rome
66010     http://www.differencebetween.com/difference-be...
271635    http://www.answers.com/Q/What_is_someone_from_...
66005     http://www.differencebetween.net/business/diff...
Name: url_relevant, dtype: object

In [80]:
distances

array([[0.90682334, 0.9055239 , 0.9052953 , 0.90517753, 0.90047634]],
      dtype=float32)

In [68]:
indices.squeeze()

array([  8081, 129459,  66010, 271635,  66005])

In [30]:
q_tokens = torch.tensor([str_to_tokens(q, word_to_idx)])
q_mask = (q_tokens != 0).float()

q_encoding = model.query_encode(q_tokens, q_mask)
q_projection = model.query_project(q_encoding)

In [33]:
q_projection_np = q_projection.detach().numpy()

In [40]:
_, indices = get_nearest_neighbors(q_projection_np)
df.loc[indices]['doc_relevant']


8081      In Countries, States, and Cities. The currency...
129459    Rome is the capital of Italy and of the Lazio ...
66010     • Embassy is the office of the ambassador whil...
271635    1 Prague: The Capital of the Czech Republic Pr...
66005     Embassy and consulate refer to government repr...
Name: doc_relevant, dtype: object

# Compute the cosine similarity between the query and the document embeddings

In [None]:
query_encoding = model.query_encode(query_tokens, query_mask)
query_projection = model.query_project(query_encoding)

cosine_similarities = F.cosine_similarity(query_projection, doc_projections, dim=1)
cosine_similarities