In [1]:
import torch
import os
import sys

import torch.optim as optim
import torch.nn.functional as F
import wandb
import pandas as pd
import pickle
from torch.utils.data import DataLoader, TensorDataset

In [2]:
from src.dev.RNN_encoder import RNNEncoder

In [3]:
# project_root = r"C:\Users\kaleb\Code\TwoTowerSearch"

# if project_root not in sys.path:
#     sys.path.append(project_root)

# print("Project root:", project_root)
# print("Current working directory:", os.getcwd())

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [6]:
# # To load the data:
# file_path = os.path.join("data", "processed", "list_of_indexes.pt")

# loaded_data = torch.load(file_path)
# query_indexes = loaded_data["query_"]
# rel_doc_indexes = loaded_data["re"]
# irrel_doc_embeddings = loaded_data["irrel_doc_embeddings"]

In [4]:
from torch.nn.utils.rnn import pad_sequence

batch_size = 10

# import and define index_token lookup



indexes = pd.read_parquet(os.path.join("data", "processed", "list_of_indexes.parquet"))
indexes = indexes.sample(frac=0.1)  # sample for local training

query_idx = indexes["query_indexes"]
rel_doc_idx = indexes["relevant_doc_indexes"]
irrel_doc_idx = indexes["irrelevant_doc_indexes"]

query_idx = [torch.tensor(seq) for seq in query_idx]
rel_doc_idx = [torch.tensor(seq) for seq in rel_doc_idx]
irrel_doc_idx = [torch.tensor(seq) for seq in irrel_doc_idx]

query_idx = pad_sequence(query_idx, batch_first=True, padding_value=0)
rel_doc_idx = pad_sequence(rel_doc_idx, batch_first=True, padding_value=0)
irrel_doc_idx = pad_sequence(irrel_doc_idx, batch_first=True, padding_value=0)

dataset = TensorDataset(query_idx, rel_doc_idx, irrel_doc_idx)
training_data = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [5]:
# just print first 10 batches
for i, batch in enumerate(training_data):
    if i >= 5:
        break
    query_batch, rel_doc_batch, irrel_doc_batch = batch
    print(query_batch.shape, rel_doc_batch.shape, irrel_doc_batch.shape)

torch.Size([10, 12]) torch.Size([10, 102]) torch.Size([10, 103])
torch.Size([10, 12]) torch.Size([10, 102]) torch.Size([10, 103])
torch.Size([10, 12]) torch.Size([10, 102]) torch.Size([10, 103])
torch.Size([10, 12]) torch.Size([10, 102]) torch.Size([10, 103])
torch.Size([10, 12]) torch.Size([10, 102]) torch.Size([10, 103])


In [6]:
# create list of indexes
import gensim

w2v = gensim.models.Word2Vec.load(
    "src/models/word2vec-gensim-text8-custom-preprocess.model"
)

vocab = w2v.wv.index_to_key
word_to_idx = {word: i for i, word in enumerate(vocab)}
embedding_lookup = torch.tensor(w2v.wv.vectors)

In [7]:
embedding_dim = 100
hidden_dim = 128
num_layers = 1
embedding_lookup = embedding_lookup.to(device)

In [8]:
query_encoder = RNNEncoder(
    embedding_lookup=embedding_lookup,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
).to(device)

doc_encoder = RNNEncoder(
    embedding_lookup=embedding_lookup,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
).to(device)

In [9]:
# Define optimizer
optimizer = optim.Adam(
    list(query_encoder.parameters()) + list(doc_encoder.parameters()), lr=0.001
)

In [10]:
def compute_triplet_loss(query_encoded, relevant_encoded, irrelevant_encoded, margin):
    relevant_distances = 1 - F.cosine_similarity(query_encoded, relevant_encoded, dim=1)
    irrelevant_distances = 1 - F.cosine_similarity(
        query_encoded, irrelevant_encoded, dim=1
    )

    triplet_loss = F.relu(relevant_distances - irrelevant_distances + margin).mean()
    return triplet_loss


def save_model(epoch, save_path):

    torch.save(
        {
            "query_encoder_state_dict": query_encoder.state_dict(),
            "doc_encoder_state_dict": doc_encoder.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        },
        save_path,
    )

In [11]:
margin = 1.0
num_epochs = 3

In [12]:
wandb.init(
    project="TwoTowers",
    name="recovering",
    config={
        "vocab_size": embedding_lookup.shape[0],
        "embedding_dim": embedding_lookup.shape[1],
        "hidden_dim": hidden_dim,
        "num_layers": num_layers,
        "margin": margin,
    },
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: kaleb-sofer (kaleb-sofer-mlx). Use `wandb login --relogin` to force relogin


In [13]:
for epoch in range(num_epochs):
    query_encoder.train()
    doc_encoder.train()

    total_loss = 0.0

    for batch in training_data:
        query, relevant, irrelevant = batch

        query = query.to(device)
        relevant = relevant.to(device)
        irrelevant = irrelevant.to(device)

        query_encoded = query_encoder(query)
        relevant_encoded = doc_encoder(relevant)
        irrelevant_encoded = doc_encoder(irrelevant)

        loss = compute_triplet_loss(
            query_encoded, relevant_encoded, irrelevant_encoded, margin
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        wandb.log({"loss": loss.item()})

    save_path = f"RNN_epoch_{epoch}.pth"
    save_model(epoch, save_path)

    avg_loss = total_loss / len(training_data)
    print(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_loss:.4f}")

wandb.finish()

print("Training complete.")