In [8]:
import torch
import os
import sys

import torch.optim as optim
import torch.nn.functional as F
import wandb
import pandas as pd
import pickle
from torch.utils.data import DataLoader, TensorDataset

In [2]:
from src.dev.RNN_encoder import RNNEncoder

In [3]:
# project_root = r"C:\Users\kaleb\Code\TwoTowerSearch"

# if project_root not in sys.path:
#     sys.path.append(project_root)

# print("Project root:", project_root)
# print("Current working directory:", os.getcwd())

In [4]:
"""
For this script to work, incoming embeddings need to be prepared as follows:
[query_embedding, relevant_embedding, irrelevant_embedding]

e.g. for 1 query with 3 relevant and 3 irrelevant docs:

[
[[query_embedding_1], [relevant_embedding_1], [irrelevant_embedding_1]],
[[query_embedding_1], [relevant_embedding_2], [irrelevant_embedding_2]],
[[query_embedding_1], [relevant_embedding_3], [irrelevant_embedding_3]]
]
"""

'\nFor this script to work, incoming embeddings need to be prepared as follows:\n[query_embedding, relevant_embedding, irrelevant_embedding]\n\ne.g. for 1 query with 3 relevant and 3 irrelevant docs:\n\n[\n[[query_embedding_1], [relevant_embedding_1], [irrelevant_embedding_1]],\n[[query_embedding_1], [relevant_embedding_2], [irrelevant_embedding_2]],\n[[query_embedding_1], [relevant_embedding_3], [irrelevant_embedding_3]]\n]\n'

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [6]:
# # To load the data:
# file_path = os.path.join("data", "processed", "list_of_indexes.pt")

# loaded_data = torch.load(file_path)
# query_indexes = loaded_data["query_"]
# rel_doc_indexes = loaded_data["re"]
# irrel_doc_embeddings = loaded_data["irrel_doc_embeddings"]

In [18]:
from torch.nn.utils.rnn import pad_sequence

batch_size = 10


indexes = pd.read_parquet(os.path.join("data", "processed", "list_of_indexes.parquet"))
indexes = indexes.sample(frac=0.1)  # sample for local training

query_idx = indexes["query_indexes"]
rel_doc_idx = indexes["relevant_doc_indexes"]
irrel_doc_idx = indexes["irrelevant_doc_indexes"]

query_idx = [torch.tensor(seq) for seq in query_idx]
rel_doc_idx = [torch.tensor(seq) for seq in rel_doc_idx]
irrel_doc_idx = [torch.tensor(seq) for seq in irrel_doc_idx]

query_idx = pad_sequence(query_idx, batch_first=True, padding_value=0)
rel_doc_idx = pad_sequence(rel_doc_idx, batch_first=True, padding_value=0)
irrel_doc_idx = pad_sequence(irrel_doc_idx, batch_first=True, padding_value=0)

dataset = TensorDataset(query_idx, rel_doc_idx, irrel_doc_idx)
training_data = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [19]:
# just print first 10 batches
for i, batch in enumerate(training_data):
    if i >= 10:
        break
    query_batch, rel_doc_batch, irrel_doc_batch = batch
    print(query_batch.shape, rel_doc_batch.shape, irrel_doc_batch.shape)

torch.Size([10, 12]) torch.Size([10, 105]) torch.Size([10, 114])
torch.Size([10, 12]) torch.Size([10, 105]) torch.Size([10, 114])
torch.Size([10, 12]) torch.Size([10, 105]) torch.Size([10, 114])
torch.Size([10, 12]) torch.Size([10, 105]) torch.Size([10, 114])
torch.Size([10, 12]) torch.Size([10, 105]) torch.Size([10, 114])
torch.Size([10, 12]) torch.Size([10, 105]) torch.Size([10, 114])
torch.Size([10, 12]) torch.Size([10, 105]) torch.Size([10, 114])
torch.Size([10, 12]) torch.Size([10, 105]) torch.Size([10, 114])
torch.Size([10, 12]) torch.Size([10, 105]) torch.Size([10, 114])
torch.Size([10, 12]) torch.Size([10, 105]) torch.Size([10, 114])


In [10]:
input_dim = 100
hidden_dim = 128
num_layers = 1
bidirectional = False
num_epochs = 3
margin = 1.0

In [11]:
# Create an instance of the RNNEncoder for queries and documents
query_encoder = RNNEncoder(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    bidirectional=bidirectional,
).to(device)

doc_encoder = RNNEncoder(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    bidirectional=bidirectional,
).to(device)

In [12]:
# Define optimizer
optimizer = optim.Adam(
    list(query_encoder.parameters()) + list(doc_encoder.parameters()), lr=0.001
)

In [16]:
def compute_triplet_loss(query_encoded, relevant_encoded, irrelevant_encoded, margin):
    relevant_distances = 1 - F.cosine_similarity(query_encoded, relevant_encoded, dim=1)
    irrelevant_distances = 1 - F.cosine_similarity(
        query_encoded, irrelevant_encoded, dim=1
    )

    triplet_loss = F.relu(relevant_distances - irrelevant_distances + margin).mean()
    return triplet_loss


def save_model(epoch, save_path):

    torch.save(
        {
            "query_encoder_state_dict": query_encoder.state_dict(),
            "doc_encoder_state_dict": doc_encoder.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        },
        save_path,
    )

In [14]:
wandb.init(
    project="TwoTowers",
    name="recovering",
    config={
        "input_dim": input_dim,
        "hidden_dim": hidden_dim,
        "num_layers": num_layers,
        "bidirectional": bidirectional,
        "margin": margin,
    },
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\kaleb\_netrc


In [19]:
margin = 1.0
num_epochs = 3


for epoch in range(num_epochs):
    query_encoder.train()
    doc_encoder.train()

    total_loss = 0.0

    for batch in training_data:
        query, relevant, irrelevant = batch

        query_encoded = query_encoder(query)
        relevant_encoded = doc_encoder(relevant)
        irrelevant_encoded = doc_encoder(irrelevant)

        # Compute loss

        loss = compute_triplet_loss(
            query_encoded, relevant_encoded, irrelevant_encoded, margin
        )

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        wandb.log({"loss": loss.item()})

    save_path = f"RNN_epoch_{epoch}.pth"
    save_model(epoch, save_path)

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")


wandb.finish()

print("Training complete.")

Epoch [1/3], Loss: 1.4495
Epoch [2/3], Loss: 0.8880
Epoch [3/3], Loss: 1.0000


0,1
loss,▅▅▅▅▅▅▅▅▅▄▃▅▅▆▄▅▄▅▄▆▂▅▁█▇▆▇▃▅▇▅▆▆▅█▇▃▃▅▅

0,1
loss,1


Training complete.
