# TorchRec Embedding Tables Overview

In [None]:
import torch
import torchrec
from torchrec import EmbeddingBagCollection
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from utils.data_generators import TorchRecDataGenerator, DataConfig
from utils.visualization import TorchRecVisualizer
from utils.debugging import TorchRecDebugger
from utils.benchmark import TorchRecBenchmark

## Basic Embedding Table Setup

In [None]:
# Create basic embedding table configuration
basic_table = EmbeddingBagConfig(
    name="basic_table",
    embedding_dim=16,
    num_embeddings=1000,
    feature_names=["feature1"],
    pooling=torchrec.PoolingType.SUM
)

# Initialize EmbeddingBagCollection
basic_ebc = EmbeddingBagCollection(
    tables=[basic_table],
    device=torch.device("meta")  # Start on meta device for memory efficiency
)

print("Basic Table Configuration:")
print(f"Name: {basic_table.name}")
print(f"Embedding Dimension: {basic_table.embedding_dim}")
print(f"Number of Embeddings: {basic_table.num_embeddings}")
print(f"Features: {basic_table.feature_names}")

## Multi-Table Setup

In [None]:
# Create multiple tables with different configurations
tables = [
    EmbeddingBagConfig(
        name="products",
        embedding_dim=64,
        num_embeddings=100_000,
        feature_names=["product_id"],
        pooling=torchrec.PoolingType.SUM
    ),
    EmbeddingBagConfig(
        name="categories",
        embedding_dim=32,
        num_embeddings=1_000,
        feature_names=["category_id"],
        pooling=torchrec.PoolingType.MEAN
    ),
    EmbeddingBagConfig(
        name="shared_features",
        embedding_dim=16,
        num_embeddings=10_000,
        feature_names=["search_term", "query_token"],  # Shared embeddings
        pooling=torchrec.PoolingType.SUM
    )
]

multi_ebc = EmbeddingBagCollection(
    tables=tables,
    device=torch.device("meta")
)

## Memory Planning

In [None]:
def calculate_memory_requirement(tables):
    total_params = 0
    memory_per_table = {}
    
    for table in tables:
        params = table.num_embeddings * table.embedding_dim
        memory_bytes = params * 4  # float32
        total_params += params
        memory_per_table[table.name] = {
            "parameters": params,
            "memory_mb": memory_bytes / (1024 * 1024)
        }
    
    return memory_per_table, total_params

memory_per_table, total_params = calculate_memory_requirement(tables)

print("\nMemory Requirements:")
for table_name, info in memory_per_table.items():
    print(f"{table_name}:")
    print(f"  Parameters: {info['parameters']:,}")
    print(f"  Memory: {info['memory_mb']:.2f} MB")
print(f"\nTotal Parameters: {total_params:,}")

## Working with Embeddings

In [None]:
# Move to appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ebc = EmbeddingBagCollection(tables=tables, device=device)

# Generate sample input data
data_gen = TorchRecDataGenerator(DataConfig(
    num_users=100,
    num_products=100_000,
    max_sequence_length=10,
    batch_size=32
))

# Create input features
kjt_inputs = data_gen.generate_kjt_inputs([
    "product_id",
    "category_id",
    "search_term",
    "query_token"
])

kjt = KeyedJaggedTensor.from_lengths_sync(
    keys=kjt_inputs["keys"],
    values=kjt_inputs["values"].to(device),
    lengths=kjt_inputs["lengths"]
)

# Forward pass
embeddings = ebc(kjt)

print("\nEmbedding Output Structure:")
print(f"Keys: {embeddings.keys()}")
print(f"Values shape: {embeddings.values().shape}")
print(f"Length per key: {embeddings.length_per_key()}")

## Embedding Analysis

In [None]:
visualizer = TorchRecVisualizer()

# Analyze embedding distributions
for key in embeddings.keys():
    emb_dict = embeddings.to_dict()
    print(f"\nAnalyzing {key} embeddings:")
    print(f"Mean: {emb_dict[key].mean().item():.4f}")
    print(f"Std: {emb_dict[key].std().item():.4f}")
    visualizer.plot_embedding_distribution(emb_dict[key])

## Performance Benchmarking

In [None]:
benchmark = TorchRecBenchmark(
    warmup_steps=3,
    measure_steps=10
)

# Benchmark forward pass
result = benchmark.benchmark_forward(
    model=ebc,
    sample_input=kjt,
    batch_size=32
)

print("\nPerformance Metrics:")
print(f"Average batch time: {result.batch_time_ms:.2f}ms")
print(f"Memory used: {result.memory_used_gb:.2f}GB")
print(f"Throughput: {result.throughput:.2f} examples/sec")

## Advanced Features

In [None]:
# Example with custom initialization
custom_init_table = EmbeddingBagConfig(
    name="custom_init",
    embedding_dim=32,
    num_embeddings=1000,
    feature_names=["custom_feature"],
    pooling=torchrec.PoolingType.SUM,
    weight_init_max=0.1,
    weight_init_min=-0.1
)

# Example with different pooling types
pooling_examples = {
    "sum": torchrec.PoolingType.SUM,
    "mean": torchrec.PoolingType.MEAN
}

for name, pooling_type in pooling_examples.items():
    table = EmbeddingBagConfig(
        name=f"pooling_{name}",
        embedding_dim=16,
        num_embeddings=100,
        feature_names=[f"feature_{name}"],
        pooling=pooling_type
    )
    
    ebc_pool = EmbeddingBagCollection(
        tables=[table],
        device=device
    )
    
    # Create sample input
    sample_kjt = KeyedJaggedTensor.from_lengths_sync(
        keys=[f"feature_{name}"],
        values=torch.tensor([1, 2, 3, 4, 5]).to(device),
        lengths=torch.tensor([2, 3])
    )
    
    output = ebc_pool(sample_kjt)
    print(f"\n{name.upper()} Pooling Output:")
    print(output.values())

## Best Practices and Tips

In [None]:
best_practices = {
    "Memory Management": [
        "Use meta device for initialization",
        "Calculate memory requirements beforehand",
        "Consider sharing embeddings for related features"
    ],
    "Performance": [
        "Choose appropriate embedding dimensions",
        "Use efficient pooling types",
        "Batch similar operations"
    ],
    "Architecture": [
        "Group related features",
        "Plan for vocabulary size growth",
        "Consider embedding dimension carefully"
    ]
}

print("\nBest Practices:")
for category, practices in best_practices.items():
    print(f"\n{category}:")
    for practice in practices:
        print(f"- {practice}")

# Cleanup
if torch.cuda.is_available():
    torch.cuda.empty_cache()