# Multi-Node Training with TorchRec

In [None]:
import os
import torch
import torchrec
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ShardingType, ParameterConstraints, ShardingEnv
from utils.debugging import TorchRecDebugger
from utils.benchmark import TorchRecBenchmark

## Multi-Node Environment Setup

In [None]:
class MultiNodeConfig:
    """Configuration for multi-node setup"""
    def __init__(
        self,
        world_size: int,
        num_nodes: int,
        node_rank: int,
        master_addr: str = "localhost",
        master_port: str = "29500"
    ):
        self.world_size = world_size
        self.num_nodes = num_nodes
        self.node_rank = node_rank
        self.gpus_per_node = world_size // num_nodes
        self.master_addr = master_addr
        self.master_port = master_port

def setup_multi_node(config: MultiNodeConfig, local_rank: int):
    """Setup multi-node distributed environment"""
    # Calculate global rank
    global_rank = config.node_rank * config.gpus_per_node + local_rank
    
    # Set environment variables
    os.environ["MASTER_ADDR"] = config.master_addr
    os.environ["MASTER_PORT"] = config.master_port
    os.environ["WORLD_SIZE"] = str(config.world_size)
    os.environ["RANK"] = str(global_rank)
    os.environ["LOCAL_RANK"] = str(local_rank)
    
    # Initialize process group
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(local_rank)
    
    return global_rank

## Multi-Node Model Definition

In [None]:
def create_large_model():
    """Create a large model that benefits from multi-node training"""
    tables = [
        torchrec.EmbeddingBagConfig(
            name="large_sparse_table",
            embedding_dim=128,
            num_embeddings=10_000_000,  # 10M rows
            feature_names=["large_sparse_features"],
        ),
        torchrec.EmbeddingBagConfig(
            name="dense_feature_table",
            embedding_dim=256,
            num_embeddings=1_000_000,  # 1M rows
            feature_names=["dense_features"],
        ),
        torchrec.EmbeddingBagConfig(
            name="shared_table",
            embedding_dim=64,
            num_embeddings=5_000_000,  # 5M rows
            feature_names=["feature1", "feature2"],
        ),
    ]
    
    return torchrec.EmbeddingBagCollection(
        tables=tables,
        device=torch.device("meta")
    )

## Multi-Node Sharding Strategy

In [None]:
def create_sharding_plan(model, config: MultiNodeConfig):
    """Create sharding plan optimized for multi-node setup"""
    
    # Define constraints based on table sizes
    constraints = {
        "large_sparse_table": ParameterConstraints(
            sharding_types=[ShardingType.ROW_WISE.value]
        ),
        "dense_feature_table": ParameterConstraints(
            sharding_types=[ShardingType.COLUMN_WISE.value]
        ),
        "shared_table": ParameterConstraints(
            sharding_types=[ShardingType.TABLE_ROW_WISE.value]
        )
    }
    
    # Create topology considering multiple nodes
    topology = Topology(
        world_size=config.world_size,
        compute_device="cuda",
        local_world_size=config.gpus_per_node
    )
    
    # Create planner
    planner = EmbeddingShardingPlanner(
        topology=topology,
        constraints=constraints
    )
    
    # Generate plan
    return planner.collective_plan(
        model,
        [torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder()]
    )

## Data Management Across Nodes

In [None]:
class MultiNodeDataManager:
    """Manage data distribution across nodes"""
    def __init__(self, config: MultiNodeConfig, rank: int):
        self.config = config
        self.rank = rank
        self.debugger = TorchRecDebugger()
    
    def generate_node_specific_batch(self, batch_size: int):
        """Generate data specific to this node/rank"""
        # Calculate local batch size
        local_batch_size = batch_size // self.config.world_size
        
        # Generate values with rank-specific offset
        values = torch.randint(
            low=self.rank * 1000,
            high=(self.rank + 1) * 1000,
            size=(local_batch_size * 10,)
        )
        
        lengths = torch.ones(local_batch_size) * 10
        
        # Create KJT
        return torchrec.sparse.jagged_tensor.KeyedJaggedTensor.from_lengths_sync(
            keys=["large_sparse_features", "dense_features", "feature1", "feature2"],
            values=values.cuda(),
            lengths=lengths.repeat(4)  # 4 features
        )
    
    def verify_data_distribution(self, batch):
        """Verify data is correctly distributed"""
        return {
            "local_batch_size": len(batch.lengths()) // 4,  # 4 features
            "value_range": (batch.values().min().item(), 
                          batch.values().max().item()),
            "device": batch.values().device
        }

##  Multi-Node Training Implementation

In [None]:
class MultiNodeTrainer:
    """Manage distributed training across nodes"""
    def __init__(
        self,
        model: DistributedModelParallel,
        config: MultiNodeConfig,
        rank: int
    ):
        self.model = model
        self.config = config
        self.rank = rank
        self.optimizer = torch.optim.Adam(model.parameters())
        self.data_manager = MultiNodeDataManager(config, rank)
        self.benchmark = TorchRecBenchmark()
    
    def train_step(self, batch):
        """Single training step"""
        self.optimizer.zero_grad()
        
        # Forward pass
        output = self.model(batch)
        embeddings = output.wait()  # Wait for async computation
        
        # Simple loss for demonstration
        loss = torch.mean(embeddings.values())
        
        # Backward pass
        loss.backward()
        
        # Step optimizer
        self.optimizer.step()
        
        return loss.item()
    
    def train_epoch(self, num_batches, batch_size):
        """Train for one epoch"""
        losses = []
        timings = []
        
        for i in range(num_batches):
            # Generate batch
            batch = self.data_manager.generate_node_specific_batch(batch_size)
            
            # Time the training step
            start_time = torch.cuda.Event(enable_timing=True)
            end_time = torch.cuda.Event(enable_timing=True)
            
            start_time.record()
            loss = self.train_step(batch)
            end_time.record()
            
            torch.cuda.synchronize()
            timings.append(start_time.elapsed_time(end_time))
            losses.append(loss)
            
            if self.rank == 0 and i % 10 == 0:
                print(f"Batch {i}, Loss: {loss:.4f}, "
                      f"Time: {timings[-1]:.2f}ms")
        
        return losses, timings

## Inter-Node Communication Analysis

In [None]:
class CommunicationAnalyzer:
    """Analyze communication patterns between nodes"""
    def __init__(self, config: MultiNodeConfig):
        self.config = config
    
    def analyze_communication(self, model: DistributedModelParallel):
        """Analyze communication patterns in the model"""
        communication_stats = {
            "all_to_all": 0,
            "reduce_scatter": 0,
            "all_gather": 0
        }
        
        # Analyze sharding plan
        for name, param in model.named_parameters():
            if hasattr(param, 'linked_param'):
                if 'row_wise' in str(param.linked_param):
                    communication_stats['all_to_all'] += 1
                elif 'column_wise' in str(param.linked_param):
                    communication_stats['reduce_scatter'] += 1
                    communication_stats['all_gather'] += 1
        
        return communication_stats
    
    def estimate_communication_volume(self, model: DistributedModelParallel):
        """Estimate communication volume between nodes"""
        total_bytes = 0
        for name, param in model.named_parameters():
            if hasattr(param, 'linked_param'):
                # Estimate bytes transferred
                total_bytes += param.numel() * param.element_size()
        
        return {
            "total_gb": total_bytes / 1e9,
            "gb_per_node": total_bytes / (1e9 * self.config.num_nodes)
        }

## Performance Monitoring

In [None]:
class MultiNodeMonitor:
    """Monitor multi-node training performance"""
    def __init__(self, config: MultiNodeConfig, rank: int):
        self.config = config
        self.rank = rank
        self.debugger = TorchRecDebugger()
    
    def monitor_step(self, model, batch):
        """Monitor single step performance"""
        memory_stats = self.debugger.memory_status()
        
        return {
            "rank": self.rank,
            "node": self.rank // self.config.gpus_per_node,
            "local_rank": self.rank % self.config.gpus_per_node,
            "memory_allocated_gb": memory_stats["allocated"] / 1e9,
            "memory_reserved_gb": memory_stats["reserved"] / 1e9
        }
    
    def collect_global_stats(self, local_stats):
        """Collect stats from all nodes"""
        # Gather stats from all ranks
        all_stats = [None] * self.config.world_size
        dist.all_gather_object(all_stats, local_stats)
        
        return all_stats

## Launch Multi-Node Training

In [None]:
def main_worker(local_rank, config: MultiNodeConfig):
    """Main worker function for each process"""
    # Setup distributed
    rank = setup_multi_node(config, local_rank)
    
    # Create model
    base_model = create_large_model()
    plan = create_sharding_plan(base_model, config)
    
    model = DistributedModelParallel(
        module=base_model,
        device=torch.device(f"cuda:{local_rank}"),
        plan=plan
    )
    
    # Initialize trainer
    trainer = MultiNodeTrainer(model, config, rank)
    monitor = MultiNodeMonitor(config, rank)
    analyzer = CommunicationAnalyzer(config)
    
    # Train
    batch_size = 1024
    num_batches = 100
    
    losses, timings = trainer.train_epoch(num_batches, batch_size)
    
    # Collect stats
    local_stats = monitor.monitor_step(model, None)
    all_stats = monitor.collect_global_stats(local_stats)
    
    # Analysis (on rank 0)
    if rank == 0:
        comm_stats = analyzer.analyze_communication(model)
        comm_volume = analyzer.estimate_communication_volume(model)
        
        print("\nTraining Summary:")
        print(f"Average Loss: {sum(losses) / len(losses):.4f}")
        print(f"Average Batch Time: {sum(timings) / len(timings):.2f}ms")
        print("\nCommunication Analysis:")
        print(f"Communication Patterns: {comm_stats}")
        print(f"Communication Volume: {comm_volume}")
        print("\nNode Statistics:")
        for stats in all_stats:
            print(f"Node {stats['node']}, Rank {stats['rank']}: "
                  f"{stats['memory_allocated_gb']:.2f}GB allocated")

## Launch Script

In [None]:
def launch_multi_node(num_nodes, gpus_per_node):
    """Launch multi-node training"""
    config = MultiNodeConfig(
        world_size=num_nodes * gpus_per_node,
        num_nodes=num_nodes,
        node_rank=int(os.environ.get("NODE_RANK", 0))
    )
    
    torch.multiprocessing.spawn(
        main_worker,
        args=(config,),
        nprocs=gpus_per_node
    )