# Search Ranking System with TorchRec

In [None]:
import torch
import torchrec
from typing import Dict, List, Tuple, Optional, NamedTuple
import numpy as np
from collections import defaultdict
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from dataclasses import dataclass

## Search-Specific Features

In [None]:
@dataclass
class SearchFeatures:
    """Search-specific features for ranking"""
    # Query features
    query_tokens: torch.Tensor          # Tokenized query
    query_length: torch.Tensor          # Query length
    query_embeddings: torch.Tensor      # Pre-computed query embeddings
    
    # Document features
    doc_tokens: torch.Tensor            # Tokenized document
    doc_length: torch.Tensor            # Document length
    doc_embeddings: torch.Tensor        # Pre-computed document embeddings
    doc_static_features: torch.Tensor   # PageRank, freshness, etc.
    
    # Interaction features
    exact_match_signals: torch.Tensor   # BM25, TF-IDF, etc.
    semantic_match_signals: torch.Tensor # Cosine similarity, etc.
    
    # User context
    user_id: torch.Tensor
    search_history: torch.Tensor
    session_features: torch.Tensor

    def to(self, device: torch.device) -> 'SearchFeatures':
        return SearchFeatures(
            **{k: v.to(device) if isinstance(v, torch.Tensor) else v 
               for k, v in self.__dict__.items()}
        )

## Search Encoder Modules

In [None]:
class TransformerEncoder(torch.nn.Module):
    """Transformer-based encoder for query/document"""
    def __init__(
        self,
        embedding_dim: int,
        num_heads: int = 4,
        num_layers: int = 2,
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.encoder_layer = torch.nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=embedding_dim * 4,
            dropout=dropout,
            batch_first=True
        )
        
        self.transformer = torch.nn.TransformerEncoder(
            self.encoder_layer,
            num_layers=num_layers
        )
        
        self.pooling = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, embedding_dim),
            torch.nn.Tanh()
        )
    
    def forward(
        self,
        embeddings: torch.Tensor,
        lengths: torch.Tensor
    ) -> torch.Tensor:
        # Create attention mask based on lengths
        max_len = embeddings.size(1)
        mask = torch.arange(max_len, device=embeddings.device)[None, :] >= lengths[:, None]
        
        # Encode sequence
        encoded = self.transformer(embeddings, src_key_padding_mask=mask)
        
        # Pool sequence (using attention mask)
        masked_encoded = encoded.masked_fill(mask.unsqueeze(-1), 0)
        pooled = masked_encoded.sum(dim=1) / lengths.unsqueeze(-1)
        
        return self.pooling(pooled)

class CrossAttention(torch.nn.Module):
    """Cross-attention between query and document"""
    def __init__(
        self,
        embedding_dim: int,
        num_heads: int = 4,
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.attention = torch.nn.MultiheadAttention(
            embed_dim=embedding_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        self.norm = torch.nn.LayerNorm(embedding_dim)
        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(
        self,
        query: torch.Tensor,
        document: torch.Tensor,
        query_mask: Optional[torch.Tensor] = None,
        doc_mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Query-to-document attention
        q2d_attended, q2d_weights = self.attention(
            query, document, document,
            key_padding_mask=doc_mask
        )
        q2d_attended = self.norm(query + self.dropout(q2d_attended))
        
        # Document-to-query attention
        d2q_attended, d2q_weights = self.attention(
            document, query, query,
            key_padding_mask=query_mask
        )
        d2q_attended = self.norm(document + self.dropout(d2q_attended))
        
        return q2d_attended, d2q_attended

## Search Ranking Model

In [None]:
class SearchRankingModel(torch.nn.Module):
    """Neural ranking model for search"""
    def __init__(
        self,
        embedding_dim: int = 128,
        hidden_dim: int = 256,
        num_transformer_layers: int = 2,
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Query encoder
        self.query_encoder = TransformerEncoder(
            embedding_dim=embedding_dim,
            num_layers=num_transformer_layers,
            dropout=dropout
        )
        
        # Document encoder
        self.doc_encoder = TransformerEncoder(
            embedding_dim=embedding_dim,
            num_layers=num_transformer_layers,
            dropout=dropout
        )
        
        # Cross attention
        self.cross_attention = CrossAttention(
            embedding_dim=embedding_dim,
            dropout=dropout
        )
        
        # User embedding
        self.user_embedding = torchrec.EmbeddingBagCollection(
            tables=[
                torchrec.EmbeddingBagConfig(
                    name="user_embeddings",
                    embedding_dim=embedding_dim,
                    num_embeddings=1_000_000,  # Adjust based on your needs
                    feature_names=["user_id"]
                )
            ],
            device=torch.device("meta")
        )
        
        # Static feature processing
        self.static_encoder = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim, embedding_dim)
        )
        
        # Session feature processing
        self.session_encoder = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim, embedding_dim)
        )
        
        # Final scoring layers
        self.scoring_layers = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim * 6, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim, hidden_dim // 2),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim // 2, 1)
        )
    
    def forward(self, features: SearchFeatures) -> Dict[str, torch.Tensor]:
        # Encode query and document
        query_encoding = self.query_encoder(
            features.query_embeddings,
            features.query_length
        )
        
        doc_encoding = self.doc_encoder(
            features.doc_embeddings,
            features.doc_length
        )
        
        # Cross attention
        query_attended, doc_attended = self.cross_attention(
            features.query_embeddings,
            features.doc_embeddings
        )
        
        # Get user embedding
        user_embedding = self.user_embedding(
            KeyedJaggedTensor.from_lengths_sync(
                keys=["user_id"],
                values=features.user_id,
                lengths=torch.ones_like(features.user_id)
            )
        ).values()
        
        # Process static and session features
        static_features = self.static_encoder(features.doc_static_features)
        session_features = self.session_encoder(features.session_features)
        
        # Combine all features
        combined_features = torch.cat([
            query_encoding,
            doc_encoding,
            query_attended.mean(dim=1),
            doc_attended.mean(dim=1),
            static_features,
            session_features
        ], dim=1)
        
        # Generate ranking score
        ranking_score = self.scoring_layers(combined_features)
        
        return {
            "score": ranking_score,
            "query_encoding": query_encoding,
            "doc_encoding": doc_encoding,
            "attention_weights": {
                "query_to_doc": query_attended,
                "doc_to_query": doc_attended
            }
        }

## Search-Specific Loss

In [None]:
class SearchRankingLoss:
    """Multiple loss components for search ranking"""
    def __init__(
        self,
        margin: float = 0.1,
        lambda_semantic: float = 0.1,
        lambda_relevance: float = 1.0,
        lambda_diversity: float = 0.1
    ):
        self.margin = margin
        self.lambda_semantic = lambda_semantic
        self.lambda_relevance = lambda_relevance
        self.lambda_diversity = lambda_diversity
    
    def compute_loss(
        self,
        outputs: Dict[str, torch.Tensor],
        labels: torch.Tensor,
        features: SearchFeatures
    ) -> Dict[str, torch.Tensor]:
        scores = outputs["score"]
        query_encoding = outputs["query_encoding"]
        doc_encoding = outputs["doc_encoding"]
        
        # Relevance loss (pairwise ranking)
        relevance_loss = self._compute_relevance_loss(scores, labels)
        
        # Semantic matching loss
        semantic_loss = self._compute_semantic_loss(
            query_encoding,
            doc_encoding,
            features.semantic_match_signals
        )
        
        # Diversity loss
        diversity_loss = self._compute_diversity_loss(doc_encoding)
        
        # Total loss
        total_loss = (
            self.lambda_relevance * relevance_loss +
            self.lambda_semantic * semantic_loss +
            self.lambda_diversity * diversity_loss
        )
        
        return {
            "total_loss": total_loss,
            "relevance_loss": relevance_loss,
            "semantic_loss": semantic_loss,
            "diversity_loss": diversity_loss
        }
    
    def _compute_relevance_loss(
        self,
        scores: torch.Tensor,
        labels: torch.Tensor
    ) -> torch.Tensor:
        return torch.nn.functional.binary_cross_entropy_with_logits(
            scores.squeeze(),
            labels.float()
        )
    
    def _compute_semantic_loss(
        self,
        query_encoding: torch.Tensor,
        doc_encoding: torch.Tensor,
        semantic_signals: torch.Tensor
    ) -> torch.Tensor:
        # Cosine similarity between predicted and pre-computed semantic signals
        pred_similarity = torch.nn.functional.cosine_similarity(
            query_encoding,
            doc_encoding
        )
        return torch.nn.functional.mse_loss(
            pred_similarity,
            semantic_signals
        )
    
    def _compute_diversity_loss(
        self,
        doc_encoding: torch.Tensor
    ) -> torch.Tensor:
        # Encourage diversity in document representations
        similarity_matrix = torch.mm(
            doc_encoding,
            doc_encoding.t()
        )
        
        # Remove self-similarity from diagonal
        mask = torch.eye(
            similarity_matrix.size(0),
            device=similarity_matrix.device
        )
        similarity_matrix = similarity_matrix * (1 - mask)
        
        # Minimize pairwise similarities
        return similarity_matrix.mean()


## Search Data Generation

In [None]:
class SearchDataGenerator:
    """Generate synthetic search data"""
    def __init__(
        self,
        vocab_size: int = 50000,
        num_users: int = 100000,
        num_docs: int = 1000000,
        embedding_dim: int = 128,
        max_query_length: int = 10,
        max_doc_length: int = 100
    ):
        self.vocab_size = vocab_size
        self.num_users = num_users
        self.num_docs = num_docs
        self.embedding_dim = embedding_dim
        self.max_query_length = max_query_length
        self.max_doc_length = max_doc_length
        
        # Generate synthetic word embeddings
        self.word_embeddings = torch.randn(vocab_size, embedding_dim)
        
        # Generate synthetic document features
        self.doc_features = {
            "static": torch.randn(num_docs, embedding_dim),
            "pagerank": torch.rand(num_docs),
            "freshness": torch.rand(num_docs),
            "quality_score": torch.rand(num_docs)
        }
        
        # Generate user features
        self.user_features = {
            "session_embeddings": torch.randn(num_users, embedding_dim),
            "search_history": self._generate_search_history()
        }
    
    def _generate_search_history(self) -> Dict[int, List[int]]:
        """Generate synthetic search history for users"""
        history = {}
        for user_id in range(self.num_users):
            num_searches = np.random.randint(5, 20)
            history[user_id] = np.random.choice(
                self.num_docs,
                size=num_searches,
                replace=False
            ).tolist()
        return history
    
    def _generate_query(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Generate synthetic query"""
        length = torch.randint(2, self.max_query_length + 1, (1,)).item()
        tokens = torch.randint(0, self.vocab_size, (length,))
        embeddings = self.word_embeddings[tokens]
        return tokens, torch.tensor(length), embeddings
    
    def _generate_document(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Generate synthetic document"""
        length = torch.randint(10, self.max_doc_length + 1, (1,)).item()
        tokens = torch.randint(0, self.vocab_size, (length,))
        embeddings = self.word_embeddings[tokens]
        return tokens, torch.tensor(length), embeddings
    
    def _compute_relevance(
        self,
        query_emb: torch.Tensor,
        doc_emb: torch.Tensor,
        doc_features: Dict[str, torch.Tensor]
    ) -> float:
        """Compute synthetic relevance score"""
        semantic_sim = torch.nn.functional.cosine_similarity(
            query_emb.mean(0, keepdim=True),
            doc_emb.mean(0, keepdim=True)
        ).item()
        
        quality = doc_features["quality_score"].item()
        freshness = doc_features["freshness"].item()
        
        return (semantic_sim * 0.5 + quality * 0.3 + freshness * 0.2)
    
    def generate_batch(
        self,
        batch_size: int,
        pos_ratio: float = 0.2
    ) -> Tuple[SearchFeatures, torch.Tensor]:
        # Generate queries and documents
        queries = [self._generate_query() for _ in range(batch_size)]
        docs = [self._generate_document() for _ in range(batch_size)]
        
        # Generate user IDs and features
        user_ids = torch.randint(0, self.num_users, (batch_size,))
        session_features = self.user_features["session_embeddings"][user_ids]
        
        # Compute match signals
        exact_match = torch.tensor([
            len(set(q[0].tolist()) & set(d[0].tolist())) / min(len(q[0]), len(d[0]))
            for q, d in zip(queries, docs)
        ])
        
        semantic_match = torch.tensor([
            torch.nn.functional.cosine_similarity(
                q[2].mean(0, keepdim=True),
                d[2].mean(0, keepdim=True)
            ).item()
            for q, d in zip(queries, docs)
        ])
        
        # Compute relevance scores and generate labels
        relevance_scores = torch.tensor([
            self._compute_relevance(q[2], d[2], {
                k: v[i] for k, v in self.doc_features.items()
            })
            for i, (q, d) in enumerate(zip(queries, docs))
        ])
        
        labels = torch.bernoulli(relevance_scores * pos_ratio)
        
        # Create features object
        features = SearchFeatures(
            query_tokens=torch.nn.utils.rnn.pad_sequence(
                [q[0] for q in queries],
                batch_first=True
            ),
            query_length=torch.stack([q[1] for q in queries]),
            query_embeddings=torch.nn.utils.rnn.pad_sequence(
                [q[2] for q in queries],
                batch_first=True
            ),
            doc_tokens=torch.nn.utils.rnn.pad_sequence(
                [d[0] for d in docs],
                batch_first=True
            ),
            doc_length=torch.stack([d[1] for d in docs]),
            doc_embeddings=torch.nn.utils.rnn.pad_sequence(
                [d[2] for d in docs],
                batch_first=True
            ),
            doc_static_features=torch.stack([
                self.doc_features["static"][i] for i in range(batch_size)
            ]),
            exact_match_signals=exact_match,
            semantic_match_signals=semantic_match,
            user_id=user_ids,
            search_history=torch.tensor([
                self.user_features["search_history"][uid.item()][:5]
                for uid in user_ids
            ]),
            session_features=session_features
        )
        
        return features, labels

## Search Training Infrastructure

In [None]:
class SearchTrainer:
    """Training infrastructure for search ranking"""
    def __init__(
        self,
        model: SearchRankingModel,
        loss_fn: SearchRankingLoss,
        learning_rate: float = 0.001,
        device: str = "cuda"
    ):
        self.model = model.to(device)
        self.loss_fn = loss_fn
        self.device = device
        
        # Optimizer
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            weight_decay=0.01
        )
        
        # Learning rate scheduler with warmup
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.optimizer,
            max_lr=learning_rate,
            total_steps=1000,
            pct_start=0.1
        )
    
    def train_step(
        self,
        features: SearchFeatures,
        labels: torch.Tensor
    ) -> Dict[str, float]:
        self.optimizer.zero_grad()
        
        # Move to device
        features = features.to(self.device)
        labels = labels.to(self.device)
        
        # Forward pass
        outputs = self.model(features)
        
        # Compute loss
        losses = self.loss_fn.compute_loss(outputs, labels, features)
        
        # Backward pass
        losses["total_loss"].backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        
        # Update weights
        self.optimizer.step()
        self.scheduler.step()
        
        return {k: v.item() for k, v in losses.items()}

# [MARKDOWN: 7. Search Evaluation Metrics]

class SearchEvaluator:
    """Evaluation metrics for search ranking"""
    def __init__(self, k_values: List[int] = [1, 3, 5, 10]):
        self.k_values = k_values
    
    @torch.no_grad()
    def evaluate(
        self,
        model: SearchRankingModel,
        features: SearchFeatures,
        labels: torch.Tensor
    ) -> Dict[str, float]:
        model.eval()
        outputs = model(features)
        scores = outputs["score"].squeeze()
        
        metrics = {}
        
        # Ranking metrics
        metrics.update(self._compute_ranking_metrics(scores, labels))
        
        # Diversity metrics
        metrics.update(self._compute_diversity_metrics(
            outputs["doc_encoding"],
            scores,
            labels
        ))
        
        # Semantic matching metrics
        metrics.update(self._compute_semantic_metrics(
            outputs["query_encoding"],
            outputs["doc_encoding"],
            features.semantic_match_signals
        ))
        
        return metrics
    
    def _compute_ranking_metrics(
        self,
        scores: torch.Tensor,
        labels: torch.Tensor
    ) -> Dict[str, float]:
        metrics = {}
        
        # Sort by scores
        sorted_indices = torch.argsort(scores, descending=True)
        sorted_labels = labels[sorted_indices]
        
        # NDCG@k
        for k in self.k_values:
            metrics[f"ndcg@{k}"] = self._compute_ndcg(sorted_labels, k)
        
        # Precision@k
        for k in self.k_values:
            metrics[f"precision@{k}"] = sorted_labels[:k].float().mean().item()
        
        # Mean Reciprocal Rank
        pos_indices = (sorted_labels == 1).nonzero()
        if len(pos_indices) > 0:
            metrics["mrr"] = 1.0 / (pos_indices[0].item() + 1)
        else:
            metrics["mrr"] = 0.0
        
        return metrics
    
    def _compute_diversity_metrics(
        self,
        doc_encodings: torch.Tensor,
        scores: torch.Tensor,
        labels: torch.Tensor
    ) -> Dict[str, float]:
        metrics = {}
        
        # Get top-k document encodings
        top_k = 10
        top_indices = torch.argsort(scores, descending=True)[:top_k]
        top_encodings = doc_encodings[top_indices]
        
        # Compute pairwise similarities
        similarity_matrix = torch.mm(top_encodings, top_encodings.t())
        
        # Average similarity (lower is more diverse)
        metrics["diversity"] = similarity_matrix.mean().item()
        
        return metrics
    
    def _compute_semantic_metrics(
        self,
        query_encoding: torch.Tensor,
        doc_encoding: torch.Tensor,
        semantic_signals: torch.Tensor
    ) -> Dict[str, float]:
        metrics = {}
        
        # Compute predicted similarities
        pred_similarity = torch.nn.functional.cosine_similarity(
            query_encoding,
            doc_encoding
        )
        
        # Correlation with pre-computed semantic signals
        correlation = torch.corrcoef(
            torch.stack([pred_similarity, semantic_signals])
        )[0, 1].item()
        
        metrics["semantic_correlation"] = correlation
        
        return metrics
    
    @staticmethod
    def _compute_ndcg(labels: torch.Tensor, k: int) -> float:
        if k > len(labels):
            k = len(labels)
        
        dcg = 0
        idcg = 0
        
        # Calculate DCG
        for i in range(k):
            if labels[i] == 1:
                dcg += 1 / np.log2(i + 2)
        
        # Calculate IDCG
        sorted_labels = torch.sort(labels, descending=True)[0]
        for i in range(k):
            if sorted_labels[i] == 1:
                idcg += 1 / np.log2(i + 2)
        
        return dcg / idcg if idcg > 0 else 0.0

## Complete Training Loop

In [None]:
def train_search_model():
    # Initialize components
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = SearchRankingModel(
        embedding_dim=128,
        hidden_dim=256
    )
    
    loss_fn = SearchRankingLoss()
    trainer = SearchTrainer(model, loss_fn, device=device)
    evaluator = SearchEvaluator()
    
    data_gen = SearchDataGenerator(
        vocab_size=50000,
        num_users=100000,
        num_docs=1000000,
        embedding_dim=128
    )
    
    # Training loop
    num_epochs = 10
    batch_size = 64
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}")
        
        # Training
        model.train()
        epoch_losses = defaultdict(list)
        
        for batch in range(100):  # 100 batches per epoch
            features, labels = data_gen.generate_batch(batch_size)
            losses = trainer.train_step(features, labels)
            
            for k, v in losses.items():
                epoch_losses[k].append(v)
            
            if batch % 10 == 0:
                print(f"Batch {batch}, Loss: {losses['total_loss']:.4f}")
        
        # Evaluation
        print("\nEvaluation:")
        eval_features, eval_labels = data_gen.generate_batch(1000)
        eval_metrics = evaluator.evaluate(
            model,
            eval_features.to(device),
            eval_labels.to(device)
        )
        
        for metric, value in eval_metrics.items():
            print(f"{metric}: {value:.4f}")

if __name__ == "__main__":
    train_search_model()