In [1]:
import torch
import sys
import os
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import wandb

In [2]:
project_root = r"C:\Users\kaleb\Code\TwoTowerSearch"

if project_root not in sys.path:
    sys.path.append(project_root)

print("Project root:", project_root)
print("Current working directory:", os.getcwd())

Project root: C:\Users\kaleb\Code\TwoTowerSearch
Current working directory: c:\Users\kaleb\Code\TwoTowerSearch


In [3]:
from src.dev.RNN_encoder import RNNEncoder

In [4]:
# generate random embeddings with size torch.Size([100038, 100])
lookup_embeddings = torch.randn(100038, 100)

In [5]:
# prepare embeddings for rnn encoder by ordering them in a batch
# we are using a batch size of 10: 5 relevant docs and 5 non relevant docs for a single query

# generate random batch of embeddings
query_embeddings = torch.randn(5000, 100)
doc_embeddings = torch.randn(50000, 100)

batch_size = 10

In [6]:
def triplet_loss_function(query, relevant_documents, irrelevant_documents, margin):
    # Calculate cosine distances
    relevant_distances = 1 - F.cosine_similarity(query, relevant_documents, dim=1)
    irrelevant_distances = 1 - F.cosine_similarity(query, irrelevant_documents, dim=1)

    # Calculate triplet loss for each relevant and irrelevant pair
    triplet_loss = F.relu(relevant_distances - irrelevant_distances + margin).mean()
    return triplet_loss

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# Define parameters
input_dim = 100
hidden_dim = 128
num_layers = 1
bidirectional = True
num_epochs = 10
query_sequence_length = 30
doc_sequence_length = 400
num_relevant = 5
num_irrelevant = 5

In [10]:
# Initialize wandb
wandb.init(
    project="TwoTowers",
    name="recovering",
    config={
        "batch_size": num_relevant + num_irrelevant,
        "query_sequence_length": query_sequence_length,
        "doc_sequence_length": doc_sequence_length,
        "embed_dims": input_dim,
    },
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

In [9]:
query_encoder = RNNEncoder(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    bidirectional=bidirectional,
).to(device)
doc_encoder = RNNEncoder(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    bidirectional=bidirectional,
).to(device)

In [10]:
# Define optimizer
optimizer = optim.Adam(
    list(query_encoder.parameters()) + list(doc_encoder.parameters()), lr=0.001
)
margin = 1.0

In [None]:
# Improved input retrieval
query_input = query_embeddings[:1, :query_sequence_length, :input_dim]
# Pad or truncate the query to have a fixed length of query_sequence_length
if query_input.shape[1] < query_sequence_length:
    padding = torch.zeros((1, query_sequence_length - query_input.shape[1], input_dim))
    query_input = torch.cat((query_input, padding), dim=1)
else:
    query_input = query_input[
        :, :query_sequence_length, :input_dim
    ]  # Single query, shape (1, sequence_length, input_dim)
relevant_doc_input = doc_embeddings[:num_relevant, :doc_sequence_length, :input_dim]
# Pad or truncate documents to have a fixed maximum length of doc_sequence_length
if relevant_doc_input.shape[1] < doc_sequence_length:
    padding = torch.zeros(
        (num_relevant, doc_sequence_length - relevant_doc_input.shape[1], input_dim)
    )
    relevant_doc_input = torch.cat((relevant_doc_input, padding), dim=1)
else:
    relevant_doc_input = relevant_doc_input[
        :, :doc_sequence_length, :input_dim
    ]  # Relevant documents, shape (num_relevant, sequence_length, input_dim)
irrelevant_doc_input = doc_embeddings[
    num_relevant : num_relevant + num_irrelevant, :doc_sequence_length, :input_dim
]
# Pad or truncate documents to have a fixed maximum length of doc_sequence_length
if irrelevant_doc_input.shape[1] < doc_sequence_length:
    padding = torch.zeros(
        (num_irrelevant, doc_sequence_length - irrelevant_doc_input.shape[1], input_dim)
    )
    irrelevant_doc_input = torch.cat((irrelevant_doc_input, padding), dim=1)
else:
    irrelevant_doc_input = irrelevant_doc_input[
        :, :doc_sequence_length, :input_dim
    ]  # Irrelevant documents, shape (num_irrelevant, sequence_length, input_dim)

In [None]:
# Move inputs to the appropriate device
query_input = query_input.to(device)
relevant_doc_input = relevant_doc_input.to(device)
irrelevant_doc_input = irrelevant_doc_input.to(device)

In [None]:
# Training loop
for epoch in range(num_epochs):
    # Get query and document embeddings
    query_embedding = query_encoder(query_input).repeat(
        num_relevant + num_irrelevant, 1
    )
    relevant_doc_embeddings = doc_encoder(relevant_doc_input)
    irrelevant_doc_embeddings = doc_encoder(irrelevant_doc_input)

    # Compute the triplet loss
    loss = triplet_loss_function(
        query_embedding, relevant_doc_embeddings, irrelevant_doc_embeddings, margin
    )

    # Backward pass to compute gradients
    loss.backward()

    # Update the parameters
    optimizer.step()

    # Zero the gradients after updating
    optimizer.zero_grad()

    # Log the loss to wandb
    wandb.log({"loss": loss.item()})

    # Save the model
    save_path = f"src/models/twotowers/RNN_epoch_{epoch}.pth"
    torch.save(
        {
            "query_encoder_state_dict": query_encoder.state_dict(),
            "doc_encoder_state_dict": doc_encoder.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        },
        save_path,
    )

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

wandb.finish()
print("Training complete.")

In [None]:
print(
    "Query Embedding Shape:", query_embedding.shape
)  # Expected shape: (num_relevant + num_irrelevant, hidden_dim * 2) if bidirectional else (num_relevant + num_irrelevant, hidden_dim)
print(
    "Relevant Document Embedding Shape:", relevant_doc_embeddings.shape
)  # Expected shape: (num_relevant, hidden_dim * 2) if bidirectional else (num_relevant, hidden_dim)
print(
    "Irrelevant Document Embedding Shape:", irrelevant_doc_embeddings.shape
)  # Expected shape: (num_irrelevant, hidden_dim * 2) if bidirectional else (num_irrelevant, hidden_dim)