# Multi-Task Recommendation System with TorchRec

In [None]:
import torch
import torchrec
import numpy as np
from typing import Dict, List, Tuple, NamedTuple
from dataclasses import dataclass
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from utils.data_generators import TorchRecDataGenerator
from utils.debugging import TorchRecDebugger
from utils.benchmark import TorchRecBenchmark

## Multi-Task Data Structures

In [None]:
@dataclass
class TaskConfig:
    """Configuration for each prediction task"""
    name: str
    weight: float
    metric_type: str  # 'binary', 'regression', 'ranking'
    loss_type: str    # 'bce', 'mse', 'hinge'

class MultiTaskTargets(NamedTuple):
    """Container for multiple task targets"""
    click: torch.Tensor
    purchase: torch.Tensor
    watch_time: torch.Tensor
    rating: torch.Tensor

## Multi-Task Model Architecture

In [None]:
class MultiTaskTower(torch.nn.Module):
    """Task-specific tower network"""
    def __init__(
        self,
        input_dim: int,
        hidden_dims: List[int],
        output_dim: int,
        dropout: float = 0.1
    ):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                torch.nn.Linear(prev_dim, hidden_dim),
                torch.nn.ReLU(),
                torch.nn.BatchNorm1d(hidden_dim),
                torch.nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim
        
        layers.append(torch.nn.Linear(prev_dim, output_dim))
        self.network = torch.nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)

class MultiTaskRecommender(torch.nn.Module):
    """Multi-task recommendation model"""
    def __init__(
        self,
        num_items: int,
        num_categories: int,
        embedding_dim: int = 64,
        hidden_dim: int = 128,
        tower_hidden_dims: List[int] = [256, 128],
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Shared embedding layer
        self.embedding_tables = torchrec.EmbeddingBagCollection(
            tables=[
                torchrec.EmbeddingBagConfig(
                    name="item_embeddings",
                    embedding_dim=embedding_dim,
                    num_embeddings=num_items,
                    feature_names=["item_id"],
                ),
                torchrec.EmbeddingBagConfig(
                    name="category_embeddings",
                    embedding_dim=embedding_dim,
                    num_embeddings=num_categories,
                    feature_names=["category_id"],
                ),
                torchrec.EmbeddingBagConfig(
                    name="user_history",
                    embedding_dim=embedding_dim,
                    num_embeddings=num_items,
                    feature_names=["history"],
                ),
            ],
            device=torch.device("meta")
        )
        
        # Shared bottom network
        self.shared_network = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim * 3, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(hidden_dim)
        )
        
        # Task-specific towers
        self.click_tower = MultiTaskTower(
            input_dim=hidden_dim,
            hidden_dims=tower_hidden_dims,
            output_dim=1
        )
        
        self.purchase_tower = MultiTaskTower(
            input_dim=hidden_dim,
            hidden_dims=tower_hidden_dims,
            output_dim=1
        )
        
        self.watch_time_tower = MultiTaskTower(
            input_dim=hidden_dim,
            hidden_dims=tower_hidden_dims,
            output_dim=1
        )
        
        self.rating_tower = MultiTaskTower(
            input_dim=hidden_dim,
            hidden_dims=tower_hidden_dims,
            output_dim=1
        )
        
        # Task configurations
        self.task_configs = {
            "click": TaskConfig("click", 1.0, "binary", "bce"),
            "purchase": TaskConfig("purchase", 2.0, "binary", "bce"),
            "watch_time": TaskConfig("watch_time", 0.5, "regression", "mse"),
            "rating": TaskConfig("rating", 1.0, "regression", "mse")
        }
    
    def forward(
        self,
        user_features: KeyedJaggedTensor,
        item_features: KeyedJaggedTensor
    ) -> Dict[str, torch.Tensor]:
        # Get embeddings
        user_embeddings = self.embedding_tables(user_features)
        item_embeddings = self.embedding_tables(item_features)
        
        # Concatenate embeddings
        combined_embeddings = torch.cat([
            user_embeddings.to_dict()["user_history"],
            item_embeddings.to_dict()["item_embeddings"],
            item_embeddings.to_dict()["category_embeddings"]
        ], dim=1)
        
        # Shared representation
        shared_repr = self.shared_network(combined_embeddings)
        
        # Task-specific predictions
        return {
            "click": self.click_tower(shared_repr),
            "purchase": self.purchase_tower(shared_repr),
            "watch_time": self.watch_time_tower(shared_repr),
            "rating": self.rating_tower(shared_repr)
        }

## Multi-Task Loss Functions

In [None]:
class MultiTaskLoss:
    """Handle multiple loss functions with weighting"""
    def __init__(self, task_configs: Dict[str, TaskConfig]):
        self.task_configs = task_configs
        
        self.loss_fns = {
            "bce": torch.nn.BCEWithLogitsLoss(reduction='none'),
            "mse": torch.nn.MSELoss(reduction='none'),
            "hinge": torch.nn.HingeEmbeddingLoss(reduction='none')
        }
    
    def compute_loss(
        self,
        predictions: Dict[str, torch.Tensor],
        targets: MultiTaskTargets,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        task_losses = {}
        total_loss = 0.0
        
        for task_name, pred in predictions.items():
            config = self.task_configs[task_name]
            target = getattr(targets, task_name)
            
            # Compute task-specific loss
            loss = self.loss_fns[config.loss_type](
                pred.squeeze(),
                target.float()
            )
            
            # Apply mask if provided
            if mask is not None:
                loss = loss * mask
            
            # Average and weight the loss
            task_loss = loss.mean() * config.weight
            task_losses[task_name] = task_loss
            total_loss += task_loss
        
        return total_loss, task_losses


## Multi-Task Data Generation

In [None]:
class MultiTaskDataGenerator:
    """Generate synthetic data for multi-task learning"""
    def __init__(
        self,
        num_users: int,
        num_items: int,
        num_categories: int,
        max_history_length: int = 10
    ):
        self.num_users = num_users
        self.num_items = num_items
        self.num_categories = num_categories
        self.max_history_length = max_history_length
        
        # Generate item categories
        self.item_categories = torch.randint(
            0, num_categories, (num_items,)
        )
    
    def generate_batch(
        self,
        batch_size: int
    ) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor, MultiTaskTargets]:
        # Generate user history
        history_lengths = torch.randint(
            1, self.max_history_length + 1,
            (batch_size,)
        )
        
        history_ids = torch.randint(
            0, self.num_items,
            (history_lengths.sum(),)
        )
        
        user_features = KeyedJaggedTensor.from_lengths_sync(
            keys=["history"],
            values=history_ids,
            lengths=history_lengths
        )
        
        # Generate candidate items
        item_ids = torch.randint(0, self.num_items, (batch_size,))
        categories = self.item_categories[item_ids]
        
        item_features = KeyedJaggedTensor.from_lengths_sync(
            keys=["item_id", "category_id"],
            values=torch.cat([item_ids, categories]),
            lengths=torch.ones(batch_size * 2)
        )
        
        # Generate targets
        targets = MultiTaskTargets(
            click=torch.bernoulli(torch.rand(batch_size) * 0.2),
            purchase=torch.bernoulli(torch.rand(batch_size) * 0.05),
            watch_time=torch.rand(batch_size) * 3600,  # seconds
            rating=torch.randint(1, 6, (batch_size,)).float()
        )
        
        return user_features, item_features, targets

## Multi-Task Trainer

In [None]:
class MultiTaskTrainer:
    """Training infrastructure for multi-task model"""
    def __init__(
        self,
        model: MultiTaskRecommender,
        loss_fn: MultiTaskLoss,
        learning_rate: float = 0.001,
        device: str = "cuda"
    ):
        self.model = model.to(device)
        self.loss_fn = loss_fn
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        self.device = device
        self.debugger = TorchRecDebugger()
    
    def train_step(
        self,
        user_features: KeyedJaggedTensor,
        item_features: KeyedJaggedTensor,
        targets: MultiTaskTargets
    ) -> Dict[str, float]:
        self.optimizer.zero_grad()
        
        # Move inputs to device
        user_features = user_features.to(self.device)
        item_features = item_features.to(self.device)
        targets = MultiTaskTargets(*[t.to(self.device) for t in targets])
        
        # Forward pass
        predictions = self.model(user_features, item_features)
        
        # Compute losses
        total_loss, task_losses = self.loss_fn.compute_loss(
            predictions, targets
        )
        
        # Backward pass
        total_loss.backward()
        
        # Update weights
        self.optimizer.step()
        
        # Return losses
        return {
            "total_loss": total_loss.item(),
            **{f"{k}_loss": v.item() for k, v in task_losses.items()}
        }

## Evaluation Metrics

In [None]:
class MultiTaskEvaluator:
    """Evaluate multi-task model performance"""
    def __init__(self, task_configs: Dict[str, TaskConfig]):
        self.task_configs = task_configs
    
    @torch.no_grad()
    def evaluate(
        self,
        predictions: Dict[str, torch.Tensor],
        targets: MultiTaskTargets
    ) -> Dict[str, float]:
        metrics = {}
        
        for task_name, pred in predictions.items():
            config = self.task_configs[task_name]
            target = getattr(targets, task_name)
            
            if config.metric_type == "binary":
                # AUC and accuracy for binary tasks
                pred_prob = torch.sigmoid(pred.squeeze())
                pred_binary = (pred_prob >= 0.5).float()
                
                metrics[f"{task_name}_accuracy"] = (
                    (pred_binary == target).float().mean().item()
                )
                
            elif config.metric_type == "regression":
                # MSE and MAE for regression tasks
                pred = pred.squeeze()
                metrics[f"{task_name}_mse"] = (
                    torch.mean((pred - target) ** 2).item()
                )
                metrics[f"{task_name}_mae"] = (
                    torch.mean(torch.abs(pred - target)).item()
                )
        
        return metrics

## Training Loop

In [None]:
def train_multi_task_model():
    # Initialize components
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = MultiTaskRecommender(
        num_items=10000,
        num_categories=100,
        embedding_dim=64,
        hidden_dim=128
    )
    
    loss_fn = MultiTaskLoss(model.task_configs)
    trainer = MultiTaskTrainer(model, loss_fn, device=device)
    evaluator = MultiTaskEvaluator(model.task_configs)
    
    data_gen = MultiTaskDataGenerator(
        num_users=10000,
        num_items=10000,
        num_categories=100
    )
    
    # Training loop
    num_epochs = 5
    batch_size = 64
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}")
        
        epoch_metrics = {
            "total_loss": 0.0,
            "click_loss": 0.0,
            "purchase_loss": 0.0,
            "watch_time_loss": 0.0,
            "rating_loss": 0.0
        }
        
        # Training
        model.train()
        for batch in range(100):  # 100 batches per epoch
            # Generate batch
            user_features, item_features, targets = data_gen.generate_batch(batch_size)
            
            # Training step
            metrics = trainer.train_step(user_features, item_features, targets)
            
            # Update metrics
            for k, v in metrics.items():
                epoch_metrics[k] += v
            
            if batch % 10 == 0:
                print(f"Batch {batch}, Loss: {metrics['total_loss']:.4f}")
        
        # Average metrics
        for k in epoch_metrics:
            epoch_metrics[k] /= 100
        
        print("\nEpoch Metrics:")
        for k, v in epoch_metrics.items():
            print(f"{k}: {v:.4f}")
        
        # Evaluation
        model.eval()
        eval_user_features, eval_item_features, eval_targets = data_gen.generate_batch(1000)
        
        with torch.no_grad():
            eval_predictions = model(
                eval_user_features.to(device),
                eval_item_features.to(device)
            )
            eval_metrics = evaluator.evaluate(eval_predictions, eval_targets)
        
        print("\nEvaluation Metrics:")
        for k, v in eval_metrics.items():
            print(f"{k}: {v:.4f}")

if __name__ == "__main__":
    train