In [1]:
import torch
import os
from tqdm import tqdm
import gensim
import torch.optim as optim
import torch.nn.functional as F
import wandb
import pandas as pd
from tqdm import tqdm

# from RNN_encoder import RNNEncoder
from src.dev.RNN_encoder import RNNEncoder
from src.utils.training_functions import (
    prepare_dataloader,
    compute_triplet_loss,
    save_model,
)

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [3]:
batch_size = 1

indexes = pd.read_parquet(os.path.join("data", "processed", "list_of_indexes.parquet"))
# indexes = pd.read_parquet("list_of_indexes.parquet")
indexes = indexes.sample(frac=0.1).reset_index(drop=True)  # sample for local training

In [4]:
# # get first row
# indexes = indexes.iloc[0:10]
# indexes.head(1)

Unnamed: 0,query_indexes,relevant_doc_indexes,irrelevant_doc_indexes
0,"[3, 9192, 2208]","[11825, 708, 3, 9192, 995, 9192, 2292, 896, 91...","[145, 41578, 19026, 16509, 632, 1539, 864, 203..."


In [5]:
# # duplicate first 2 rows 100 times
# indexes = pd.concat([indexes] * 1000, ignore_index=True)
# len(indexes)

10000

In [4]:
training_data = prepare_dataloader(indexes, batch_size=1, padding_value=0.0)

In [5]:
# creating the embedding lookup
w2v = gensim.models.Word2Vec.load(
    "src/models/word2vec-gensim-text8-custom-preprocess.model"
)
vocab = w2v.wv.index_to_key



word_to_idx = {word: i for i, word in enumerate(vocab)}



embedding_lookup = torch.tensor(w2v.wv.vectors)

In [6]:
embedding_dim = 100
hidden_dim = 128
num_layers = 3
embedding_lookup = embedding_lookup.to(device)

margin = 1
num_epochs = 2

In [7]:
encoder = RNNEncoder(
    embedding_lookup=embedding_lookup,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
).to(device)

optimizer = optim.Adam(list(encoder.parameters()), lr=0.001)

In [8]:
from datetime import datetime

wandb.init(
    project="RNN training",
    name=f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
    reinit=True,
)

total_steps = num_epochs * len(training_data)

progress_bar = tqdm(total=total_steps, desc="Training")

for epoch in range(num_epochs):
    encoder.train()

    total_loss = 0.0

    for batch_idx, batch in enumerate(training_data):
        query, query_lengths, relevant, rel_lengths, irrelevant, irrel_lengths = batch

        query = query.to(device)
        relevant = relevant.to(device)
        irrelevant = irrelevant.to(device)

        query_lengths = query_lengths.to(device)
        rel_lengths = rel_lengths.to(device)
        irrel_lengths = irrel_lengths.to(device)

        query_encoded = encoder(query, query_lengths)
        relevant_encoded = encoder(relevant, rel_lengths)
        irrelevant_encoded = encoder(irrelevant, irrel_lengths)

        loss = compute_triplet_loss(
            query_encoded, relevant_encoded, irrelevant_encoded, margin
        )

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

        total_loss += loss.item()

        wandb.log({"loss": loss.item()})

        progress_bar.update(1)

    save_path = f"RNN_epoch_{epoch}.pth"
    save_model(epoch, save_path)

    avg_loss = total_loss / len(training_data)
    progress_bar.set_postfix({"Epoch": epoch + 1, "Avg Loss": f"{avg_loss:.4f}"})

# Close the progress bar when training is complete
progress_bar.close()
wandb.finish()

print("Training complete.")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: kaleb-sofer (kaleb-sofer-mlx). Use `wandb login --relogin` to force relogin


Training:   3%|â–Ž         | 3489/130924 [08:18<5:23:06,  6.57it/s]

: 