# Single Node Multi-GPU Training with TorchRec

In [None]:
import os
import torch
import torchrec
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ShardingType, ShardingEnv
from utils.debugging import TorchRecDebugger
from utils.benchmark import TorchRecBenchmark

## Distributed Environment Setup

In [None]:
def setup_distributed(rank: int, world_size: int):
    """Initialize distributed environment"""
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["RANK"] = str(rank)
    
    # Initialize process group
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(rank)

# Check available GPUs
num_gpus = torch.cuda.device_count()
print(f"Available GPUs: {num_gpus}")

## Create Base Model

In [None]:
# Define embedding tables
tables = [
    torchrec.EmbeddingBagConfig(
        name="large_table",
        embedding_dim=128,
        num_embeddings=1_000_000,
        feature_names=["large_features"],
        pooling=torchrec.PoolingType.SUM,
    ),
    torchrec.EmbeddingBagConfig(
        name="medium_table",
        embedding_dim=64,
        num_embeddings=100_000,
        feature_names=["medium_features"],
        pooling=torchrec.PoolingType.SUM,
    ),
    torchrec.EmbeddingBagConfig(
        name="small_table",
        embedding_dim=32,
        num_embeddings=10_000,
        feature_names=["small_features"],
        pooling=torchrec.PoolingType.SUM,
    ),
]

# Create base model
model = torchrec.EmbeddingBagCollection(
    tables=tables,
    device=torch.device("meta")
)

## Sharding Configuration

In [None]:
# Define sharding constraints
constraints = {
    "large_table": torchrec.distributed.types.ParameterConstraints(
        sharding_types=[ShardingType.ROW_WISE.value]
    ),
    "medium_table": torchrec.distributed.types.ParameterConstraints(
        sharding_types=[ShardingType.TABLE_WISE.value]
    ),
    "small_table": torchrec.distributed.types.ParameterConstraints(
        sharding_types=[ShardingType.TABLE_WISE.value]
    ),
}

## Initialize Distributed Model

In [None]:
def create_distributed_model(model, rank, world_size):
    """Create distributed model with sharding plan"""
    # Initialize distributed environment
    setup_distributed(rank, world_size)
    
    # Define topology
    topology = Topology(
        world_size=world_size,
        compute_device="cuda"
    )
    
    # Create planner
    planner = EmbeddingShardingPlanner(
        topology=topology,
        constraints=constraints
    )
    
    # Generate plan
    plan = planner.collective_plan(
        model, 
        [torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder()]
    )
    
    # Create distributed model
    distributed_model = DistributedModelParallel(
        module=model,
        device=torch.device(f"cuda:{rank}"),
        plan=plan
    )
    
    return distributed_model

## Generate Sample Data

In [None]:
def generate_batch(batch_size: int, num_gpus: int):
    """Generate sample batch data"""
    values = []
    lengths = []
    
    for feature in ["large_features", "medium_features", "small_features"]:
        # Generate values and lengths for each feature
        feature_values = torch.randint(0, 1000, (batch_size * 10,))
        feature_lengths = torch.ones(batch_size) * 10
        
        values.append(feature_values)
        lengths.append(feature_lengths)
    
    # Create KJT
    kjt = torchrec.sparse.jagged_tensor.KeyedJaggedTensor.from_lengths_sync(
        keys=["large_features", "medium_features", "small_features"],
        values=torch.cat(values),
        lengths=torch.cat(lengths)
    )
    
    return kjt

## Training Loop Setup

In [None]:
def train_step(model, batch, optimizer):
    """Single training step"""
    optimizer.zero_grad()
    
    # Forward pass (returns LazyAwaitable)
    output = model(batch)
    
    # Wait for embeddings and compute loss
    embeddings = output.wait()
    loss = torch.mean(embeddings.values())
    
    # Backward pass
    loss.backward()
    
    # Step optimizer
    optimizer.step()
    
    return loss.item()

## Complete Training Example

In [None]:
def run_training(rank, world_size, num_iterations=10):
    """Run complete training loop"""
    # Create distributed model
    dist_model = create_distributed_model(model, rank, world_size)
    
    # Setup optimizer
    optimizer = torch.optim.Adam(dist_model.parameters())
    
    # Training loop
    for iteration in range(num_iterations):
        # Generate batch
        batch = generate_batch(batch_size=32, num_gpus=world_size)
        
        # Move batch to correct device
        batch = batch.to(torch.device(f"cuda:{rank}"))
        
        # Training step
        loss = train_step(dist_model, batch, optimizer)
        
        if rank == 0:
            print(f"Iteration {iteration}, Loss: {loss:.4f}")
    
    # Cleanup
    dist.destroy_process_group()

## Performance Monitoring

In [None]:
def monitor_performance(rank, model, batch):
    """Monitor distributed training performance"""
    debugger = TorchRecDebugger()
    benchmark = TorchRecBenchmark()
    
    # Memory status
    memory_stats = debugger.memory_status()
    print(f"\nRank {rank} Memory Usage:")
    print(f"Allocated: {memory_stats['allocated'] / 1e9:.2f} GB")
    print(f"Reserved: {memory_stats['reserved'] / 1e9:.2f} GB")
    
    # Performance benchmark
    results = benchmark.benchmark_forward(model, batch, batch_size=32)
    print(f"\nRank {rank} Performance:")
    print(f"Batch Time: {results.batch_time_ms:.2f} ms")
    print(f"Throughput: {results.throughput:.2f} examples/sec")

## Launch Training

In [None]:
import torch.multiprocessing as mp

def main():
    world_size = torch.cuda.device_count()
    if world_size > 1:
        mp.spawn(
            run_training,
            args=(world_size,),
            nprocs=world_size,
            join=True
        )
    else:
        print("Multiple GPUs required for this example")

## Best Practices

In [None]:
distributed_tips = {
    "Initialization": [
        "Always use meta device initially",
        "Set appropriate sharding constraints",
        "Verify process group initialization"
    ],
    "Performance": [
        "Monitor per-GPU memory usage",
        "Use appropriate batch sizes",
        "Consider communication overhead"
    ],
    "Debug": [
        "Start with small tables",
        "Monitor memory usage",
        "Check device placement"
    ]
}

print("\nDistributed Training Tips:")
for category, tips in distributed_tips.items():
    print(f"\n{category}:")
    for tip in tips:
        print(f"- {tip}")

if __name__ == "__main__":
    main()