# Updated

In [1]:
from utils.load_data import load_word2vec
from utils.preprocess_str import str_to_tokens
import torch.nn as nn
from models.core import DocumentDataset, TwoTowerModel, loss_fn
import pandas as pd
import faiss
import torch
from models.HYPERPARAMETERS import FREEZE_EMBEDDINGS, PROJECTION_DIM, MARGIN

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/jigishap/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /Users/jigishap/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
# Load embeddings

vocab,embeddings, word_to_idx = load_word2vec()
embedding_layer = nn.Embedding.from_pretrained(embeddings, freeze=FREEZE_EMBEDDINGS)

EMBEDDING_DIM = embeddings.shape[1]
VOCAB_SIZE = len(vocab)

In [3]:
pwd

'/Users/jigishap/Desktop/MLX-Week-2/app'

In [4]:
df = pd.read_parquet('data/training-with-tokens.parquet')


In [5]:
index = faiss.read_index('data/doc-index-64.faiss')

In [6]:
model = TwoTowerModel(embedding_dim=EMBEDDING_DIM, projection_dim=PROJECTION_DIM, embedding_layer=embedding_layer, margin=MARGIN)
model.load_state_dict(torch.load(f'models/two_tower_state_dict.pth'))

  model.load_state_dict(torch.load(f'models/two_tower_state_dict.pth'))


<All keys matched successfully>

In [7]:
# Function to get nearest neighbors
def get_nearest_neighbors(query, model, df, k=5):
    query_tokens = torch.tensor([str_to_tokens(query, word_to_idx)])
    query_mask = (query_tokens != 0).float()
    query_encoding = model.query_encode(query_tokens, query_mask)
    query_projection = model.query_project(query_encoding)

    query_vector = query_projection.detach().numpy()
    faiss.normalize_L2(query_vector)
    distances, indices = index.search(query_vector, k)

    documents = df.loc[indices.squeeze()]['doc_relevant']
    urls = df.loc[indices.squeeze()]['url_relevant']

    return documents, urls, distances


In [8]:
q = "What is the capital of France?"

In [9]:
documents, urls, distances = get_nearest_neighbors(q, model, df)

In [10]:
documents

8081      In Countries, States, and Cities. The currency...
129459    Rome is the capital of Italy and of the Lazio ...
66010     • Embassy is the office of the ambassador whil...
271635    1 Prague: The Capital of the Czech Republic Pr...
66005     Embassy and consulate refer to government repr...
Name: doc_relevant, dtype: object

In [11]:
urls

8081      http://www.answers.com/Q/What_is_the_currency_...
129459                   https://en.wikipedia.org/wiki/Rome
66010     http://www.differencebetween.com/difference-be...
271635    http://www.answers.com/Q/What_is_someone_from_...
66005     http://www.differencebetween.net/business/diff...
Name: url_relevant, dtype: object

In [12]:
distances

array([[0.90682334, 0.9055239 , 0.9052953 , 0.90517753, 0.90047634]],
      dtype=float32)