# Sequential Recommendation System with TorchRec

In [None]:
import torch
import torchrec
import numpy as np
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from utils.data_generators import TorchRecDataGenerator
from utils.debugging import TorchRecDebugger
from utils.benchmark import TorchRecBenchmark

## Data Structures

In [None]:
@dataclass
class SequentialFeatures:
    """Container for sequential recommendation features"""
    item_ids: torch.Tensor
    timestamps: torch.Tensor
    categories: torch.Tensor
    positions: torch.Tensor  # Position encoding
    lengths: torch.Tensor    # Sequence lengths

class SequenceEncoder:
    """Encode sequential data into TorchRec format"""
    def __init__(
        self,
        max_sequence_length: int,
        num_items: int,
        num_categories: int,
    ):
        self.max_sequence_length = max_sequence_length
        self.num_items = num_items
        self.num_categories = num_categories
    
    def encode_sequence(
        self,
        features: SequentialFeatures
    ) -> KeyedJaggedTensor:
        """Convert sequence features to KJT format"""
        # Concatenate all features
        values = torch.cat([
            features.item_ids,
            features.categories,
            features.positions,
            features.timestamps
        ])
        
        # Repeat lengths for each feature type
        lengths = features.lengths.repeat(4)  # 4 feature types
        
        return KeyedJaggedTensor.from_lengths_sync(
            keys=["item_id", "category", "position", "timestamp"],
            values=values,
            lengths=lengths
        )

## Sequential Model Architecture

In [None]:
class SequentialRecommender(torch.nn.Module):
    """Sequential recommendation model with attention"""
    def __init__(
        self,
        num_items: int,
        num_categories: int,
        embedding_dim: int = 64,
        hidden_dim: int = 128,
        num_heads: int = 4,
        dropout: float = 0.1,
    ):
        super().__init__()
        
        # Embedding tables
        self.embedding_tables = torchrec.EmbeddingBagCollection(
            tables=[
                torchrec.EmbeddingBagConfig(
                    name="item_embeddings",
                    embedding_dim=embedding_dim,
                    num_embeddings=num_items,
                    feature_names=["item_id"],
                ),
                torchrec.EmbeddingBagConfig(
                    name="category_embeddings",
                    embedding_dim=embedding_dim,
                    num_embeddings=num_categories,
                    feature_names=["category"],
                ),
                torchrec.EmbeddingBagConfig(
                    name="position_embeddings",
                    embedding_dim=embedding_dim,
                    num_embeddings=1000,  # Max sequence length
                    feature_names=["position"],
                ),
            ],
            device=torch.device("meta"),
        )
        
        # Time encoding
        self.time_encoder = torch.nn.Linear(1, embedding_dim)
        
        # Multi-head attention
        self.attention = torch.nn.MultiheadAttention(
            embed_dim=embedding_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Transformer layers
        self.transformer_encoder = torch.nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            dropout=dropout,
            batch_first=True
        )
        
        # Output layers
        self.output_layer = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim, embedding_dim)
        )
        
    def encode_sequence(
        self,
        sequence_features: KeyedJaggedTensor
    ) -> torch.Tensor:
        """Encode sequence into fixed-length representation"""
        # Get embeddings
        embeddings_dict = self.embedding_tables(sequence_features).to_dict()
        
        # Combine embeddings
        item_emb = embeddings_dict["item_embeddings"]
        category_emb = embeddings_dict["category_embeddings"]
        position_emb = embeddings_dict["position_embeddings"]
        
        # Encode timestamps
        timestamps = sequence_features.values()[
            sequence_features.keys().index("timestamp")
        ].float().unsqueeze(-1)
        time_emb = self.time_encoder(timestamps)
        
        # Combine all embeddings
        sequence_repr = item_emb + category_emb + position_emb + time_emb
        
        # Apply transformer encoding
        sequence_repr = self.transformer_encoder(sequence_repr)
        
        # Apply attention
        attention_out, _ = self.attention(
            sequence_repr, sequence_repr, sequence_repr
        )
        
        # Final sequence representation
        return self.output_layer(attention_out)
    
    def forward(
        self,
        user_sequence: KeyedJaggedTensor,
        candidate_items: KeyedJaggedTensor,
    ) -> torch.Tensor:
        # Encode user sequence
        sequence_repr = self.encode_sequence(user_sequence)
        
        # Get candidate item embeddings
        candidate_embeddings = self.embedding_tables(
            candidate_items
        ).to_dict()["item_embeddings"]
        
        # Compute similarity scores
        return torch.matmul(sequence_repr, candidate_embeddings.t())

## Sequential Data Generation

In [None]:
class SequentialDataGenerator:
    """Generate synthetic sequential data"""
    def __init__(
        self,
        num_users: int,
        num_items: int,
        num_categories: int,
        max_sequence_length: int,
        min_sequence_length: int = 2,
    ):
        self.num_users = num_users
        self.num_items = num_items
        self.num_categories = num_categories
        self.max_sequence_length = max_sequence_length
        self.min_sequence_length = min_sequence_length
        
        # Generate item categories
        self.item_categories = torch.randint(
            0, num_categories, (num_items,)
        )
    
    def generate_sequence(
        self,
        batch_size: int
    ) -> Tuple[SequentialFeatures, KeyedJaggedTensor]:
        """Generate a batch of sequences"""
        # Generate sequence lengths
        lengths = torch.randint(
            self.min_sequence_length,
            self.max_sequence_length + 1,
            (batch_size,)
        )
        
        # Generate item IDs
        item_ids = torch.randint(
            0, self.num_items,
            (lengths.sum(),)
        )
        
        # Get categories
        categories = self.item_categories[item_ids]
        
        # Generate timestamps (increasing within sequence)
        timestamps = torch.zeros_like(item_ids, dtype=torch.float32)
        offset = 0
        for i, length in enumerate(lengths):
            timestamps[offset:offset + length] = torch.arange(length)
            offset += length
        
        # Generate position encodings
        positions = timestamps.clone()
        
        # Create features
        features = SequentialFeatures(
            item_ids=item_ids,
            timestamps=timestamps,
            categories=categories,
            positions=positions,
            lengths=lengths
        )
        
        # Convert to KJT
        encoder = SequenceEncoder(
            max_sequence_length=self.max_sequence_length,
            num_items=self.num_items,
            num_categories=self.num_categories
        )
        
        return features, encoder.encode_sequence(features)

## Training Infrastructure

In [None]:
class SequentialTrainer:
    """Trainer for sequential recommender"""
    def __init__(
        self,
        model: SequentialRecommender,
        learning_rate: float = 0.001,
        device: str = "cuda",
    ):
        self.model = model.to(device)
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        self.device = device
        self.debugger = TorchRecDebugger()
    
    def train_step(
        self,
        user_sequence: KeyedJaggedTensor,
        target_items: KeyedJaggedTensor,
        labels: torch.Tensor,
    ) -> float:
        self.optimizer.zero_grad()
        
        # Move to device
        user_sequence = user_sequence.to(self.device)
        target_items = target_items.to(self.device)
        labels = labels.to(self.device)
        
        # Forward pass
        scores = self.model(user_sequence, target_items)
        
        # Compute loss
        loss = torch.nn.functional.binary_cross_entropy_with_logits(
            scores.view(-1), labels.float()
        )
        
        # Backward pass
        loss.backward()
        
        # Update weights
        self.optimizer.step()
        
        return loss.item()

## Training Loop

In [None]:
def train_sequential_model(
    num_users: int = 10000,
    num_items: int = 1000,
    num_categories: int = 100,
    max_sequence_length: int = 50,
    batch_size: int = 32,
    num_epochs: int = 5,
    device: str = "cuda",
):
    # Create model and data generator
    model = SequentialRecommender(
        num_items=num_items,
        num_categories=num_categories
    )
    
    data_gen = SequentialDataGenerator(
        num_users=num_users,
        num_items=num_items,
        num_categories=num_categories,
        max_sequence_length=max_sequence_length
    )
    
    trainer = SequentialTrainer(model, device=device)
    
    # Training loop
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}")
        epoch_losses = []
        
        for batch in range(100):  # 100 batches per epoch
            # Generate sequence data
            features, sequence_kjt = data_gen.generate_sequence(batch_size)
            
            # Generate target items (next items in sequence)
            target_items = torch.randint(0, num_items, (batch_size,))
            target_kjt = KeyedJaggedTensor.from_lengths_sync(
                keys=["item_id"],
                values=target_items,
                lengths=torch.ones(batch_size)
            )
            
            # Generate labels (1 for actual next item, 0 for random items)
            labels = torch.zeros(batch_size)
            labels[0] = 1  # Assume first item is the true next item
            
            # Train step
            loss = trainer.train_step(sequence_kjt, target_kjt, labels)
            epoch_losses.append(loss)
            
            if batch % 10 == 0:
                print(f"Batch {batch}, Loss: {loss:.4f}")
        
        print(f"Epoch {epoch + 1} Average Loss: {sum(epoch_losses) / len(epoch_losses):.4f}")

if __name__ == "__main__":
    train_sequential_model()