# Embedding-Based Search with TorchRec

In [None]:
import torch
import torchrec
import faiss
import numpy as np
from typing import Dict, List, Tuple, Optional, NamedTuple
from dataclasses import dataclass
from collections import defaultdict
import heapq

## Embedding Search Infrastructure

In [None]:
@dataclass
class EmbeddingIndex:
    """Container for embedding index"""
    dimension: int
    index: faiss.Index
    id_to_metadata: Dict[int, Dict]
    
    def search(
        self,
        query_vector: np.ndarray,
        k: int = 10
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Search for nearest neighbors"""
        return self.index.search(query_vector.reshape(1, -1), k)

class DualEncoder(torch.nn.Module):
    """Encoder for query and document embeddings"""
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int = 128,
        hidden_dim: int = 256,
        num_layers: int = 2,
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Shared embedding layer
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        
        # Query encoder
        self.query_encoder = torch.nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        
        # Document encoder
        self.doc_encoder = torch.nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        
        # Projection layers
        self.query_projection = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim * 2, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim, embedding_dim),
            torch.nn.LayerNorm(embedding_dim)
        )
        
        self.doc_projection = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim * 2, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim, embedding_dim),
            torch.nn.LayerNorm(embedding_dim)
        )
    
    def encode_query(
        self,
        tokens: torch.Tensor,
        lengths: torch.Tensor
    ) -> torch.Tensor:
        # Embed tokens
        embedded = self.embedding(tokens)
        
        # Pack sequence
        packed = torch.nn.utils.rnn.pack_padded_sequence(
            embedded,
            lengths.cpu(),
            batch_first=True,
            enforce_sorted=False
        )
        
        # Encode
        _, (hidden, _) = self.query_encoder(packed)
        
        # Combine bidirectional states
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        
        # Project
        return self.query_projection(hidden)
    
    def encode_document(
        self,
        tokens: torch.Tensor,
        lengths: torch.Tensor
    ) -> torch.Tensor:
        # Embed tokens
        embedded = self.embedding(tokens)
        
        # Pack sequence
        packed = torch.nn.utils.rnn.pack_padded_sequence(
            embedded,
            lengths.cpu(),
            batch_first=True,
            enforce_sorted=False
        )
        
        # Encode
        _, (hidden, _) = self.doc_encoder(packed)
        
        # Combine bidirectional states
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        
        # Project
        return self.doc_projection(hidden)

## Index Building and Search

In [None]:
class EmbeddingSearchEngine:
    """Engine for embedding-based search"""
    def __init__(
        self,
        encoder: DualEncoder,
        embedding_dim: int,
        index_type: str = "IVFFlat",
        n_cells: int = 100,
        n_probes: int = 10,
        device: str = "cuda"
    ):
        self.encoder = encoder.to(device)
        self.embedding_dim = embedding_dim
        self.device = device
        
        # Create FAISS index
        if index_type == "IVFFlat":
            quantizer = faiss.IndexFlatIP(embedding_dim)
            self.index = faiss.IndexIVFFlat(
                quantizer,
                embedding_dim,
                n_cells,
                faiss.METRIC_INNER_PRODUCT
            )
        elif index_type == "HNSW":
            self.index = faiss.IndexHNSWFlat(
                embedding_dim,
                32,  # M parameter for HNSW
                faiss.METRIC_INNER_PRODUCT
            )
        else:
            self.index = faiss.IndexFlatIP(embedding_dim)
        
        self.n_probes = n_probes
        self.document_store = {}
    
    def build_index(
        self,
        documents: List[Dict],
        batch_size: int = 32
    ):
        """Build search index from documents"""
        print("Encoding documents...")
        all_embeddings = []
        
        self.encoder.eval()
        with torch.no_grad():
            for i in range(0, len(documents), batch_size):
                batch_docs = documents[i:i + batch_size]
                
                # Prepare batch
                tokens = torch.tensor(
                    [doc["tokens"] for doc in batch_docs],
                    device=self.device
                )
                lengths = torch.tensor(
                    [len(doc["tokens"]) for doc in batch_docs],
                    device=self.device
                )
                
                # Encode
                embeddings = self.encoder.encode_document(tokens, lengths)
                all_embeddings.append(embeddings.cpu().numpy())
                
                # Store documents
                for j, doc in enumerate(batch_docs):
                    self.document_store[i + j] = doc
        
        # Concatenate all embeddings
        all_embeddings = np.vstack(all_embeddings)
        
        # Normalize embeddings
        faiss.normalize_L2(all_embeddings)
        
        # Train index if needed
        if isinstance(self.index, faiss.IndexIVFFlat):
            print("Training index...")
            self.index.train(all_embeddings)
        
        # Add vectors to index
        print("Adding vectors to index...")
        self.index.add(all_embeddings)
        
        if hasattr(self.index, 'nprobe'):
            self.index.nprobe = self.n_probes
    
    def search(
        self,
        query: str,
        k: int = 10,
        return_scores: bool = True
    ) -> List[Dict]:
        """Search for similar documents"""
        # Tokenize query (simplified)
        query_tokens = torch.tensor(
            [ord(c) % 100 for c in query],  # Simplified tokenization
            device=self.device
        ).unsqueeze(0)
        query_length = torch.tensor([len(query)], device=self.device)
        
        # Encode query
        self.encoder.eval()
        with torch.no_grad():
            query_embedding = self.encoder.encode_query(
                query_tokens,
                query_length
            )
        
        # Normalize query embedding
        query_embedding = query_embedding.cpu().numpy()
        faiss.normalize_L2(query_embedding)
        
        # Search
        scores, indices = self.index.search(query_embedding, k)
        
        # Return results
        results = []
        for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
            if idx == -1:  # FAISS returns -1 for not enough results
                continue
                
            result = self.document_store[idx].copy()
            if return_scores:
                result["score"] = float(score)
            results.append(result)
        
        return results

## Training Infrastructure

In [None]:
class DualEncoderTrainer:
    """Training infrastructure for dual encoder"""
    def __init__(
        self,
        encoder: DualEncoder,
        learning_rate: float = 0.001,
        temperature: float = 0.1,
        device: str = "cuda"
    ):
        self.encoder = encoder.to(device)
        self.temperature = temperature
        self.device = device
        
        # Optimizer
        self.optimizer = torch.optim.AdamW(
            encoder.parameters(),
            lr=learning_rate,
            weight_decay=0.01
        )
        
        # Learning rate scheduler
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=1000
        )
    
    def train_step(
        self,
        query_tokens: torch.Tensor,
        query_lengths: torch.Tensor,
        pos_doc_tokens: torch.Tensor,
        pos_doc_lengths: torch.Tensor,
        neg_doc_tokens: torch.Tensor,
        neg_doc_lengths: torch.Tensor
    ) -> Dict[str, float]:
        self.optimizer.zero_grad()
        
        # Move to device
        query_tokens = query_tokens.to(self.device)
        query_lengths = query_lengths.to(self.device)
        pos_doc_tokens = pos_doc_tokens.to(self.device)
        pos_doc_lengths = pos_doc_lengths.to(self.device)
        neg_doc_tokens = neg_doc_tokens.to(self.device)
        neg_doc_lengths = neg_doc_lengths.to(self.device)
        
        # Encode queries and documents
        query_embeddings = self.encoder.encode_query(
            query_tokens,
            query_lengths
        )
        
        pos_doc_embeddings = self.encoder.encode_document(
            pos_doc_tokens,
            pos_doc_lengths
        )
        
        neg_doc_embeddings = self.encoder.encode_document(
            neg_doc_tokens,
            neg_doc_lengths
        )
        
        # Compute similarities
        pos_similarities = torch.sum(
            query_embeddings * pos_doc_embeddings,
            dim=1
        ) / self.temperature
        
        neg_similarities = torch.sum(
            query_embeddings * neg_doc_embeddings,
            dim=1
        ) / self.temperature
        
        # Compute loss (InfoNCE)
        logits = torch.stack([pos_similarities, neg_similarities], dim=1)
        labels = torch.zeros(len(logits), device=self.device).long()
        
        loss = torch.nn.functional.cross_entropy(logits, labels)
        
        # Compute accuracy
        accuracy = (logits.argmax(dim=1) == labels).float().mean()
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), 1.0)
        
        # Update weights
        self.optimizer.step()
        self.scheduler.step()
        
        return {
            "loss": loss.item(),
            "accuracy": accuracy.item(),
            "pos_similarity": pos_similarities.mean().item(),
            "neg_similarity": neg_similarities.mean().item()
        }

## Evaluation Metrics

In [None]:
class EmbeddingSearchEvaluator:
    """Evaluation metrics for embedding-based search"""
    def __init__(self, k_values: List[int] = [1, 5, 10, 100]):
        self.k_values = k_values
    
    @torch.no_grad()
    def evaluate(
        self,
        search_engine: EmbeddingSearchEngine,
        eval_queries: List[Dict],
        relevant_docs: Dict[int, List[int]],
        batch_size: int = 32
    ) -> Dict[str, float]:
        metrics = defaultdict(list)
        
        for i in range(0, len(eval_queries), batch_size):
            batch_queries = eval_queries[i:i + batch_size]
            
            for query in batch_queries:
                query_id = query["id"]
                results = search_engine.search(
                    query["text"],
                    k=max(self.k_values),
                    return_scores=True
                )
                
                # Get retrieved document IDs
                retrieved_ids = [doc["id"] for doc in results]
                
                # Compute metrics
                relevant_ids = set(relevant_docs[query_id])
                batch_metrics = self._compute_metrics(
                    retrieved_ids,
                    relevant_ids
                )
                
                for metric, value in batch_metrics.items():
                    metrics[metric].append(value)
        
        # Average metrics
        return {k: np.mean(v) for k, v in metrics.items()}
    
    def _compute_metrics(
        self,
        retrieved_ids: List[int],
        relevant_ids: set
    ) -> Dict[str, float]:
        metrics = {}
        
        # Recall@k
        for k in self.k_values:
            retrieved_at_k = set(retrieved_ids[:k])
            recall = len(retrieved_at_k & relevant_ids) / len(relevant_ids)
            metrics[f"recall@{k}"] = recall
        
        # Precision@k
        for k in self.k_values:
            if k > len(retrieved_ids):
                continue
            retrieved_at_k = set(retrieved_ids[:k])
            precision = len(retrieved_at_k & relevant_ids) / k
            metrics[f"precision@{k}"] = precision
        
        # Mean Reciprocal Rank (MRR)
        for i, doc_id in enumerate(retrieved_ids):
            if doc_id in relevant_ids:
                metrics["mrr"] = 1.0 / (i + 1)
                break
        else:
            metrics["mrr"] = 0.0
        
        return metrics

## Data Generation

In [None]:
class EmbeddingTrainingDataGenerator:
    """Generate training data for embedding-based search"""
    def __init__(
        self,
        vocab_size: int = 10000,
        max_query_length: int = 10,
        max_doc_length: int = 100,
        num_docs: int = 100000
    ):
        self.vocab_size = vocab_size
        self.max_query_length = max_query_length
        self.max_doc_length = max_doc_length
        
        # Generate synthetic document corpus
        self.documents = self._generate_documents(num_docs)
        
        # Generate synthetic query-document relevance
        self.relevance = self._generate_relevance()
    
    def _generate_documents(self, num_docs: int) -> List[Dict]:
        """Generate synthetic documents"""
        documents = []
        for i in range(num_docs):
            length = np.random.randint(20, self.max_doc_length)
            tokens = np.random.randint(0, self.vocab_size, size=length)
            
            documents.append({
                "id": i,
                "tokens": tokens,
                "length": length,
                "embedding_quality": np.random.random()  # Synthetic quality score
            })
        return documents
    
    def _generate_relevance(self) -> Dict[int, List[Dict]]:
        """Generate synthetic query-document relevance"""
        relevance = {}
        num_queries = len(self.documents) // 10  # 10 documents per query on average
        
        for i in range(num_queries):
            # Generate synthetic query
            length = np.random.randint(3, self.max_query_length)
            tokens = np.random.randint(0, self.vocab_size, size=length)
            
            # Randomly select relevant documents
            num_relevant = np.random.randint(1, 5)
            relevant_docs = np.random.choice(
                len(self.documents),
                size=num_relevant,
                replace=False
            )
            
            relevance[i] = {
                "query": {
                    "id": i,
                    "tokens": tokens,
                    "length": length
                },
                "relevant_docs": relevant_docs.tolist()
            }
        
        return relevance
    
    def generate_batch(
        self,
        batch_size: int,
        neg_ratio: int = 4
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Generate training batch with negative sampling"""
        # Randomly select queries
        query_ids = np.random.choice(list(self.relevance.keys()), batch_size)
        
        # Prepare tensors
        query_tokens = []
        query_lengths = []
        pos_doc_tokens = []
        pos_doc_lengths = []
        neg_doc_tokens = []
        neg_doc_lengths = []
        
        for qid in query_ids:
            # Get query
            query = self.relevance[qid]["query"]
            query_tokens.append(query["tokens"])
            query_lengths.append(query["length"])
            
            # Get positive document
            pos_doc_id = np.random.choice(self.relevance[qid]["relevant_docs"])
            pos_doc = self.documents[pos_doc_id]
            pos_doc_tokens.append(pos_doc["tokens"])
            pos_doc_lengths.append(pos_doc["length"])
            
            # Get negative document
            for _ in range(neg_ratio):
                while True:
                    neg_doc_id = np.random.randint(0, len(self.documents))
                    if neg_doc_id not in self.relevance[qid]["relevant_docs"]:
                        break
                
                neg_doc = self.documents[neg_doc_id]
                neg_doc_tokens.append(neg_doc["tokens"])
                neg_doc_lengths.append(neg_doc["length"])
        
        # Convert to tensors with padding
        return (
            self._pad_sequences(query_tokens),
            torch.tensor(query_lengths),
            self._pad_sequences(pos_doc_tokens),
            torch.tensor(pos_doc_lengths),
            self._pad_sequences(neg_doc_tokens),
            torch.tensor(neg_doc_lengths)
        )
    
    def _pad_sequences(self, sequences: List[np.ndarray]) -> torch.Tensor:
        """Pad sequences to same length"""
        max_len = max(len(seq) for seq in sequences)
        padded = np.zeros((len(sequences), max_len), dtype=np.int64)
        
        for i, seq in enumerate(sequences):
            padded[i, :len(seq)] = seq
        
        return torch.tensor(padded)

## Complete Training Loop

In [None]:
def train_embedding_search():
    # Initialize components
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    vocab_size = 10000
    embedding_dim = 128
    
    # Create model
    encoder = DualEncoder(
        vocab_size=vocab_size,
        embedding_dim=embedding_dim
    )
    
    # Create search engine
    search_engine = EmbeddingSearchEngine(
        encoder=encoder,
        embedding_dim=embedding_dim,
        index_type="IVFFlat",
        device=device
    )
    
    # Create trainer
    trainer = DualEncoderTrainer(encoder, device=device)
    
    # Create data generator
    data_gen = EmbeddingTrainingDataGenerator(
        vocab_size=vocab_size,
        num_docs=100000
    )
    
    # Create evaluator
    evaluator = EmbeddingSearchEvaluator()
    
    # Training loop
    num_epochs = 10
    batch_size = 32
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}")
        
        # Training
        encoder.train()
        epoch_metrics = defaultdict(list)
        
        for batch in range(100):  # 100 batches per epoch
            # Generate batch
            (
                query_tokens,
                query_lengths,
                pos_doc_tokens,
                pos_doc_lengths,
                neg_doc_tokens,
                neg_doc_lengths
            ) = data_gen.generate_batch(batch_size)
            
            # Train step
            metrics = trainer.train_step(
                query_tokens,
                query_lengths,
                pos_doc_tokens,
                pos_doc_lengths,
                neg_doc_tokens,
                neg_doc_lengths
            )
            
            for k, v in metrics.items():
                epoch_metrics[k].append(v)
            
            if batch % 10 == 0:
                print(f"Batch {batch}, Loss: {metrics['loss']:.4f}")
        
        # Print epoch metrics
        print("\nTraining Metrics:")
        for k, v in epoch_metrics.items():
            print(f"{k}: {np.mean(v):.4f}")
        
        # Evaluation
        if epoch % 2 == 0:  # Evaluate every 2 epochs
            print("\nRebuilding search index...")
            search_engine.build_index(data_gen.documents)
            
            print("\nEvaluating...")
            eval_metrics = evaluator.evaluate(
                search_engine=search_engine,
                eval_queries=[q["query"] for q in data_gen.relevance.values()],
                relevant_docs={
                    q["query"]["id"]: q["relevant_docs"]
                    for q in data_gen.relevance.values()
                }
            )
            
            print("\nEvaluation Metrics:")
            for metric, value in eval_metrics.items():
                print(f"{metric}: {value:.4f}")


## Example Usage

In [None]:
def example_search():
    # Initialize components
    encoder = DualEncoder(vocab_size=10000, embedding_dim=128)
    
    search_engine = EmbeddingSearchEngine(
        encoder=encoder,
        embedding_dim=128,
        index_type="IVFFlat"
    )
    
    # Generate sample documents
    data_gen = EmbeddingTrainingDataGenerator(num_docs=1000)
    
    # Build index
    search_engine.build_index(data_gen.documents)
    
    # Example searches
    queries = [
        "sample query 1",
        "sample query 2",
        "sample query 3"
    ]
    
    for query in queries:
        print(f"\nQuery: {query}")
        results = search_engine.search(query, k=5)
        
        print("\nTop 5 Results:")
        for i, result in enumerate(results, 1):
            print(f"{i}. Document ID: {result['id']}, Score: {result['score']:.4f}")

if __name__ == "__main__":
    train_embedding_search()
    example_search()