In [12]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchtext.data.utils import get_tokenizer
from datasets import load_dataset
from collections import defaultdict
import pickle
import json
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# # -----------------------------
# # 1. Load MS MARCO V1.1 training dataset
# # -----------------------------

# # This will stream the data, you don't have to download the full file
# dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train")  # or "validation"

In [None]:
# # -----------------------------
# # 2. Create triples of (query, relevant_docs, irrelevant_docs)
# # -----------------------------

# # 1. Build passage pool and index mapping for fast sampling and seed for reproducibility
# passage_to_idx = dict()
# idx_to_passage = []
# for row in tqdm(dataset, desc="Building passage pool..."):
#     for p in row['passages']['passage_text']:
#         if p not in passage_to_idx:
#             passage_to_idx[p] = len(idx_to_passage)
#             idx_to_passage.append(p)
# num_passages = len(idx_to_passage)

# # 2. For each query, map relevant passage indices
# triples = []
# for row in tqdm(dataset, desc="Creating triples..."):
#     query = row['query']
#     relevant_passages = row['passages']['passage_text'][:10]
#     relevant_indices = [passage_to_idx[p] for p in relevant_passages]
    
#     # For fast sampling: mask out relevant indices
#     mask = np.ones(num_passages, dtype=bool)
#     mask[relevant_indices] = False
#     irrelevant_indices = np.random.choice(np.where(mask)[0], 10, replace=False)
#     irrelevant_passages = [idx_to_passage[i] for i in irrelevant_indices]

#     triples.append((query, relevant_passages, irrelevant_passages))

# with open("triples_full.pkl", "wb") as f:
#     pickle.dump(triples, f)

# print(triples[0][0])  # query
# print(triples[0][1])  # 10 relevant docs
# print(triples[0][2])  # 10 irrelevant docs

Building passage pool...: 100%|██████████| 82326/82326 [00:23<00:00, 3508.72it/s]
Creating triples...: 100%|██████████| 82326/82326 [25:29<00:00, 53.82it/s]  


what is rba
["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.", "The Reserve Bank of Australia (RBA) came into being on 14 January 1960 as Australia 's central bank and banknote issuing authority, when the Reserve Bank Act 1959 removed the central banking functions from the Commonwealth Bank. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site

In [4]:
# -------------------------------
# 2b. Load triples from file
# -------------------------------

# Load .pkl
with open("triples_full.pkl", "rb") as f:
    triples = pickle.load(f)


In [None]:
# -------------------------------
# 3. Embed queries and documents using a pre-trained SentenceTransformer model
# -------------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_docs = 10  # Number of relevant/irrelevant documents to consider

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/msmarco-MiniLM-L6-v3')
model     = AutoModel.from_pretrained('sentence-transformers/msmarco-MiniLM-L6-v3').to(device)
embed_dim = model.config.hidden_size  # 384 for MiniLM

def tokenize_and_embed(texts, batch_size=64, show_progress=True):
    token_embeddings = []
    lengths = []
    total = len(texts)
    it = range(0, total, batch_size)
    if show_progress:
        it = tqdm(it, desc="Embedding", total=(total+batch_size-1)//batch_size)
    for i in it:
        batch_texts = texts[i:i+batch_size]
        enc = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt").to(device)
        with torch.no_grad():
            output = model(**enc).last_hidden_state  # (batch, seq_len, embed_dim)
        # Move outputs to CPU only once
        token_embeddings.extend(output.cpu().split(1, dim=0))
        lengths.extend(enc['attention_mask'].sum(dim=1).cpu().tolist())
    # Remove the batch dimension from each embedding
    token_embeddings = [emb.squeeze(0) for emb in token_embeddings]
    return token_embeddings, lengths

query_texts   = [t[0] for t in triples]
rel_doc_texts = [t[1] for t in triples]
irrel_doc_texts = [t[2] for t in triples]

# Flatten the lists of lists to a single list of strings
rel_doc_texts_flat = [doc for docs in rel_doc_texts for doc in docs]
irrel_doc_texts_flat = [doc for docs in irrel_doc_texts for doc in docs]

query_embeds, query_lens         = tokenize_and_embed(query_texts, batch_size=64)
rel_doc_embeds, rel_doc_lens     = tokenize_and_embed(rel_doc_texts_flat, batch_size=64)
irrel_doc_embeds, irrel_doc_lens = tokenize_and_embed(irrel_doc_texts_flat, batch_size=64)


Embedding:   0%|          | 26/10566 [01:03<7:09:15,  2.44s/it] 


KeyboardInterrupt: 

In [None]:
# -----------------------------
# 7. Embed Queries, Relevant Documents and Irrelevant Documents (randomly sampled)
# -----------------------------

query_embeddings = []
relevant_doc_embeddings = []
irrelevant_doc_embeddings = []

num_docs = 10  # Number of relevant/irrelevant documents to consider

for i, (tokenized_query, tokenized_rels, tokenized_irrels) in enumerate(
    tqdm(tokenized_triples, desc="Embedding all triples")
):
    try:
        # --- Query embedding
        q_ids = [word_to_ix[t] for t in tokenized_query if t in word_to_ix]
        if q_ids:
            with torch.no_grad():
                q_vecs = cbow_model.embeddings(torch.tensor(q_ids))
            query_embeddings.append(q_vecs)
        else:
            query_embeddings.append(torch.zeros(1, embed_dim))

        # --- Relevant docs
        rel_embs = []
        for doc_tokens in tokenized_rels[:num_docs]:
            doc_ids = [word_to_ix[t] for t in doc_tokens if t in word_to_ix]
            if doc_ids:
                with torch.no_grad():
                    doc_vecs = cbow_model.embeddings(torch.tensor(doc_ids))
                rel_embs.append(doc_vecs)
            else:
                rel_embs.append(torch.zeros(1, embed_dim))
        while len(rel_embs) < num_docs:
            rel_embs.append(torch.zeros(1, embed_dim))
        relevant_doc_embeddings.append(rel_embs)

        # --- Irrelevant docs
        irrel_embs = []
        for doc_tokens in tokenized_irrels[:num_docs]:
            doc_ids = [word_to_ix[t] for t in doc_tokens if t in word_to_ix]
            if doc_ids:
                with torch.no_grad():
                    doc_vecs = cbow_model.embeddings(torch.tensor(doc_ids))
                irrel_embs.append(doc_vecs)
            else:
                irrel_embs.append(torch.zeros(1, embed_dim))
        while len(irrel_embs) < num_docs:
            irrel_embs.append(torch.zeros(1, embed_dim))
        irrelevant_doc_embeddings.append(irrel_embs)
    except Exception as e:
        print(f"Error at index {i}: {e}")
        break  # or continue, depending on preference


print(len(query_embeddings))        # Should be N
print(len(relevant_doc_embeddings)) # Should be N
print(len(irrelevant_doc_embeddings)) # Should be N
print(query_embeddings[0].shape)    # (seq_len, embed_dim)
print(relevant_doc_embeddings[0][0].shape) # (doc_seq_len, embed_dim)

Embedding all triples:  62%|██████▏   | 51289/82326 [10:41<03:10, 162.51it/s]  

In [20]:
# -----------------------------
# 4. Define distance functions & Triplet loss function
# -----------------------------

# Cosine similarity for calculation of cosine distance
def cosine_similarity(x, y):
    return F.cosine_similarity(x, y, dim=1)

# Cosine - Smallest value means most similar
def cosine_distance(x, y):
    return 1 - cosine_similarity(x, y)

# Euclidean (L2) - Smallest value means most similar
def euclidean_distance(x, y):
    return torch.norm(x - y, p=2, dim=1)

# Squared Euclidean - Smallest value means most similar
def squared_euclidean_distance(x, y):
    return torch.sum((x - y) ** 2, dim=1)

# Manhattan (L1) - Smallest value means most similar
def manhattan_distance(x, y):
    return torch.norm(x - y, p=1, dim=1)

# Chebyshev (L-infinity) - Smallest value means most similar
def chebyshev_distance(x, y):
    return torch.max(torch.abs(x - y), dim=1).values

# Minkowski - Smallest value means most similar
def minkowski_distance(x, y, p=3):
    return torch.norm(x - y, p=p, dim=1)

# Triplet loss function - will compute the loss for a batch of triplets
def triplet_loss_function(query, relevant_doc, irrelevant_doc, distance_function, margin):
    rel_dist = distance_function(query, relevant_doc)         # (batch,)
    irrel_dist = distance_function(query, irrelevant_doc)     # (batch,)
    triplet_loss = torch.relu(rel_dist - irrel_dist + margin)
    return triplet_loss.mean()                                # Average over batch

In [21]:
# -----------------------------
# 5. Define Two Tower Model (QueryTower and DocTower)
# -----------------------------

class QueryTower(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_layers=1, rnn_type='gru'):
        super().__init__()
        if rnn_type == 'gru':
            self.rnn = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        elif rnn_type == 'lstm':
            self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        else:
            raise ValueError("Unknown rnn_type: choose 'gru' or 'lstm'")

    def forward(self, x, lengths):
        # x: (batch, seq_len, embed_dim)
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        out, h = self.rnn(packed)
        if isinstance(h, tuple):  # LSTM
            h = h[0]
        return h[-1]  # (batch, hidden_dim)

class DocTower(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_layers=1, rnn_type='gru'):
        super().__init__()
        if rnn_type == 'gru':
            self.rnn = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        elif rnn_type == 'lstm':
            self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        else:
            raise ValueError("Unknown rnn_type: choose 'gru' or 'lstm'")

    def forward(self, x, lengths):
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        out, h = self.rnn(packed)
        if isinstance(h, tuple):
            h = h[0]
        return h[-1]


In [22]:
# -----------------------------
# 6. Prepare dataset for DataLoader
# -----------------------------

class TripleDataset(Dataset):
    def __init__(self, X_queries, X_rels, X_irrels):
        self.X_queries = X_queries    # list of N tensors (seq_len, embed_dim)
        self.X_rels = X_rels          # list of N lists of 10 tensors (seq_len, embed_dim)
        self.X_irrels = X_irrels      # list of N lists of 10 tensors (seq_len, embed_dim)
        self.n = len(X_queries)

    def __len__(self):
        return self.n * 10

    def __getitem__(self, idx):
        triple_idx = idx // 10
        doc_idx = idx % 10

        qry = self.X_queries[triple_idx]              # (q_seq_len, embed_dim)
        rel = self.X_rels[triple_idx][doc_idx]        # (rel_seq_len, embed_dim)
        irrel = self.X_irrels[triple_idx][doc_idx]    # (irrel_seq_len, embed_dim)

        return qry, rel, irrel
    
def collate_fn(batch):
    q_seqs, r_seqs, i_seqs = zip(*batch)
    q_lens = [x.shape[0] for x in q_seqs]
    r_lens = [x.shape[0] for x in r_seqs]
    i_lens = [x.shape[0] for x in i_seqs]
    q_padded = pad_sequence(q_seqs, batch_first=True)
    r_padded = pad_sequence(r_seqs, batch_first=True)
    i_padded = pad_sequence(i_seqs, batch_first=True)
    return q_padded, r_padded, i_padded, q_lens, r_lens, i_lens


# Create the dataset and dataloader for batching
batch_size = 32
dataset = TripleDataset(query_embeds, rel_doc_embeds, irrel_doc_embeds)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)


NameError: name 'rel_doc_embeds' is not defined

In [23]:
# -----------------------------
# 7. Train the model (change hyperparameters as needed)
# -----------------------------

# embed_dim defined in section 3.

# Define hyperparameters
hidden_dim = 128  # Dimension of the hidden state in RNNs (GRU/LSTM - can be adjusted)
margin = 0.2  # Margin for triplet loss (can be adjusted)
distance_function = cosine_distance  # Choose distance function (cosine_distance, euclidean_distance, manhattan_distance, squared_euclidean_distance, chebyshev_distance, minkowski_distance)

qry_tower = QueryTower(embed_dim, hidden_dim)
doc_tower = DocTower(embed_dim, hidden_dim)

optimizer = torch.optim.Adam(list(qry_tower.parameters()) + list(doc_tower.parameters()), lr=1e-3)
num_epochs = 5  # Set as needed

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        qry_embeds, rel_embeds, irrel_embeds, q_lens, r_lens, i_lens = batch
        # qry_embeds: (batch, q_seq_len, embed_dim)
        # rel_embeds: (batch, r_seq_len, embed_dim)
        # irrel_embeds: (batch, i_seq_len, embed_dim)

        # Mask: keep items where rel doc is *not* all zeros
        # mask shape: (batch,)
        mask = ~torch.all(rel_embeds == 0, dim=(1,2))

        # If mask is all False, skip batch (shouldn't happen)
        if mask.sum() == 0:
            continue

        # Only keep non-padded triples
        qry_embeds    = qry_embeds[mask]
        rel_embeds    = rel_embeds[mask]
        irrel_embeds  = irrel_embeds[mask]
        q_lens        = [q_lens[i] for i in range(len(mask)) if mask[i]]
        r_lens        = [r_lens[i] for i in range(len(mask)) if mask[i]]
        i_lens        = [i_lens[i] for i in range(len(mask)) if mask[i]]

        qry_vecs    = qry_tower(qry_embeds, q_lens)
        rel_vecs    = doc_tower(rel_embeds, r_lens)
        irrel_vecs  = doc_tower(irrel_embeds, i_lens)

        loss = triplet_loss_function(
            qry_vecs, rel_vecs, irrel_vecs,
            distance_function=distance_function,
            margin=margin
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss / len(dataloader):.4f}")


# Save final model after all epochs are done
torch.save({
    'epoch': num_epochs,
    'qry_tower_state_dict': qry_tower.state_dict(),
    'doc_tower_state_dict': doc_tower.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': epoch_loss,  # from last epoch
}, "twotower_final.pt")

NameError: name 'dataloader' is not defined

In [None]:
# # -----------------------------
# # 7. Train the model (change hyperparameters as needed)
# # -----------------------------

# # embed_dim defined in section 3.

# # Define hyperparameters
# hidden_dim = 128  # Dimension of the hidden state in RNNs (GRU/LSTM - can be adjusted)
# margin = 0.2  # Margin for triplet loss (can be adjusted)
# distance_function = cosine_distance  # Choose distance function (cosine_distance, euclidean_distance, manhattan_distance, squared_euclidean_distance, chebyshev_distance, minkowski_distance)

# qry_tower = QueryTower(embed_dim, hidden_dim)
# doc_tower = DocTower(embed_dim, hidden_dim)

# optimizer = torch.optim.Adam(list(qry_tower.parameters()) + list(doc_tower.parameters()), lr=1e-3)
# num_epochs = 10  # Set as needed

# for epoch in range(num_epochs):
#     epoch_loss = 0.0
#     for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
#         qry_embeds, rel_embeds, irrel_embeds, q_lens, r_lens, i_lens = batch

#         qry_vecs    = qry_tower(qry_embeds, q_lens)
#         rel_vecs    = doc_tower(rel_embeds, r_lens)
#         irrel_vecs  = doc_tower(irrel_embeds, i_lens)

#         loss = triplet_loss_function(
#             qry_vecs, rel_vecs, irrel_vecs,
#             distance_function=distance_function,
#             margin=margin
#         )

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         epoch_loss += loss.item()

#     print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss / len(dataloader):.4f}")


# # Save final model after all epochs are done
# torch.save({
#     'epoch': num_epochs,
#     'qry_tower_state_dict': qry_tower.state_dict(),
#     'doc_tower_state_dict': doc_tower.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict(),
#     'loss': epoch_loss,  # from last epoch
# }, "twotower_final.pt")


Epoch 1/10: 100%|██████████| 313/313 [01:36<00:00,  3.24it/s]


Epoch 1/10 | Loss: 0.1173


Epoch 2/10: 100%|██████████| 313/313 [01:46<00:00,  2.93it/s]


Epoch 2/10 | Loss: 0.0490


Epoch 3/10: 100%|██████████| 313/313 [01:38<00:00,  3.16it/s]


Epoch 3/10 | Loss: 0.0368


Epoch 4/10: 100%|██████████| 313/313 [02:02<00:00,  2.55it/s]


Epoch 4/10 | Loss: 0.0879


Epoch 5/10: 100%|██████████| 313/313 [02:35<00:00,  2.01it/s]


Epoch 5/10 | Loss: 0.1506


Epoch 6/10: 100%|██████████| 313/313 [02:35<00:00,  2.01it/s]


Epoch 6/10 | Loss: 0.1439


Epoch 7/10: 100%|██████████| 313/313 [02:34<00:00,  2.03it/s]


Epoch 7/10 | Loss: 0.1373


Epoch 8/10: 100%|██████████| 313/313 [01:50<00:00,  2.84it/s]


Epoch 8/10 | Loss: 0.1374


Epoch 9/10: 100%|██████████| 313/313 [02:28<00:00,  2.11it/s]


Epoch 9/10 | Loss: 0.1361


Epoch 10/10: 100%|██████████| 313/313 [02:16<00:00,  2.29it/s]

Epoch 10/10 | Loss: 0.1296





In [None]:
# -----------------------------
# 12. Inference: Encode documents and queries, find top-k relevant docs
# -----------------------------

def encode_documents(doc_tower, all_doc_embeds, all_doc_lens, device='cpu', batch_size=128):
    doc_tower.eval()
    all_vecs = []
    with torch.no_grad():
        for i in range(0, len(all_doc_embeds), batch_size):
            batch_embeds = all_doc_embeds[i:i+batch_size].to(device)
            batch_lens = all_doc_lens[i:i+batch_size]
            vecs = doc_tower(batch_embeds, batch_lens)  # Shape: (batch, hidden_dim)
            all_vecs.append(vecs.cpu())
    return torch.cat(all_vecs, dim=0)  # Shape: (num_docs, hidden_dim)

def encode_query(qry_tower, query_embed, query_len, device='cpu'):
    qry_tower.eval()
    with torch.no_grad():
        query_vec = qry_tower(query_embed.to(device), query_len)
    return query_vec.cpu()  # Shape: (1, hidden_dim)

def find_top_k(query_vec, doc_vecs, k=5, distance_fn=None):
    # query_vec: (1, hidden_dim), doc_vecs: (num_docs, hidden_dim)
    if distance_fn is None:
        # Default to cosine distance
        def distance_fn(q, d):
            return 1 - torch.nn.functional.cosine_similarity(q, d)
    distances = distance_fn(query_vec, doc_vecs)
    # If query_vec is (1,hidden_dim), expand to (num_docs,hidden_dim)
    if query_vec.shape[0] == 1:
        distances = distance_fn(query_vec.expand_as(doc_vecs), doc_vecs)
    # Get top k smallest distances
    topk = torch.topk(-distances, k)  # negative because smallest distance = highest relevance
    indices = topk.indices.cpu().numpy()
    scores = -topk.values.cpu().numpy()
    return indices, scores


In [None]:
# -----------------------------
# 13. Calculate all document embeddings and lengths
# -----------------------------

all_doc_texts = []  # To hold all relevant doc texts for padding
all_query_texts = []

for query, rels, irrels in triples:
    all_query_texts.append(query)
    # rels is the list of relevant doc texts for this query
    # We slice to num_docs to match the embedding logic, then pad if needed
    for doc in rels[:num_docs]:
        all_doc_texts.append(doc)
    # Pad if fewer than num_docs
    while len(rels) < num_docs:
        all_doc_texts.append("")
        rels.append("")  # So embedding and text padding always match

# Step 1: Gather your document embedding sequences and their lengths
all_doc_embeds_list = []
all_doc_lens_list = []

for batch in dataloader:
    rel_embeds, r_lens = batch[1], batch[4]  # rel_embeds: (batch, seq_len, embed_dim)
    for i in range(rel_embeds.shape[0]):
        all_doc_embeds_list.append(rel_embeds[i])  # (seq_len, embed_dim)
        all_doc_lens_list.append(r_lens[i])

# Step 2: Pad all embeddings to max seq_len
# pad_sequence wants a list of (seq_len, embed_dim), returns (max_seq_len, num_docs, embed_dim)
padded = pad_sequence(all_doc_embeds_list, batch_first=True)  # (num_docs, max_seq_len, embed_dim)

all_doc_embeds = padded  # (num_docs, max_seq_len, embed_dim)
all_doc_lens = torch.tensor(all_doc_lens_list)

print("all_doc_embeds shape:", all_doc_embeds.shape)
print("all_doc_lens shape:", all_doc_lens.shape)
print("all_doc_texts length:", len(all_doc_texts))
print("all_query_texts length:", len(all_query_texts))


In [None]:
# -----------------------------
# 14. Encode documents and queries, find top-k relevant docs
# -----------------------------

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 1. Encode all documents
doc_vecs = encode_documents(doc_tower, all_doc_embeds, all_doc_lens, device=device)
#print("Encoded document vectors shape:", doc_vecs.shape)

# 2. Encode a sample query (here, just using first doc for demo)
query_embed = all_doc_embeds[0].unsqueeze(0)
query_len = all_doc_lens[0].unsqueeze(0)
query_vec = encode_query(qry_tower, query_embed, query_len, device=device)
#print("Encoded query vector shape:", query_vec.shape)

# 3. Retrieve top k relevant documents
k = 5
indices, scores = find_top_k(query_vec, doc_vecs, k=k)

print("Query: ", all_query_texts[0])
print("Top document matches:")
for rank, i in enumerate(indices):
    print(f"{rank+1}: {all_doc_texts[i]}")
    print(f"   (score: {scores[rank]:.4f})")


In [None]:
# -----------------------------
# 14b. Custom Query Inference
# -----------------------------

# Hardcode your query here
custom_query = "what is machine learning?"

# Tokenize and embed the custom query
tokenized_query = tokenizer(custom_query)  # returns List[int] or torch.Tensor

# Map tokens to ids, filtering out tokens not in vocab
q_ids = [word_to_ix[t] for t in tokenized_query if t in word_to_ix]

if q_ids:
    with torch.no_grad():
        q_vecs = cbow_model.embeddings(torch.tensor(q_ids))  # (seq_len, embed_dim)
else:
    q_vecs = torch.zeros(1, embed_dim)  # Fallback for empty/unknown queries)

# Pad to (1, seq_len, embed_dim) for model
query_embed = q_vecs.unsqueeze(0)  # (1, seq_len, embed_dim)
query_len = torch.tensor([q_vecs.shape[0]])

# Move to device if needed
query_embed = query_embed.to(device)
query_len = query_len.to(device)

# Encode the hardcoded query
query_vec = encode_query(qry_tower, query_embed, query_len, device=device)

# Retrieve top k relevant documents
k = 5
indices, scores = find_top_k(query_vec, doc_vecs, k=k)

print("Custom Query:", custom_query)
print("Top document matches:")
for rank, i in enumerate(indices):
    print(f"{rank+1}: {all_doc_texts[i]}")
    print(f"   (score: {scores[rank]:.4f})")


In [None]:
# -----------------------------
# 15. Redis test
# -----------------------------

r = redis.Redis(host='localhost', port=6379, decode_responses=True)

r.set('foo', 'bar')
# True
r.get('foo')
# bar

