# TorchRec Data Types Overview

In [None]:
import torch
import torchrec
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
from utils.data_generators import TorchRecDataGenerator
from utils.visualization import TorchRecVisualizer
from utils.debugging import TorchRecDebugger

## JaggedTensor Basics

In [None]:
# Create a simple JaggedTensor
values = torch.tensor([1, 2, 3, 4, 5, 6, 7])
lengths = torch.tensor([2, 0, 3, 2])  # 4 sequences of different lengths

jt = JaggedTensor(
    values=values,
    lengths=lengths
)

# Examine the structure
print("JaggedTensor Structure:")
print(f"Values: {jt.values()}")
print(f"Lengths: {jt.lengths()}")
print(f"Offsets: {jt.offsets()}")

### Working with JaggedTensor

In [None]:
# Convert to dense (list of tensors)
dense_form = jt.to_dense()
print("\nDense representation:")
for i, tensor in enumerate(dense_form):
    print(f"Sequence {i}: {tensor}")

# Convert to padded dense (2D tensor)
padded = jt.to_padded_dense(desired_length=3, padding_value=0)
print("\nPadded dense representation:")
print(padded)

## KeyedJaggedTensor Basics

In [None]:
# Create sample data for multiple features
product_values = torch.tensor([1, 2, 3, 4, 5])
product_lengths = torch.tensor([2, 3])

category_values = torch.tensor([10, 20, 30])
category_lengths = torch.tensor([2, 1])

# Create individual JaggedTensors
product_jt = JaggedTensor(values=product_values, lengths=product_lengths)
category_jt = JaggedTensor(values=category_values, lengths=category_lengths)

# Create KeyedJaggedTensor from JaggedTensors
kjt = KeyedJaggedTensor.from_jt_dict({
    "product_id": product_jt,
    "category_id": category_jt
})

print("\nKeyedJaggedTensor Structure:")
print(f"Keys: {kjt.keys()}")
print(f"Values: {kjt.values()}")
print(f"Lengths: {kjt.lengths()}")

### KeyedJaggedTensor Operations

In [None]:
# Convert back to dictionary of JaggedTensors
jt_dict = kjt.to_dict()

print("\nIndividual features:")
for key, jt in jt_dict.items():
    print(f"\n{key}:")
    print(f"Values: {jt.values()}")
    print(f"Lengths: {jt.lengths()}"

## Real-world Example - User Interaction Data

In [None]:
# Initialize data generator
data_config = DataConfig(
    num_users=1000,
    num_products=10000,
    max_sequence_length=5,
    min_sequence_length=1,
    batch_size=3
)

data_gen = TorchRecDataGenerator(data_config)

# Generate batch of user interaction data
interaction_data = data_gen.generate_kjt_inputs([
    "product_history",
    "category_history",
    "search_terms"
])

# Create KJT from interaction data
interaction_kjt = KeyedJaggedTensor.from_lengths_sync(
    keys=interaction_data["keys"],
    values=interaction_data["values"],
    lengths=interaction_data["lengths"]
)

print("\nUser Interaction Batch:")
print(f"Number of features: {len(interaction_kjt.keys())}")
print(f"Batch size: {len(interaction_kjt.lengths()) // len(interaction_kjt.keys())}")
print(f"Total values: {len(interaction_kjt.values())}")

## Data Validation and Debugging

In [None]:
debugger = TorchRecDebugger()

# Validate KJT structure
validation_results = debugger.validate_kjt(interaction_kjt)
print("\nValidation Results:")
for check, status in validation_results.items():
    print(f"{check}: {'✅' if status else '❌'}")

## Device Management

In [None]:
if torch.cuda.is_available():
    # Move to GPU
    gpu_kjt = interaction_kjt.to(torch.device("cuda"))
    
    print("\nDevice Location:")
    print(f"Values device: {gpu_kjt.values().device}")
    print(f"Lengths device: {gpu_kjt.lengths().device}")

## Performance Considerations

In [None]:
# Create large batch for performance testing
large_config = DataConfig(
    num_users=10000,
    num_products=100000,
    max_sequence_length=50,
    batch_size=32
)

large_data_gen = TorchRecDataGenerator(large_config)
large_batch = large_data_gen.generate_kjt_inputs(["product_history"])

# Memory usage before
if torch.cuda.is_available():
    print("\nMemory Usage:")
    print("Before large batch:")
    print(f"Allocated: {torch.cuda.memory_allocated() / 1e6:.2f}MB")

    # Create large KJT
    large_kjt = KeyedJaggedTensor.from_lengths_sync(
        keys=large_batch["keys"],
        values=large_batch["values"].cuda(),
        lengths=large_batch["lengths"]
    )

    print("\nAfter large batch:")
    print(f"Allocated: {torch.cuda.memory_allocated() / 1e6:.2f}MB")

    # Clean up
    debugger.clear_memory()

## Visualization

In [None]:
visualizer = TorchRecVisualizer()

# Create sample embeddings for visualization
sample_embeddings = torch.randn(100, 16)  # 100 embeddings of dim 16
visualizer.plot_embedding_distribution(sample_embeddings)

# Summary of operations

In [None]:
summary = {
    "JaggedTensor": {
        "creation_methods": ["direct", "from_dense", "from_lengths"],
        "operations": ["to_dense", "to_padded_dense", "lengths", "offsets"],
        "use_cases": ["variable length sequences", "sparse features"]
    },
    "KeyedJaggedTensor": {
        "creation_methods": ["from_jt_dict", "from_lengths_sync", "from_offsets_sync"],
        "operations": ["to_dict", "keys", "values", "lengths"],
        "use_cases": ["multi-feature data", "batch processing", "embedding lookups"]
    }
}

print("\nData Types Summary:")
for dtype, info in summary.items():
    print(f"\n{dtype}:")
    for category, items in info.items():
        print(f"  {category}: {items}")