In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import random

In [2]:
# --- Sample Synthetic Data (in real usage, load MS MARCO preprocessed triples) ---
data = [
    ("what is ai", "artificial intelligence explanation", "banana nutrition"),
    ("define deep learning", "deep learning is a subset of machine learning", "weather in london"),
    ("machine learning applications", "uses of machine learning in finance", "cooking pasta")
]


In [3]:
# --- Tokenizer and Vocab ---
from collections import defaultdict
from itertools import chain

class SimpleTokenizer:
    def __init__(self):
        self.word2idx = defaultdict(self._new_index)
        self.idx = 0
        self.locked = False

    def _new_index(self):
        if self.locked:
            return 0  # default to <PAD> for unknown tokens
        idx = self.idx
        self.idx += 1
        return idx

    def encode(self, sentence):
        return [self.word2idx[word.lower()] for word in sentence.split()]

    def vocab_size(self):
        return self.idx

    def lock_vocab(self):
        self.locked = True


In [4]:
# Build vocab
tokenizer = SimpleTokenizer()
_ = [tokenizer.encode(text) for triplet in data for text in triplet]  # populate vocab
tokenizer.lock_vocab()

In [5]:
# --- Dataset Class ---
class TripletDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        query, pos_doc, neg_doc = self.data[idx]
        return (torch.tensor(self.tokenizer.encode(query)),
                torch.tensor(self.tokenizer.encode(pos_doc)),
                torch.tensor(self.tokenizer.encode(neg_doc)))

In [6]:
# --- Collate Function ---
def collate_fn(batch):
    queries, pos_docs, neg_docs = zip(*batch)
    return (
        pad_sequence(queries, batch_first=True),
        pad_sequence(pos_docs, batch_first=True),
        pad_sequence(neg_docs, batch_first=True)
    )

In [7]:
# --- Dual RNN Encoder Model ---
class RNNEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        x = self.embedding(x)
        _, h_n = self.rnn(x)
        return h_n.squeeze(0)  # shape: (batch, hidden_dim)

In [8]:
# --- Triplet Loss Function ---
def triplet_loss_function(triplet, distance_function, margin):
    query, pos_doc, neg_doc = triplet
    d_pos = distance_function(query, pos_doc)
    d_neg = distance_function(query, neg_doc)
    return torch.clamp(d_pos - d_neg + margin, min=0.0).mean()


In [9]:
# --- Training Setup ---
VOCAB_SIZE = tokenizer.vocab_size()
EMBED_DIM = 64
HIDDEN_DIM = 128
MARGIN = 1.0

query_encoder = RNNEncoder(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM)
doc_encoder = RNNEncoder(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM)

optimizer = torch.optim.Adam(list(query_encoder.parameters()) + list(doc_encoder.parameters()), lr=0.001)
dataset = TripletDataset(data, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [10]:
# --- Training Loop ---
for epoch in range(5):
    total_loss = 0
    for query_batch, pos_batch, neg_batch in dataloader:
        q_vec = query_encoder(query_batch)
        pos_vec = doc_encoder(pos_batch)
        neg_vec = doc_encoder(neg_batch)

        loss = triplet_loss_function((q_vec, pos_vec, neg_vec), F.pairwise_distance, MARGIN)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

Epoch 1, Loss: 1.9686
Epoch 2, Loss: 0.2344
Epoch 3, Loss: 0.4333
Epoch 4, Loss: 0.0000
Epoch 5, Loss: 0.0000


In [11]:
# --- Inference Function ---
def search(query_text, documents, tokenizer, query_encoder, doc_encoder):
    with torch.no_grad():
        query_tensor = pad_sequence([torch.tensor(tokenizer.encode(query_text))], batch_first=True)
        query_vec = query_encoder(query_tensor)

        doc_tensors = pad_sequence([torch.tensor(tokenizer.encode(doc)) for doc in documents], batch_first=True)
        doc_vecs = doc_encoder(doc_tensors)

        scores = F.cosine_similarity(query_vec, doc_vecs)
        top_indices = torch.argsort(scores, descending=True)
        return [(documents[i], scores[i].item()) for i in top_indices]

In [12]:
# --- Example Usage ---
documents = ["deep learning applications", "banana smoothie", "introduction to ai", "machine learning in banking"]
results = search("what is ai", documents, tokenizer, query_encoder, doc_encoder)

print("\nSearch results:")
for doc, score in results:
    print(f"{doc} (score: {score:.4f})")



Search results:
deep learning applications (score: 0.0267)
introduction to ai (score: -0.0974)
machine learning in banking (score: -0.1749)
banana smoothie (score: -0.3350)
