# Advanced Ranking System with TorchRec

In [None]:
import torch
import torchrec
import numpy as np
from typing import Dict, List, Tuple, Optional, NamedTuple, defaultdict
from dataclasses import dataclass
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from utils.data_generators import TorchRecDataGenerator
from utils.debugging import TorchRecDebugger
from utils.benchmark import TorchRecBenchmark

## Advanced Feature Engineering

In [None]:
@dataclass
class RankingFeatures:
    """Complex feature set for ranking"""
    # Dense features
    user_features: torch.Tensor      # age, gender, etc.
    item_features: torch.Tensor      # price, age, etc.
    context_features: torch.Tensor   # time, device, etc.
    
    # Sparse features (IDs)
    user_id: torch.Tensor
    item_id: torch.Tensor
    category_id: torch.Tensor
    
    # Interaction features
    historical_ctr: torch.Tensor    # Historical click-through rate
    user_item_similarity: torch.Tensor
    
    # Sequence features
    history_length: torch.Tensor
    position_ids: torch.Tensor

class FeatureProcessor:
    """Process and normalize features"""
    def __init__(self, feature_config: Dict[str, Dict]):
        self.feature_config = feature_config
        self.scalers = {}
        
    def fit(self, features: Dict[str, torch.Tensor]):
        """Compute normalization parameters"""
        for name, config in self.feature_config.items():
            if config["type"] == "continuous":
                self.scalers[name] = {
                    "mean": features[name].mean(),
                    "std": features[name].std()
                }
    
    def transform(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Normalize features"""
        normalized = {}
        for name, tensor in features.items():
            config = self.feature_config[name]
            
            if config["type"] == "continuous":
                normalized[name] = (tensor - self.scalers[name]["mean"]) / self.scalers[name]["std"]
            else:
                normalized[name] = tensor
                
        return normalized

## Cross Network Architecture

In [None]:
class CrossNetwork(torch.nn.Module):
    """Cross network for feature interactions"""
    def __init__(self, input_dim: int, num_layers: int):
        super().__init__()
        self.num_layers = num_layers
        
        # Cross layers
        self.cross_layers = torch.nn.ModuleList([
            torch.nn.Linear(input_dim, input_dim)
            for _ in range(num_layers)
        ])
        
        # Layer normalization
        self.layer_norms = torch.nn.ModuleList([
            torch.nn.LayerNorm(input_dim)
            for _ in range(num_layers)
        ])
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x0 = x
        for i in range(self.num_layers):
            # Cross operation
            cross = x0 * self.cross_layers[i](x)
            # Residual connection
            x = x + cross
            # Layer normalization
            x = self.layer_norms[i](x)
        return x

## Advanced Ranking Model

In [None]:
class AdvancedRankingModel(torch.nn.Module):
    """Ranking model with cross network and attention"""
    def __init__(
        self,
        num_users: int,
        num_items: int,
        num_categories: int,
        embedding_dim: int = 64,
        hidden_dim: int = 256,
        num_heads: int = 4,
        num_cross_layers: int = 3,
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Embedding tables
        self.embedding_tables = torchrec.EmbeddingBagCollection(
            tables=[
                torchrec.EmbeddingBagConfig(
                    name="user_embeddings",
                    embedding_dim=embedding_dim,
                    num_embeddings=num_users,
                    feature_names=["user_id"],
                ),
                torchrec.EmbeddingBagConfig(
                    name="item_embeddings",
                    embedding_dim=embedding_dim,
                    num_embeddings=num_items,
                    feature_names=["item_id"],
                ),
                torchrec.EmbeddingBagConfig(
                    name="category_embeddings",
                    embedding_dim=embedding_dim,
                    num_embeddings=num_categories,
                    feature_names=["category_id"],
                ),
            ],
            device=torch.device("meta")
        )
        
        # Dense feature processing
        self.dense_network = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim * 3, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.Dropout(dropout)
        )
        
        # Cross network for feature interactions
        self.cross_network = CrossNetwork(
            input_dim=hidden_dim,
            num_layers=num_cross_layers
        )
        
        # Self-attention for feature refinement
        self.self_attention = torch.nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Final ranking layers
        self.ranking_layers = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim * 2, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim, 1)
        )
        
        # Embedding for position bias
        self.position_embedding = torch.nn.Embedding(
            num_embeddings=100,  # Max position
            embedding_dim=hidden_dim
        )
    
    def forward(
        self,
        features: RankingFeatures,
        return_embeddings: bool = False
    ) -> Dict[str, torch.Tensor]:
        # Get embeddings
        sparse_embeddings = self.embedding_tables(
            KeyedJaggedTensor.from_lengths_sync(
                keys=["user_id", "item_id", "category_id"],
                values=torch.cat([
                    features.user_id,
                    features.item_id,
                    features.category_id
                ]),
                lengths=torch.ones(len(features.user_id) * 3)
            )
        ).to_dict()
        
        # Combine embeddings
        combined_embeddings = torch.cat([
            sparse_embeddings["user_embeddings"],
            sparse_embeddings["item_embeddings"],
            sparse_embeddings["category_embeddings"]
        ], dim=1)
        
        # Process dense features
        dense_features = self.dense_network(combined_embeddings)
        
        # Apply cross network
        cross_features = self.cross_network(dense_features)
        
        # Apply self-attention
        attended_features, _ = self.self_attention(
            cross_features.unsqueeze(1),
            cross_features.unsqueeze(1),
            cross_features.unsqueeze(1)
        )
        attended_features = attended_features.squeeze(1)
        
        # Add position embeddings
        position_emb = self.position_embedding(features.position_ids)
        
        # Combine features
        final_features = torch.cat([
            attended_features,
            position_emb
        ], dim=1)
        
        # Generate ranking scores
        scores = self.ranking_layers(final_features)
        
        if return_embeddings:
            return {
                "scores": scores,
                "embeddings": {
                    "user": sparse_embeddings["user_embeddings"],
                    "item": sparse_embeddings["item_embeddings"],
                    "cross": cross_features,
                    "attention": attended_features
                }
            }
        
        return {"scores": scores}

## Advanced Loss Functions

In [None]:
class RankingLoss:
    """Combined ranking loss functions"""
    def __init__(
        self,
        lambda_pair: float = 1.0,
        lambda_list: float = 1.0,
        temperature: float = 1.0
    ):
        self.lambda_pair = lambda_pair
        self.lambda_list = lambda_list
        self.temperature = temperature
    
    def pointwise_loss(
        self,
        scores: torch.Tensor,
        labels: torch.Tensor
    ) -> torch.Tensor:
        """Binary cross entropy loss"""
        return torch.nn.functional.binary_cross_entropy_with_logits(
            scores.squeeze(),
            labels.float()
        )
    
    def pairwise_loss(
        self,
        scores: torch.Tensor,
        labels: torch.Tensor,
        group_ids: torch.Tensor
    ) -> torch.Tensor:
        """Pairwise ranking loss with lambda weights"""
        losses = []
        
        # Process each group (e.g., query) separately
        for group_id in group_ids.unique():
            mask = group_ids == group_id
            group_scores = scores[mask]
            group_labels = labels[mask]
            
            # Compute pairwise differences
            score_diff = group_scores.unsqueeze(0) - group_scores.unsqueeze(1)
            label_diff = group_labels.unsqueeze(0) - group_labels.unsqueeze(1)
            
            # Valid pairs have different labels
            valid_pairs = label_diff != 0
            
            if valid_pairs.sum() > 0:
                # Compute lambda weights (e.g., based on position or label difference)
                lambda_weights = torch.abs(label_diff[valid_pairs])
                
                # Compute loss for valid pairs
                pair_losses = torch.nn.functional.margin_ranking_loss(
                    score_diff[valid_pairs],
                    torch.zeros_like(score_diff[valid_pairs]),
                    torch.sign(label_diff[valid_pairs]),
                    reduction='none'
                )
                
                # Weight the losses
                losses.append((pair_losses * lambda_weights).mean())
        
        return torch.stack(losses).mean() if losses else torch.tensor(0.0).to(scores.device)
    
    def listwise_loss(
        self,
        scores: torch.Tensor,
        labels: torch.Tensor,
        group_ids: torch.Tensor
    ) -> torch.Tensor:
        """ListNet loss with temperature scaling"""
        losses = []
        
        for group_id in group_ids.unique():
            mask = group_ids == group_id
            group_scores = scores[mask]
            group_labels = labels[mask]
            
            # Apply temperature scaling
            scaled_scores = group_scores / self.temperature
            
            # Compute probabilities
            score_probs = torch.nn.functional.softmax(scaled_scores, dim=0)
            label_probs = torch.nn.functional.softmax(group_labels, dim=0)
            
            # Cross entropy between distributions
            list_loss = -(label_probs * torch.log(score_probs + 1e-10)).sum()
            losses.append(list_loss)
        
        return torch.stack(losses).mean()
    
    def combined_loss(
        self,
        scores: torch.Tensor,
        labels: torch.Tensor,
        group_ids: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        point_loss = self.pointwise_loss(scores, labels)
        pair_loss = self.pairwise_loss(scores, labels, group_ids)
        list_loss = self.listwise_loss(scores, labels, group_ids)
        
        total_loss = point_loss + self.lambda_pair * pair_loss + self.lambda_list * list_loss
        
        return {
            "total_loss": total_loss,
            "point_loss": point_loss,
            "pair_loss": pair_loss,
            "list_loss": list_loss
        }

## Training Infrastructure

In [None]:
class RankingTrainer:
    """Training infrastructure for advanced ranking model"""
    def __init__(
        self,
        model: AdvancedRankingModel,
        loss_fn: RankingLoss,
        learning_rate: float = 0.001,
        device: str = "cuda"
    ):
        self.model = model.to(device)
        self.loss_fn = loss_fn
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            weight_decay=0.01
        )
        self.device = device
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=1000
        )
    
    def train_step(
        self,
        features: RankingFeatures,
        labels: torch.Tensor,
        group_ids: torch.Tensor
    ) -> Dict[str, float]:
        self.optimizer.zero_grad()
        
        # Move everything to device
        features = RankingFeatures(*[f.to(self.device) for f in features])
        labels = labels.to(self.device)
        group_ids = group_ids.to(self.device)
        
        # Forward pass
        outputs = self.model(features, return_embeddings=True)
        
        # Compute loss
        losses = self.loss_fn.combined_loss(
            outputs["scores"],
            labels,
            group_ids
        )
        
        # Backward pass
        losses["total_loss"].backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        
        # Update weights
        self.optimizer.step()
        self.scheduler.step()
        
        return {k: v.item() for k, v in losses.items()}

## Advanced Ranking Metrics

In [None]:
class RankingMetrics:
    """Comprehensive ranking evaluation metrics"""
    
    @staticmethod
    def ndcg_at_k(scores: torch.Tensor, labels: torch.Tensor, k: int) -> float:
        """Normalized Discounted Cumulative Gain"""
        device = scores.device
        
        # Sort by predicted scores
        _, indices = scores.squeeze().sort(descending=True)
        sorted_labels = labels[indices][:k]
        
        # Ideal ordering
        _, ideal_indices = labels.sort(descending=True)
        ideal_labels = labels[ideal_indices][:k]
        
        # Position discounts
        positions = torch.arange(1, k + 1, dtype=torch.float, device=device)
        discounts = 1.0 / torch.log2(positions + 1)
        
        # Calculate DCG and IDCG
        dcg = (sorted_labels * discounts).sum()
        idcg = (ideal_labels * discounts).sum()
        
        return (dcg / idcg).item() if idcg > 0 else 0.0
    
    @staticmethod
    def mean_reciprocal_rank(scores: torch.Tensor, labels: torch.Tensor) -> float:
        """Mean Reciprocal Rank"""
        _, indices = scores.squeeze().sort(descending=True)
        sorted_labels = labels[indices]
        
        # Find position of first relevant item
        pos = (sorted_labels == 1).nonzero(as_tuple=True)[0]
        if len(pos) > 0:
            return (1.0 / (pos[0].item() + 1))
        return 0.0
    
    @staticmethod
    def precision_recall_at_k(
        scores: torch.Tensor,
        labels: torch.Tensor,
        k: int
    ) -> Tuple[float, float]:
        """Precision and Recall at k"""
        _, indices = scores.squeeze().sort(descending=True)
        sorted_labels = labels[indices][:k]
        
        relevant_items = sorted_labels.sum().item()
        total_relevant = labels.sum().item()
        
        precision = relevant_items / k
        recall = relevant_items / total_relevant if total_relevant > 0 else 0.0
        
        return precision, recall
    
    @staticmethod
    def evaluate_rankings(
        scores: torch.Tensor,
        labels: torch.Tensor,
        group_ids: torch.Tensor,
        k: int = 10
    ) -> Dict[str, float]:
        metrics = defaultdict(list)
        
        for group_id in group_ids.unique():
            mask = group_ids == group_id
            group_scores = scores[mask]
            group_labels = labels[mask]
            
            metrics["ndcg"].append(
                RankingMetrics.ndcg_at_k(group_scores, group_labels, k)
            )
            metrics["mrr"].append(
                RankingMetrics.mean_reciprocal_rank(group_scores, group_labels)
            )
            precision, recall = RankingMetrics.precision_recall_at_k(
                group_scores, group_labels, k
            )
            metrics["precision"].append(precision)
            metrics["recall"].append(recall)
        
        return {
            k: np.mean(v) for k, v in metrics.items()
        }

## Advanced Data Generation

In [None]:
class RankingDataGenerator:
    """Generate realistic ranking data"""
    def __init__(
        self,
        num_users: int,
        num_items: int,
        num_categories: int,
        feature_dims: Dict[str, int],
        max_list_length: int = 20
    ):
        self.num_users = num_users
        self.num_items = num_items
        self.num_categories = num_categories
        self.feature_dims = feature_dims
        self.max_list_length = max_list_length
        
        # Generate synthetic user and item features
        self.user_features = torch.randn(num_users, feature_dims["user"])
        self.item_features = torch.randn(num_items, feature_dims["item"])
        self.item_categories = torch.randint(0, num_categories, (num_items,))
        
        # Generate synthetic CTR data
        self.item_base_ctr = torch.distributions.Beta(2, 10).sample((num_items,))
    
    def generate_batch(
        self,
        batch_size: int,
        list_size: int
    ) -> Tuple[RankingFeatures, torch.Tensor, torch.Tensor]:
        # Sample users and items
        user_ids = torch.randint(0, self.num_users, (batch_size,))
        query_ids = torch.arange(batch_size).repeat_interleave(list_size)
        
        # For each user, sample items
        item_ids = torch.randint(
            0, self.num_items,
            (batch_size * list_size,)
        )
        
        # Get features
        features = RankingFeatures(
            user_features=self.user_features[user_ids].repeat_interleave(list_size, dim=0),
            item_features=self.item_features[item_ids],
            context_features=torch.randn(batch_size * list_size, self.feature_dims["context"]),
            user_id=user_ids.repeat_interleave(list_size),
            item_id=item_ids,
            category_id=self.item_categories[item_ids],
            historical_ctr=self.item_base_ctr[item_ids],
            user_item_similarity=torch.randn(batch_size * list_size),
            history_length=torch.randint(1, 50, (batch_size * list_size,)),
            position_ids=torch.arange(list_size).repeat(batch_size)
        )
        
        # Generate labels (click probability based on features)
        labels = torch.sigmoid(
            0.1 * features.user_item_similarity +
            0.2 * features.historical_ctr +
            -0.1 * features.position_ids.float()
        )
        
        # Convert to binary labels
        binary_labels = torch.bernoulli(labels)
        
        return features, binary_labels, query_ids

## Training Loop

In [None]:
def train_ranking_model():
    # Initialize components
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    feature_dims = {
        "user": 10,
        "item": 20,
        "context": 5
    }
    
    model = AdvancedRankingModel(
        num_users=100_000,
        num_items=1_000_000,
        num_categories=1000,
        embedding_dim=64,
        hidden_dim=256
    )
    
    loss_fn = RankingLoss(
        lambda_pair=0.5,
        lambda_list=0.5,
        temperature=0.1
    )
    
    trainer = RankingTrainer(model, loss_fn, device=device)
    data_gen = RankingDataGenerator(
        num_users=100_000,
        num_items=1_000_000,
        num_categories=1000,
        feature_dims=feature_dims
    )
    
    # Training loop
    num_epochs = 5
    batch_size = 32
    list_size = 10
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}")
        
        # Training
        model.train()
        epoch_losses = defaultdict(list)
        
        for batch in range(100):  # 100 batches per epoch
            features, labels, group_ids = data_gen.generate_batch(
                batch_size, list_size
            )
            
            losses = trainer.train_step(features, labels, group_ids)
            
            for k, v in losses.items():
                epoch_losses[k].append(v)
            
            if batch % 10 == 0:
                print(f"Batch {batch}, Loss: {losses['total_loss']:.4f}")
        
        # Evaluation
        model.eval()
        eval_metrics = defaultdict(list)
        
        with torch.no_grad():
            for _ in range(10):  # 10 eval batches
                features, labels, group_ids = data_gen.generate_batch(
                    batch_size=100,  # Larger eval batch
                    list_size=list_size
                )
                
                features = RankingFeatures(
                    *[f.to(device) for f in features]
                )
                
                scores = model(features)["scores"]
                metrics = RankingMetrics.evaluate_rankings(
                    scores, labels.to(device), group_ids.to(device)
                )
                
                for k, v in metrics.items():
                    eval_metrics[k].append(v)
        
        print("\nEvaluation Metrics:")
        for k, v in eval_metrics.items():
            print(f"{k}: {np.mean(v):.4f}")

if __name__ == "__main__":
    train_ranking_model()