# ðŸ”¬ TPUsight Demo

**A comprehensive TPU profiler inspired by NVIDIA Nsight**

This notebook demonstrates the key features of TPUsight:

1. **Systolic Array Utilization** - MXU efficiency analysis
2. **Padding/Tiling Inefficiency** - Shape optimization
3. **Fusion Failure Explanations** - Why ops aren't fused
4. **Dynamic Shape + Cache Profiler** - JIT recompilation tracking
5. **Memory Traffic + Layout** - HBM bandwidth analysis
6. **TPU Doctor** - Actionable optimization suggestions


## Setup


In [None]:
# Install TPUsight (run once)
# !pip install -e ..

import sys
sys.path.insert(0, '..')


In [None]:
import jax
import jax.numpy as jnp
from jax import random

from tpusight import TPUsight

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")


## Basic Usage

Create a profiler and trace your JAX functions:


In [None]:
# Create profiler instance
profiler = TPUsight(session_name="demo")

print(profiler)


In [None]:
# Define some JAX functions to profile

@profiler.trace
def efficient_matmul(x, w):
    """Matmul with TPU-friendly dimensions (multiples of 128)."""
    return jnp.dot(x, w)

@profiler.trace
def inefficient_matmul(x, w):
    """Matmul with poor dimensions for TPU."""
    return jnp.dot(x, w)

@profiler.trace
def mlp_layer(x, w1, w2, b1, b2):
    """Simple MLP layer with activation."""
    h = jnp.dot(x, w1) + b1
    h = jax.nn.gelu(h)
    return jnp.dot(h, w2) + b2


In [None]:
# Create test data with different shapes
key = random.PRNGKey(42)

# Efficient shapes (multiples of 128)
x_good = random.normal(key, (256, 512))
w_good = random.normal(key, (512, 256))

# Inefficient shapes (not aligned to 128)
x_bad = random.normal(key, (100, 200))
w_bad = random.normal(key, (200, 50))

# MLP weights
w1 = random.normal(key, (512, 1024))
w2 = random.normal(key, (1024, 512))
b1 = jnp.zeros(1024)
b2 = jnp.zeros(512)


In [None]:
# Run the profiled functions
print("Running efficient matmul...")
for _ in range(5):
    result1 = efficient_matmul(x_good, w_good)

print("Running inefficient matmul...")
for _ in range(5):
    result2 = inefficient_matmul(x_bad, w_bad)

print("Running MLP layer...")
for _ in range(3):
    result3 = mlp_layer(x_good, w1, w2, b1, b2)

print(f"\nProfiled {profiler.profile_data.total_ops} operations")


## Interactive Dashboard

Launch the full interactive dashboard:


In [None]:
# Display the interactive dashboard
profiler.dashboard()


## Individual Analyzer Examples

You can also access each analyzer individually:

### Systolic Array Utilization


In [None]:
# Analyze MXU utilization
systolic_analysis = profiler.systolic.analyze()

if systolic_analysis['status'] == 'ok':
    metrics = systolic_analysis['metrics']
    print(f"Overall MXU Utilization: {metrics.overall_utilization:.1f}%")
    print(f"Total MatMul Operations: {metrics.total_matmul_ops}")
    print(f"Low Efficiency Operations: {metrics.low_util_ops}")
    print(f"Wasted FLOPS: {metrics.wasted_flops:,}")
    
    print("\nEfficiency Distribution:")
    for bucket, count in metrics.efficiency_buckets.items():
        print(f"  {bucket}: {count} ops")


### Padding Analysis


In [None]:
# Analyze padding inefficiency
padding_analysis = profiler.padding.analyze()

if padding_analysis['status'] == 'ok':
    metrics = padding_analysis['metrics']
    print(f"Average Padding Waste: {metrics.total_wasted_compute_pct:.1f}%")
    print(f"Critical Shapes (>30% waste): {metrics.critical_ops}")
    print(f"Warning Shapes (10-30% waste): {metrics.warning_ops}")
    
    print("\nWorst Shapes:")
    for op in metrics.worst_operations[:3]:
        print(f"  {op['name']}: {op['shape']} -> {op['waste_pct']:.1f}% waste")
        if op.get('recommendation'):
            print(f"    Suggestion: {op['recommendation']}")


In [None]:
# Get optimal shape suggestions for a specific tensor
suggestions = profiler.padding.suggest_optimal_shapes((100, 200))

print(f"Original shape: {suggestions['original']}")
print(f"Current waste: {suggestions['original_waste_pct']:.1f}%")
print("\nSuggestions:")
for s in suggestions['suggestions']:
    print(f"  {s['type']}: {s['shape']} - {s['description']}")


### TPU Doctor - All Recommendations


In [None]:
# Get comprehensive diagnosis
diagnosis = profiler.doctor.diagnose()

print(f"TPU Health Score: {diagnosis['health_score']}/100 ({diagnosis['health_status']})")
print(f"\nIssues Found:")
print(f"  Critical: {diagnosis['critical_count']}")
print(f"  Warnings: {diagnosis['warning_count']}")
print(f"  Info: {diagnosis['info_count']}")

print("\n=== Top Recommendations ===")
for i, rec in enumerate(diagnosis['top_recommendations'][:5], 1):
    severity_emoji = {'critical': 'ðŸ”´', 'warning': 'ðŸŸ¡', 'info': 'ðŸ”µ'}.get(rec['severity'], 'âšª')
    print(f"\n{i}. {severity_emoji} {rec['title']}")
    print(f"   {rec['message']}")
    print(f"   Impact: {rec['impact_estimate']}")


## Utility Functions

TPUsight provides helpful utility functions:


In [None]:
from tpusight.utils.helpers import (
    calculate_padding_waste,
    estimate_mxu_utilization,
    format_bytes,
    format_flops,
)

# Analyze padding for a shape
shape = (100, 200)
padding = calculate_padding_waste(shape)
print(f"Shape {shape}:")
print(f"  Padded to: {padding['padded_shape']}")
print(f"  Waste: {padding['wasted_compute_pct']:.1f}%")
print(f"  Recommendation: {padding['recommendation']}")


In [None]:
# Estimate MXU utilization for a matmul
# (M, K) x (K, N) = (M, N)
m, n, k = 100, 200, 150
mxu = estimate_mxu_utilization(m, n, k)

print(f"Matmul ({m}, {k}) x ({k}, {n}):")
print(f"  MXU Utilization: {mxu['mxu_utilization_pct']:.1f}%")
print(f"  Actual FLOPS: {format_flops(mxu['actual_flops'])}")
print(f"  Wasted FLOPS: {format_flops(mxu['wasted_flops'])}")
print(f"  Bottleneck: {mxu['bottleneck']}")

# Compare with optimal shape
m_opt, n_opt, k_opt = 128, 256, 128
mxu_opt = estimate_mxu_utilization(m_opt, n_opt, k_opt)

print(f"\nOptimal Matmul ({m_opt}, {k_opt}) x ({k_opt}, {n_opt}):")
print(f"  MXU Utilization: {mxu_opt['mxu_utilization_pct']:.1f}%")
print(f"  Wasted FLOPS: {format_flops(mxu_opt['wasted_flops'])}")


## Next Steps

1. **Profile your own models** - Use `@profiler.trace` or `with profiler.trace_context()`
2. **Check the dashboard** - `profiler.dashboard()` for interactive analysis
3. **Follow recommendations** - Address critical issues first
4. **Iterate** - Profile again after optimizations to measure improvement

For more information, see the [README](../README.md).
