# Multi-Tower Ranking System with TorchRec

In [None]:
import torch
import torchrec
from typing import Dict, List, Tuple, NamedTuple, Optional
from dataclasses import dataclass
import numpy as np
from collections import defaultdict
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

## Multi-Tower Feature Definitions

In [None]:
@dataclass
class TowerFeatures:
    """Features for each tower"""
    dense_features: torch.Tensor
    sparse_ids: torch.Tensor
    sparse_weights: Optional[torch.Tensor] = None

class MultiTowerFeatures:
    """Complete feature set for multi-tower model"""
    def __init__(
        self,
        user_tower: TowerFeatures,
        item_tower: TowerFeatures,
        context_tower: TowerFeatures,
        sequence_tower: Optional[TowerFeatures] = None
    ):
        self.user_tower = user_tower
        self.item_tower = item_tower
        self.context_tower = context_tower
        self.sequence_tower = sequence_tower
    
    def to(self, device: torch.device) -> 'MultiTowerFeatures':
        return MultiTowerFeatures(
            user_tower=TowerFeatures(
                dense_features=self.user_tower.dense_features.to(device),
                sparse_ids=self.user_tower.sparse_ids.to(device),
                sparse_weights=self.user_tower.sparse_weights.to(device) 
                    if self.user_tower.sparse_weights is not None else None
            ),
            item_tower=TowerFeatures(
                dense_features=self.item_tower.dense_features.to(device),
                sparse_ids=self.item_tower.sparse_ids.to(device),
                sparse_weights=self.item_tower.sparse_weights.to(device)
                    if self.item_tower.sparse_weights is not None else None
            ),
            context_tower=TowerFeatures(
                dense_features=self.context_tower.dense_features.to(device),
                sparse_ids=self.context_tower.sparse_ids.to(device),
                sparse_weights=self.context_tower.sparse_weights.to(device)
                    if self.context_tower.sparse_weights is not None else None
            ),
            sequence_tower=TowerFeatures(
                dense_features=self.sequence_tower.dense_features.to(device),
                sparse_ids=self.sequence_tower.sparse_ids.to(device),
                sparse_weights=self.sequence_tower.sparse_weights.to(device)
                    if self.sequence_tower is not None and 
                       self.sequence_tower.sparse_weights is not None else None
            ) if self.sequence_tower is not None else None
        )

## Tower Architecture

In [None]:
class Tower(torch.nn.Module):
    """Single tower of multi-tower architecture"""
    def __init__(
        self,
        embedding_tables: torchrec.EmbeddingBagCollection,
        dense_dim: int,
        embedding_dim: int,
        hidden_dims: List[int],
        dropout: float = 0.1,
        use_batch_norm: bool = True
    ):
        super().__init__()
        
        self.embedding_tables = embedding_tables
        input_dim = dense_dim + embedding_dim
        
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                torch.nn.Linear(prev_dim, hidden_dim),
                torch.nn.ReLU(),
                torch.nn.Dropout(dropout)
            ])
            
            if use_batch_norm:
                layers.append(torch.nn.BatchNorm1d(hidden_dim))
            
            prev_dim = hidden_dim
        
        self.network = torch.nn.Sequential(*layers)
    
    def forward(self, features: TowerFeatures) -> torch.Tensor:
        # Get embeddings
        sparse_embeddings = self.embedding_tables(
            KeyedJaggedTensor.from_lengths_sync(
                keys=["sparse_id"],
                values=features.sparse_ids,
                lengths=torch.ones(len(features.sparse_ids))
            )
        )
        
        # Combine dense and sparse features
        combined_features = torch.cat([
            features.dense_features,
            sparse_embeddings.values()
        ], dim=1)
        
        return self.network(combined_features)



## Cross-Tower Interaction

In [None]:
class CrossInteraction(torch.nn.Module):
    """Cross interaction between tower outputs"""
    def __init__(
        self,
        input_dim: int,
        num_layers: int = 3
    ):
        super().__init__()
        
        self.layers = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Linear(input_dim, input_dim),
                torch.nn.LayerNorm(input_dim),
                torch.nn.ReLU()
            ) for _ in range(num_layers)
        ])
        
        self.attention = torch.nn.MultiheadAttention(
            embed_dim=input_dim,
            num_heads=4,
            batch_first=True
        )
    
    def forward(self, tower_outputs: List[torch.Tensor]) -> torch.Tensor:
        # Stack tower outputs
        stacked = torch.stack(tower_outputs, dim=1)  # [batch_size, num_towers, dim]
        
        # Self-attention across towers
        attended, _ = self.attention(stacked, stacked, stacked)
        
        # Cross layers
        cross_features = attended
        for layer in self.layers:
            cross_features = layer(cross_features) + cross_features
        
        return cross_features.mean(dim=1)  # Pool across towers

## Multi-Tower Model

In [None]:
class MultiTowerModel(torch.nn.Module):
    """Complete multi-tower ranking model"""
    def __init__(
        self,
        tower_configs: Dict[str, Dict],
        final_hidden_dims: List[int],
        embedding_dim: int = 64,
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Initialize towers
        self.towers = torch.nn.ModuleDict()
        for tower_name, config in tower_configs.items():
            embedding_tables = torchrec.EmbeddingBagCollection(
                tables=[
                    torchrec.EmbeddingBagConfig(
                        name=f"{tower_name}_embeddings",
                        embedding_dim=embedding_dim,
                        num_embeddings=config["num_embeddings"],
                        feature_names=["sparse_id"]
                    )
                ],
                device=torch.device("meta")
            )
            
            self.towers[tower_name] = Tower(
                embedding_tables=embedding_tables,
                dense_dim=config["dense_dim"],
                embedding_dim=embedding_dim,
                hidden_dims=config["hidden_dims"],
                dropout=dropout
            )
        
        # Cross interaction
        self.cross_interaction = CrossInteraction(
            input_dim=config["hidden_dims"][-1]
        )
        
        # Final layers
        layers = []
        prev_dim = config["hidden_dims"][-1]
        
        for hidden_dim in final_hidden_dims:
            layers.extend([
                torch.nn.Linear(prev_dim, hidden_dim),
                torch.nn.ReLU(),
                torch.nn.Dropout(dropout),
                torch.nn.BatchNorm1d(hidden_dim)
            ])
            prev_dim = hidden_dim
        
        layers.append(torch.nn.Linear(prev_dim, 1))
        
        self.final_layers = torch.nn.Sequential(*layers)
    
    def forward(
        self,
        features: MultiTowerFeatures
    ) -> Dict[str, torch.Tensor]:
        # Process each tower
        tower_outputs = []
        
        for tower_name, tower in self.towers.items():
            tower_features = getattr(features, f"{tower_name}_tower")
            tower_output = tower(tower_features)
            tower_outputs.append(tower_output)
        
        # Cross interaction
        cross_features = self.cross_interaction(tower_outputs)
        
        # Final prediction
        ranking_score = self.final_layers(cross_features)
        
        return {
            "score": ranking_score,
            "tower_embeddings": {
                name: output for name, output in zip(self.towers.keys(), tower_outputs)
            },
            "cross_features": cross_features
        }

## Advanced Loss Function

In [None]:
class MultiTowerLoss:
    """Multi-objective loss for multi-tower model"""
    def __init__(
        self,
        margin: float = 0.1,
        tower_weights: Optional[Dict[str, float]] = None
    ):
        self.margin = margin
        self.tower_weights = tower_weights or {
            "user": 1.0,
            "item": 1.0,
            "context": 0.5
        }
    
    def compute_loss(
        self,
        outputs: Dict[str, torch.Tensor],
        labels: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        scores = outputs["score"]
        tower_embeddings = outputs["tower_embeddings"]
        cross_features = outputs["cross_features"]
        
        # Main ranking loss
        ranking_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            scores.squeeze(),
            labels.float()
        )
        
        # Tower regularization
        tower_losses = {}
        for tower_name, embeddings in tower_embeddings.items():
            tower_losses[f"{tower_name}_reg"] = (
                self.tower_weights[tower_name] * 
                torch.norm(embeddings, p=2, dim=1).mean()
            )
        
        # Contrastive loss
        pos_features = cross_features[labels == 1]
        neg_features = cross_features[labels == 0]
        
        if len(pos_features) > 0 and len(neg_features) > 0:
            contrastive_loss = torch.nn.functional.triplet_margin_loss(
                anchor=pos_features,
                positive=pos_features.roll(1, dims=0),  # Positive pairs
                negative=neg_features,
                margin=self.margin
            )
        else:
            contrastive_loss = torch.tensor(0.0, device=scores.device)
        
        # Combine losses
        total_loss = ranking_loss + contrastive_loss
        for tower_loss in tower_losses.values():
            total_loss += tower_loss
        
        return {
            "total_loss": total_loss,
            "ranking_loss": ranking_loss,
            "contrastive_loss": contrastive_loss,
            **tower_losses
        }

## Multi-Tower Data Generation

In [None]:
class MultiTowerDataGenerator:
    """Generate synthetic data for multi-tower model"""
    def __init__(
        self,
        num_users: int,
        num_items: int,
        num_contexts: int,
        user_dense_dim: int = 10,
        item_dense_dim: int = 20,
        context_dense_dim: int = 5,
        sequence_length: int = 10
    ):
        self.num_users = num_users
        self.num_items = num_items
        self.num_contexts = num_contexts
        
        # Generate synthetic user features
        self.user_features = {
            "dense": torch.randn(num_users, user_dense_dim),
            "categories": torch.randint(0, 10, (num_users,)),
            "activity_level": torch.rand(num_users)
        }
        
        # Generate synthetic item features
        self.item_features = {
            "dense": torch.randn(num_items, item_dense_dim),
            "categories": torch.randint(0, 20, (num_items,)),
            "popularity": torch.rand(num_items)
        }
        
        # Generate synthetic context features
        self.context_features = {
            "dense": torch.randn(num_contexts, context_dense_dim),
            "time_bins": torch.randint(0, 24, (num_contexts,)),
            "device_types": torch.randint(0, 5, (num_contexts,))
        }
        
        # Generate interaction patterns
        self.interaction_patterns = self._generate_interaction_patterns()
    
    def _generate_interaction_patterns(self) -> Dict[str, torch.Tensor]:
        """Generate interaction patterns between different features"""
        return {
            "user_item": torch.rand(self.num_users, self.num_items),
            "user_context": torch.rand(self.num_users, self.num_contexts),
            "item_context": torch.rand(self.num_items, self.num_contexts)
        }
    
    def generate_batch(
        self,
        batch_size: int,
        include_sequence: bool = True
    ) -> Tuple[MultiTowerFeatures, torch.Tensor]:
        # Sample IDs
        user_ids = torch.randint(0, self.num_users, (batch_size,))
        item_ids = torch.randint(0, self.num_items, (batch_size,))
        context_ids = torch.randint(0, self.num_contexts, (batch_size,))
        
        # Generate tower features
        user_tower = TowerFeatures(
            dense_features=self.user_features["dense"][user_ids],
            sparse_ids=user_ids,
            sparse_weights=self.user_features["activity_level"][user_ids]
        )
        
        item_tower = TowerFeatures(
            dense_features=self.item_features["dense"][item_ids],
            sparse_ids=item_ids,
            sparse_weights=self.item_features["popularity"][item_ids]
        )
        
        context_tower = TowerFeatures(
            dense_features=self.context_features["dense"][context_ids],
            sparse_ids=context_ids
        )
        
        # Generate sequence features if needed
        sequence_tower = None
        if include_sequence:
            seq_length = 10
            history_ids = torch.randint(
                0, self.num_items,
                (batch_size, seq_length)
            )
            
            sequence_tower = TowerFeatures(
                dense_features=torch.stack([
                    self.item_features["dense"][history_ids[i]].mean(0)
                    for i in range(batch_size)
                ]),
                sparse_ids=history_ids.view(-1)
            )
        
        features = MultiTowerFeatures(
            user_tower=user_tower,
            item_tower=item_tower,
            context_tower=context_tower,
            sequence_tower=sequence_tower
        )
        
        # Generate labels based on interaction patterns
        interaction_scores = (
            self.interaction_patterns["user_item"][user_ids, item_ids] * 0.4 +
            self.interaction_patterns["user_context"][user_ids, context_ids] * 0.3 +
            self.interaction_patterns["item_context"][item_ids, context_ids] * 0.3
        )
        
        labels = torch.bernoulli(interaction_scores * 0.2)  # 20% positive rate
        
        return features, labels