In [17]:
import sys
sys.path.append("..")
from sif_src.loader import download_ms_marco, load_data, save_processed_data
from sif_src.preprocess_data import preprocess_text
from sif_src.utils import load_glove, load_glove_vectors
from sif_src.sif import compute_word_frequencies, compute_sif_weights, compute_sif_embeddings, remove_pc
from pathlib import Path
import pandas as pd

In [2]:
train_df = pd.read_pickle("../pickle_backups/marco_train_df2024-06-04T17.38.1717490321.pickle")
valid_df = pd.read_pickle("../pickle_backups/marco_valid_df2024-06-04T17.38.1717490321.pickle")

In [3]:
glove_vectors = load_glove_vectors('../wv/glove.6B.300d.txt')

Loading GloVe Vectors: 100%|██████████| 400000/400000 [01:03<00:00, 6292.96it/s]


In [4]:
train_df['passage_text'] = train_df['passages'].apply(lambda x: x['passage_text'])


In [29]:
train_df['passage_text']

0         [Since 2007, the RBA's outstanding reputation ...
1         [In his younger years, Ronald Reagan was a mem...
2         [Sydney, New South Wales, Australia is located...
3         [In regards to tile installation costs, consum...
4         [Conclusions: In adult body CT, dose to an org...
                                ...                        
891052    [Complete a Direct Dispute Form if you have in...
891053    [Most of the city of Hitchcock is served by th...
891054    [And with no waiting period, the treatment is ...
891055    [Isaac Bell is the son of Ebenezer Bell and gr...
891056    [ANIMALS. ANIMALS. The marine biome covers thr...
Name: passage_text, Length: 891057, dtype: object

In [30]:
passage_texts = [text for sublist in train_df['passage_text'] for text in sublist]
train_corpus = train_df['query'].tolist() + passage_texts

In [13]:
word_freq = compute_word_frequencies(train_corpus)
sif_weights = compute_sif_weights(word_freq)

In [24]:
len(sif_weights), len(word_freq)

(6519006, 6519006)

In [None]:
train_queries_sif = compute_sif_embeddings(train_df['query'].tolist(), glove_vectors, sif_weights)
train_passages_sif = compute_sif_embeddings(train_df['passage_texts'], glove_vectors, sif_weights)

ValueError: Length of weights not compatible with specified axis.

In [None]:
valid_queries_sif = compute_sif_embeddings(valid_df['query'].tolist(), glove_vectors, sif_weights)
valid_passages_sif = compute_sif_embeddings(passage_texts.tolist(), glove_vectors, sif_weights)

In [None]:
train_queries_sif = remove_pc(train_queries_sif)
train_passages_sif = remove_pc(train_passages_sif)
valid_queries_sif = remove_pc(valid_queries_sif)
valid_passages_sif = remove_pc(valid_passages_sif)

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit

# Combine query and passage embeddings
X_train_sif = jnp.concatenate([train_queries_sif, train_passages_sif], axis=1)
X_valid_sif = jnp.concatenate([valid_queries_sif, valid_passages_sif], axis=1)
y_train = jnp.array(train_df['label'].values)
y_valid = jnp.array(valid_df['label'].values)

# Define linear regression functions
def predict(params, inputs):
    return jnp.dot(inputs, params)

def loss_fn(params, inputs, targets):
    predictions = predict(params, inputs)
    return jnp.mean((predictions - targets) ** 2)

def update(params, inputs, targets, learning_rate=0.01):
    grads = grad(loss_fn)(params, inputs, targets)
    return params - learning_rate * grads

# Train the model
def train_model(X_train, y_train, num_epochs=100, learning_rate=0.01):
    params = jnp.zeros(X_train.shape[1])
    for epoch in range(num_epochs):
        params = update(params, X_train, y_train, learning_rate)
    return params




In [None]:
params_sif = train_model(X_train_sif, y_train)

In [None]:
# Evaluate the model
def evaluate_model(params, X_valid, y_valid):
    predictions = predict(params, X_valid)
    return jnp.mean((predictions - y_valid) ** 2)

mse_sif = evaluate_model(params_sif, X_valid_sif, y_valid)
print(f"SIF MSE: {mse_sif}")