# TorchRec Sharding Strategies Deep Dive

In [None]:
import torch
import torchrec
import torch.distributed as dist
from torchrec.distributed.types import ShardingType, ParameterConstraints
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.model_parallel import DistributedModelParallel
from utils.debugging import TorchRecDebugger
from utils.visualization import TorchRecVisualizer
from utils.benchmark import TorchRecBenchmark

## Understanding Different Sharding Types

In [None]:
def create_sample_tables():
    """Create sample tables for demonstration"""
    return [
        torchrec.EmbeddingBagConfig(
            name="large_table",
            embedding_dim=128,
            num_embeddings=1_000_000,
            feature_names=["large_features"],
        ),
        torchrec.EmbeddingBagConfig(
            name="wide_table",
            embedding_dim=256,
            num_embeddings=100_000,
            feature_names=["wide_features"],
        ),
        torchrec.EmbeddingBagConfig(
            name="small_table",
            embedding_dim=32,
            num_embeddings=10_000,
            feature_names=["small_features"],
        ),
    ]

# Create base model
base_model = torchrec.EmbeddingBagCollection(
    tables=create_sample_tables(),
    device=torch.device("meta")
)

## Sharding Type Analysis

In [None]:
sharding_types = {
    "TABLE_WISE": {
        "description": "Entire table on one device",
        "best_for": "Small to medium tables",
        "trade_offs": {
            "pros": ["Low communication overhead", "Simple implementation"],
            "cons": ["Limited by single GPU memory", "Potential load imbalance"]
        }
    },
    "ROW_WISE": {
        "description": "Split tables by rows across devices",
        "best_for": "Tables with many embeddings",
        "trade_offs": {
            "pros": ["Scales with embedding count", "Good memory distribution"],
            "cons": ["All-to-all communication", "Complex lookup patterns"]
        }
    },
    "COLUMN_WISE": {
        "description": "Split embedding dimensions across devices",
        "best_for": "Tables with large embedding dimensions",
        "trade_offs": {
            "pros": ["Scales with embedding dim", "Balanced computation"],
            "cons": ["All-to-all communication", "Complex reduction"]
        }
    },
    "DATA_PARALLEL": {
        "description": "Full table replica on each device",
        "best_for": "Small tables with high lookup frequency",
        "trade_offs": {
            "pros": ["Fast forward pass", "Simple implementation"],
            "cons": ["High memory usage", "Gradient synchronization overhead"]
        }
    }
}

print(sharding_types)

## Implement Different Sharding Strategies

In [None]:
def create_sharding_configs():
    """Create different sharding configurations"""
    configs = {}
    
    # Table-wise sharding
    configs["table_wise"] = {
        "large_table": ParameterConstraints(
            sharding_types=[ShardingType.TABLE_WISE.value]
        ),
        "wide_table": ParameterConstraints(
            sharding_types=[ShardingType.TABLE_WISE.value]
        ),
        "small_table": ParameterConstraints(
            sharding_types=[ShardingType.TABLE_WISE.value]
        )
    }
    
    # Row-wise sharding
    configs["row_wise"] = {
        "large_table": ParameterConstraints(
            sharding_types=[ShardingType.ROW_WISE.value]
        ),
        "wide_table": ParameterConstraints(
            sharding_types=[ShardingType.ROW_WISE.value]
        ),
        "small_table": ParameterConstraints(
            sharding_types=[ShardingType.ROW_WISE.value]
        )
    }
    
    # Column-wise sharding
    configs["column_wise"] = {
        "large_table": ParameterConstraints(
            sharding_types=[ShardingType.COLUMN_WISE.value]
        ),
        "wide_table": ParameterConstraints(
            sharding_types=[ShardingType.COLUMN_WISE.value]
        ),
        "small_table": ParameterConstraints(
            sharding_types=[ShardingType.COLUMN_WISE.value]
        )
    }
    
    # Mixed sharding (realistic scenario)
    configs["mixed"] = {
        "large_table": ParameterConstraints(
            sharding_types=[ShardingType.ROW_WISE.value]
        ),
        "wide_table": ParameterConstraints(
            sharding_types=[ShardingType.COLUMN_WISE.value]
        ),
        "small_table": ParameterConstraints(
            sharding_types=[ShardingType.TABLE_WISE.value]
        )
    }
    
    return configs

## Benchmark Different Strategies

In [None]:
class ShardingBenchmark:
    def __init__(self, world_size, device="cuda"):
        self.world_size = world_size
        self.device = device
        self.debugger = TorchRecDebugger()
        self.benchmark = TorchRecBenchmark()
        
    def create_sharded_model(self, base_model, sharding_config):
        """Create sharded model with specific configuration"""
        topology = Topology(
            world_size=self.world_size,
            compute_device=self.device
        )
        
        planner = EmbeddingShardingPlanner(
            topology=topology,
            constraints=sharding_config
        )
        
        plan = planner.collective_plan(
            base_model,
            [torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder()]
        )
        
        return DistributedModelParallel(
            module=base_model,
            plan=plan,
            device=torch.device(self.device)
        )
    
    def generate_sample_batch(self, batch_size=32):
        """Generate sample batch for testing"""
        values = torch.randint(0, 1000, (batch_size * 10,))
        lengths = torch.ones(batch_size) * 10
        
        return torchrec.sparse.jagged_tensor.KeyedJaggedTensor.from_lengths_sync(
            keys=["large_features", "wide_features", "small_features"],
            values=values.to(self.device),
            lengths=lengths
        )
    
    def benchmark_strategy(self, model, batch):
        """Benchmark specific sharding strategy"""
        results = self.benchmark.benchmark_forward(model, batch, batch_size=32)
        memory_stats = self.debugger.memory_status()
        
        return {
            "batch_time_ms": results.batch_time_ms,
            "throughput": results.throughput,
            "memory_gb": memory_stats["allocated"] / 1e9
        }

## Run Comparative Analysis

In [None]:
def run_sharding_comparison(world_size=2):
    """Compare different sharding strategies"""
    benchmark = ShardingBenchmark(world_size)
    configs = create_sharding_configs()
    results = {}
    
    for strategy_name, config in configs.items():
        print(f"\nTesting {strategy_name} sharding strategy...")
        
        # Create sharded model
        model = benchmark.create_sharded_model(base_model, config)
        
        # Generate test batch
        batch = benchmark.generate_sample_batch()
        
        # Run benchmark
        results[strategy_name] = benchmark.benchmark_strategy(model, batch)
        
        # Clean up
        del model
        torch.cuda.empty_cache()
    
    return results

## Visualization and Analysis

In [None]:
def visualize_results(results):
    """Visualize benchmark results"""
    visualizer = TorchRecVisualizer()
    
    # Plot latency comparison
    latencies = [res["batch_time_ms"] for res in results.values()]
    strategies = list(results.keys())
    
    plt.figure(figsize=(10, 5))
    plt.bar(strategies, latencies)
    plt.title("Latency by Sharding Strategy")
    plt.ylabel("Batch Time (ms)")
    plt.xticks(rotation=45)
    plt.show()
    
    # Plot memory usage
    memory_usage = [res["memory_gb"] for res in results.values()]
    plt.figure(figsize=(10, 5))
    plt.bar(strategies, memory_usage)
    plt.title("Memory Usage by Sharding Strategy")
    plt.ylabel("Memory (GB)")
    plt.xticks(rotation=45)
    plt.show()

## Implementation Guidelines

In [None]:
sharding_guidelines = {
    "Table Size Based": {
        "Large Tables (>1M rows)": "Consider ROW_WISE sharding",
        "Wide Tables (>256 dim)": "Consider COLUMN_WISE sharding",
        "Small Tables (<100K rows)": "Consider TABLE_WISE or DATA_PARALLEL"
    },
    "Access Pattern Based": {
        "High Frequency Access": "Prefer DATA_PARALLEL or TABLE_WISE",
        "Sparse Access": "ROW_WISE can be more efficient",
        "Mixed Access": "Consider mixed sharding strategy"
    },
    "Hardware Considerations": {
        "Limited GPU Memory": "Prefer ROW_WISE or COLUMN_WISE",
        "Fast GPU Interconnect": "All-to-all communication less problematic",
        "Multiple Nodes": "Consider communication overhead carefully"
    }
}

print(sharding_guidelines)

## Practical Example

In [None]:
def create_production_sharding_plan(tables_info):
    """Create sharding plan based on table characteristics"""
    constraints = {}
    
    for table_name, info in tables_info.items():
        if info["num_embeddings"] > 1_000_000:
            # Large tables get row-wise sharding
            constraints[table_name] = ParameterConstraints(
                sharding_types=[ShardingType.ROW_WISE.value]
            )
        elif info["embedding_dim"] > 256:
            # Wide tables get column-wise sharding
            constraints[table_name] = ParameterConstraints(
                sharding_types=[ShardingType.COLUMN_WISE.value]
            )
        else:
            # Small tables get table-wise sharding
            constraints[table_name] = ParameterConstraints(
                sharding_types=[ShardingType.TABLE_WISE.value]
            )
    
    return constraints

# Example usage
tables_info = {
    "large_table": {"num_embeddings": 1_500_000, "embedding_dim": 128},
    "wide_table": {"num_embeddings": 100_000, "embedding_dim": 512},
    "small_table": {"num_embeddings": 10_000, "embedding_dim": 32}
}

production_constraints = create_production_sharding_plan(tables_info)

## Monitoring and Debugging

In [None]:
def analyze_sharding_plan(plan):
    """Analyze sharding plan distribution"""
    analysis = {
        "sharding_types": {},
        "memory_distribution": {},
        "communication_patterns": {}
    }
    
    for table_name, sharding in plan.items():
        # Analyze sharding type distribution
        stype = sharding.sharding_type
        analysis["sharding_types"][stype] = analysis["sharding_types"].get(stype, 0) + 1
        
        # Analyze memory distribution
        for shard in sharding.sharding_spec.shards:
            device = shard.placement
            memory = shard.shard_sizes[0] * shard.shard_sizes[1] * 4  # float32
            analysis["memory_distribution"][device] = \
                analysis["memory_distribution"].get(device, 0) + memory
    
    return analysis