In [1]:
import gensim
import torch
import numpy as np
import torch.nn as nn
import pandas as pd
import os

In [2]:
from src.utils.preprocess_str import preprocess_str

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\kaleb\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\kaleb\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
# read data from data/processed/query_rel_doc.parquet
file_path = os.path.join("data", "processed", "query_rel_doc.parquet")
df = pd.read_parquet(file_path)

In [4]:
df.head(10)

Unnamed: 0,query,query_id,relevant_document,irrelevant_document,is_selected
0,what is rba,19699,"Since 2007, the RBA's outstanding reputation h...",Cover the grill again and allow to cook for an...,0
1,what is rba,19699,The Reserve Bank of Australia (RBA) came into ...,Submit. · just now. Report Abuse. cloud is vib...,0
2,what is rba,19699,RBA Recognized with the 2014 Microsoft US Regi...,Determining Flag Size. The length of the flag ...,0
3,what is rba,19699,The inner workings of a rebuildable atomizer a...,"1 On average, a chair lift can cost anywhere f...",0
4,what is rba,19699,Results-Based Accountability® (also known as R...,n. A seismic wave that travels through the ear...,0
5,what is rba,19699,Results-Based Accountability® (also known as R...,Brief History of Maytag & Washing Machine Inno...,1
6,what is rba,19699,"RBA uses a data-driven, decision-making proces...",Function. The cardiac skeleton has four major ...,0
7,what is rba,19699,vs. NetIQ Identity Manager. Risk-based authent...,Types of counseling most often used to treat s...,0
8,what is rba,19699,"A rebuildable atomizer (RBA), often referred t...","Most homeowners report spending between $3,675...",0
9,what is rba,19699,Get To Know Us. RBA is a digital and technolog...,Definition. Patellar tendinitis is an injury t...,0


In [24]:
dfa = df[["query_id", "query", "relevant_document", "irrelevant_document"]].copy()
dfa.reset_index(inplace=True)

In [25]:
# convert the strings, from each column, into list of tokens
dfa.loc[:, "query_tokens"] = dfa["query"].apply(preprocess_str)
dfa.loc[:, "relevant_document_tokens"] = dfa["relevant_document"].apply(preprocess_str)
dfa.loc[:, "irrelevant_document_tokens"] = dfa["irrelevant_document"].apply(
    preprocess_str
)

In [26]:
dfa.head(5)

Unnamed: 0,index,query_id,query,relevant_document,irrelevant_document,query_tokens,relevant_document_tokens,irrelevant_document_tokens
0,0,19699,what is rba,"Since 2007, the RBA's outstanding reputation h...",Cover the grill again and allow to cook for an...,[rba],"[sinc, rba, outstand, reput, affect, secur, np...","[cover, grill, allow, cook, anoth, minut, repe..."
1,1,19699,what is rba,The Reserve Bank of Australia (RBA) came into ...,Submit. · just now. Report Abuse. cloud is vib...,[rba],"[reserv, bank, australia, rba, came, januari, ...","[submit, report, abus, cloud, vibrat, waterit,..."
2,2,19699,what is rba,RBA Recognized with the 2014 Microsoft US Regi...,Determining Flag Size. The length of the flag ...,[rba],"[rba, recogn, microsoft, us, region, partner, ...","[determin, flag, size, length, flag, least, on..."
3,3,19699,what is rba,The inner workings of a rebuildable atomizer a...,"1 On average, a chair lift can cost anywhere f...",[rba],"[inner, work, rebuild, atom, surprisingli, sim...","[averag, chair, lift, cost, anywher, much, wou..."
4,4,19699,what is rba,Results-Based Accountability® (also known as R...,n. A seismic wave that travels through the ear...,[rba],"[resultsbas, account, also, known, rba, discip...","[seismic, wave, travel, earth, rather, across,..."


In [None]:
# save to data/processed/query_rel_doc_tokens.parquet
file_path = os.path.join("data", "processed", "query_rel_doc_tokens.parquet")
dfa.to_parquet(file_path)

In [27]:
# get embeddings for each token in the dataframe
w2v = gensim.models.Word2Vec.load(
    "src/models/word2vec-gensim-text8-custom-preprocess.model"
)

vocab = w2v.wv.index_to_key
word_to_idx = {word: i for i, word in enumerate(vocab)}
embeddings_array = np.array([w2v.wv[word] for word in vocab])
embeddings = torch.tensor(embeddings_array, dtype=torch.float32)
print(embeddings.shape)

embedding_layer = nn.Embedding.from_pretrained(embeddings, freeze=True)
# word_index = torch.tensor([word_to_idx['north']], dtype=torch.long)

torch.Size([74792, 100])


In [28]:
word_index = torch.tensor([word_to_idx["rba"]], dtype=torch.long)
# look at an example embedding

embedding_layer(word_index)

tensor([[-0.0706,  0.0265,  0.0331,  0.0277,  0.0544, -0.0174,  0.0143,  0.0584,
         -0.0007, -0.0526,  0.0089, -0.0053, -0.0213, -0.0008,  0.0012,  0.0245,
          0.0008, -0.0297, -0.0015, -0.0272, -0.0457, -0.0396,  0.0480, -0.0143,
          0.0252, -0.0490, -0.0792,  0.0075, -0.0342,  0.0052, -0.0287,  0.0002,
         -0.0270, -0.0292, -0.0008,  0.0186, -0.0223, -0.0193,  0.0210, -0.0210,
         -0.0023, -0.0083, -0.0468, -0.0091,  0.0039,  0.0252, -0.0277,  0.0018,
         -0.0053,  0.0064,  0.0097, -0.0421, -0.0847,  0.0159, -0.0606, -0.0179,
          0.0288,  0.0450, -0.0305, -0.0109, -0.0007,  0.0525,  0.0515, -0.0364,
         -0.0474,  0.0215,  0.0165,  0.0318, -0.0316, -0.0149,  0.0081,  0.0473,
         -0.0088,  0.0070,  0.0129, -0.0212, -0.0057, -0.0211, -0.0427, -0.0130,
          0.0040,  0.0112,  0.0151,  0.0323,  0.0235,  0.0425,  0.0254,  0.0153,
          0.0172,  0.0416,  0.0062,  0.0265,  0.0135,  0.0591,  0.0194,  0.0090,
         -0.0269,  0.0012, -

In [29]:
len(word_to_idx)

74792

In [30]:
def embed_tokens(tokens: list[str], unknown_tokens: set):
    valid_tokens = [token for token in tokens if token in word_to_idx]
    unknown_tokens.update(set(tokens) - set(valid_tokens))
    if valid_tokens:
        return (
            embedding_layer(
                torch.tensor(
                    [word_to_idx[token] for token in valid_tokens], dtype=torch.long
                )
            ),
            unknown_tokens,
        )
    return torch.tensor([])

In [31]:
unknown_tokens = set()

dfa.loc[:, "query_embedding"] = dfa["query_tokens"].apply(
    embed_tokens, args=(unknown_tokens,)
)
dfa.loc[:, "relevant_document_embedding"] = dfa["relevant_document_tokens"].apply(
    embed_tokens, args=(unknown_tokens,)
)
dfa.loc[:, "irrelevant_document_embedding"] = dfa["irrelevant_document_tokens"].apply(
    embed_tokens, args=(unknown_tokens,)
)

In [34]:
# Remove rows with empty embeddings
dfb = dfa[
    dfa["query_embedding"].apply(
        lambda x: len(x[0]) > 0 if isinstance(x, tuple) else len(x) > 0
    )
    & dfa["relevant_document_embedding"].apply(
        lambda x: len(x[0]) > 0 if isinstance(x, tuple) else len(x) > 0
    )
    & dfa["irrelevant_document_embedding"].apply(
        lambda x: len(x[0]) > 0 if isinstance(x, tuple) else len(x) > 0
    )
]

In [37]:
# print the number of unknown tokens
print(f"Number of unknown tokens: {len(unknown_tokens)}")
# print the first 10 unknown tokens
print(f"Unknown tokens: {list(unknown_tokens)[:10]}")

Number of unknown tokens: 2017
Unknown tokens: ['newsyou', 'ntact', 'bodyinvoluntari', 'baristaqu', 'ual', 'ouldin', 'granitelik', 'enterica', 'elastin', 'sma']


In [38]:
dfb[
    [
        "query_id",
        "query_embedding",
        "relevant_document_embedding",
        "irrelevant_document_embedding",
    ]
].head(1)

Unnamed: 0,query_id,query_embedding,relevant_document_embedding,irrelevant_document_embedding
0,19699,"([[tensor(-0.0706), tensor(0.0265), tensor(0.0...","([[tensor(0.4154), tensor(-0.2829), tensor(-1....","([[tensor(1.0163), tensor(0.9083), tensor(0.62..."
