# Advanced Ranking System with TorchRec

In [None]:
import torch
import torchrec
import numpy as np
from typing import Dict, List, Tuple, Optional, NamedTuple
from dataclasses import dataclass
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from utils.data_generators import TorchRecDataGenerator
from utils.debugging import TorchRecDebugger
from utils.benchmark import TorchRecBenchmark

## Advanced Feature Engineering

In [None]:
@dataclass
class RankingFeatures:
    """Complex feature set for ranking"""
    # Dense features
    user_features: torch.Tensor      # age, gender, etc.
    item_features: torch.Tensor      # price, age, etc.
    context_features: torch.Tensor   # time, device, etc.
    
    # Sparse features (IDs)
    user_id: torch.Tensor
    item_id: torch.Tensor
    category_id: torch.Tensor
    
    # Interaction features
    historical_ctr: torch.Tensor    # Historical click-through rate
    user_item_similarity: torch.Tensor
    
    # Sequence features
    history_length: torch.Tensor
    position_ids: torch.Tensor

class FeatureProcessor:
    """Process and normalize features"""
    def __init__(self, feature_config: Dict[str, Dict]):
        self.feature_config = feature_config
        self.scalers = {}
        
    def fit(self, features: Dict[str, torch.Tensor]):
        """Compute normalization parameters"""
        for name, config in self.feature_config.items():
            if config["type"] == "continuous":
                self.scalers[name] = {
                    "mean": features[name].mean(),
                    "std": features[name].std()
                }
    
    def transform(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Normalize features"""
        normalized = {}
        for name, tensor in features.items():
            config = self.feature_config[name]
            
            if config["type"] == "continuous":
                normalized[name] = (tensor - self.scalers[name]["mean"]) / self.scalers[name]["std"]
            else:
                normalized[name] = tensor
                
        return normalized

## Cross Network Architecture

In [None]:
class CrossNetwork(torch.nn.Module):
    """Cross network for feature interactions"""
    def __init__(self, input_dim: int, num_layers: int):
        super().__init__()
        self.num_layers = num_layers
        
        # Cross layers
        self.cross_layers = torch.nn.ModuleList([
            torch.nn.Linear(input_dim, input_dim)
            for _ in range(num_layers)
        ])
        
        # Layer normalization
        self.layer_norms = torch.nn.ModuleList([
            torch.nn.LayerNorm(input_dim)
            for _ in range(num_layers)
        ])
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x0 = x
        for i in range(self.num_layers):
            # Cross operation
            cross = x0 * self.cross_layers[i](x)
            # Residual connection
            x = x + cross
            # Layer normalization
            x = self.layer_norms[i](x)
        return x

## Advanced Ranking Model

In [None]:
class AdvancedRankingModel(torch.nn.Module):
    """Ranking model with cross network and attention"""
    def __init__(
        self,
        num_users: int,
        num_items: int,
        num_categories: int,
        embedding_dim: int = 64,
        hidden_dim: int = 256,
        num_heads: int = 4,
        num_cross_layers: int = 3,
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Embedding tables
        self.embedding_tables = torchrec.EmbeddingBagCollection(
            tables=[
                torchrec.EmbeddingBagConfig(
                    name="user_embeddings",
                    embedding_dim=embedding_dim,
                    num_embeddings=num_users,
                    feature_names=["user_id"],
                ),
                torchrec.EmbeddingBagConfig(
                    name="item_embeddings",
                    embedding_dim=embedding_dim,
                    num_embeddings=num_items,
                    feature_names=["item_id"],
                ),
                torchrec.EmbeddingBagConfig(
                    name="category_embeddings",
                    embedding_dim=embedding_dim,
                    num_embeddings=num_categories,
                    feature_names=["category_id"],
                ),
            ],
            device=torch.device("meta")
        )
        
        # Dense feature processing
        self.dense_network = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim * 3, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.Dropout(dropout)
        )
        
        # Cross network for feature interactions
        self.cross_network = CrossNetwork(
            input_dim=hidden_dim,
            num_layers=num_cross_layers
        )
        
        # Self-attention for feature refinement
        self.self_attention = torch.nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Final ranking layers
        self.ranking_layers = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim * 2, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim, 1)
        )
        
        # Embedding for position bias
        self.position_embedding = torch.nn.Embedding(
            num_embeddings=100,  # Max position
            embedding_dim=hidden_dim
        )
    
    def forward(
        self,
        features: RankingFeatures,
        return_embeddings: bool = False
    ) -> Dict[str, torch.Tensor]:
        # Get embeddings
        sparse_embeddings = self.embedding_tables(
            KeyedJaggedTensor.from_lengths_sync(
                keys=["user_id", "item_id", "category_id"],
                values=torch.cat([
                    features.user_id,
                    features.item_id,
                    features.category_id
                ]),
                lengths=torch.ones(len(features.user_id) * 3)
            )
        ).to_dict()
        
        # Combine embeddings
        combined_embeddings = torch.cat([
            sparse_embeddings["user_embeddings"],
            sparse_embeddings["item_embeddings"],
            sparse_embeddings["category_embeddings"]
        ], dim=1)
        
        # Process dense features
        dense_features = self.dense_network(combined_embeddings)
        
        # Apply cross network
        cross_features = self.cross_network(dense_features)
        
        # Apply self-attention
        attended_features, _ = self.self_attention(
            cross_features.unsqueeze(1),
            cross_features.unsqueeze(1),
            cross_features.unsqueeze(1)
        )
        attended_features = attended_features.squeeze(1)
        
        # Add position embeddings
        position_emb = self.position_embedding(features.position_ids)
        
        # Combine features
        final_features = torch.cat([
            attended_features,
            position_emb
        ], dim=1)
        
        # Generate ranking scores
        scores = self.ranking_layers(final_features)
        
        if return_embeddings:
            return {
                "scores": scores,
                "embeddings": {
                    "user": sparse_embeddings["user_embeddings"],
                    "item": sparse_embeddings["item_embeddings"],
                    "cross": cross_features,
                    "attention": attended_features
                }
            }
        
        return {"scores": scores}

## Advanced Loss Functions

In [None]:
class RankingLoss:
    """Combined ranking loss functions"""
    def __init__(
        self,
        lambda_pair: float = 1.0,
        lambda_list: float = 1.0,
        temperature: float = 1.0
    ):
        self.lambda_pair = lambda_pair
        self.lambda_list = lambda_list
        self.temperature = temperature
    
    def pointwise_loss(
        self,
        scores: torch.Tensor,
        labels: torch.Tensor
    ) -> torch.Tensor:
        """Binary cross entropy loss"""
        return torch.nn.functional.binary_cross_entropy_with_logits(
            scores.squeeze(),
            labels.float()
        )
    
    def pairwise_loss(
        self,
        scores: torch.Tensor,
        labels: torch.Tensor,
        group_ids: torch.Tensor
    ) -> torch.Tensor:
        """Pairwise ranking loss with lambda weights"""
        losses = []
        
        # Process each group (e.g., query) separately
        for group_id in group_ids.unique():
            mask = group_ids == group_id
            group_scores = scores[mask]
            group_labels = labels[mask]
            
            # Compute pairwise differences
            score_diff = group_scores.unsqueeze(0) - group_scores.unsqueeze(1)
            label_diff = group_labels.unsqueeze(0) - group_labels.unsqueeze(1)
            
            # Valid pairs have different labels
            valid_pairs = label_diff != 0
            
            if valid_pairs.sum() > 0:
                # Compute lambda weights (e.g., based on position or label difference)
                lambda_weights = torch.abs(label_diff[valid_pairs])
                
                # Compute loss for valid pairs
                pair_losses = torch.nn.functional.margin_ranking_loss(
                    score_diff[valid_pairs],
                    torch.zeros_like(score_diff[valid_pairs]),
                    torch.sign(label_diff[valid_pairs]),
                    reduction='none'
                )
                
                # Weight the losses
                losses.append((pair_losses * lambda_weights).mean())
        
        return torch.stack(losses).mean() if losses else torch.tensor(0.0).to(scores.device)
    
    def listwise_loss(
        self,
        scores: torch.Tensor,
        labels: torch.Tensor,
        group_ids: torch.Tensor
    ) -> torch.Tensor:
        """ListNet loss with temperature scaling"""
        losses = []
        
        for group_id in group_ids.unique():
            mask = group_ids == group_id
            group_scores = scores[mask]
            group_labels = labels[mask]
            
            # Apply temperature scaling
            scaled_scores = group_scores / self.temperature
            
            # Compute probabilities
            score_probs = torch.nn.functional.softmax(scaled_scores, dim=0)
            label_probs = torch.nn.functional.softmax(group_labels, dim=0)
            
            # Cross entropy between distributions
            list_loss = -(label_probs * torch.log(score_probs + 1e-10)).sum()
            losses.append(list_loss)
        
        return torch.stack(losses).mean()
    
    def combined_loss(
        self,
        scores: torch.Tensor,
        labels: torch.Tensor,
        group_ids: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        point_loss = self.pointwise_loss(scores, labels)
        pair_loss = self.pairwise_loss(scores, labels, group_ids)
        list_loss = self.listwise_loss(scores, labels, group_ids)
        
        total_loss = point_loss + self.lambda_pair * pair_loss + self.lambda_list * list_loss
        
        return {
            "total_loss": total_loss,
            "point_loss": point_loss,
            "pair_loss": pair_loss,
            "list_loss": list_loss
        }