# üî¨ TPUsight Demo

**A comprehensive TPU profiler inspired by NVIDIA Nsight**

This notebook demonstrates the key features of TPUsight:

1. **Systolic Array Utilization** - MXU efficiency analysis
2. **Padding/Tiling Inefficiency** - Shape optimization
3. **Fusion Failure Explanations** - Why ops aren't fused
4. **Dynamic Shape + Cache Profiler** - JIT recompilation tracking
5. **Memory Traffic + Layout** - HBM bandwidth analysis
6. **TPU Doctor** - Actionable optimization suggestions
7. **Time Breakdown** - Compute vs memory vs compilation time
8. **Live Profiling** - Real-time monitoring with alerts


## Setup


In [None]:
# Install TPUsight (run once)
# !pip install -e ..

import sys
sys.path.insert(0, '..')


In [1]:
import jax
import jax.numpy as jnp
from jax import random

from tpusight import TPUsight

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")




JAX version: 0.7.2
Devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]


## Basic Usage

Create a profiler and trace your JAX functions:


In [2]:
# Create profiler instance
profiler = TPUsight(session_name="demo")

print(profiler)


TPUsight(session='demo', device='tpu', ops=0)


In [4]:
# Define some JAX functions to profile

@profiler.trace
def efficient_matmul(x, w):
    """Matmul with TPU-friendly dimensions (multiples of 128)."""
    return jnp.dot(x, w)

@profiler.trace
def inefficient_matmul(x, w):
    """Matmul with poor dimensions for TPU."""
    return jnp.dot(x, w)

@profiler.trace
def mlp_layer(x, w1, w2, b1, b2):
    """Simple MLP layer with activation."""
    h = jnp.dot(x, w1) + b1
    h = jax.nn.gelu(h)
    return jnp.dot(h, w2) + b2


In [9]:
# Create test data with different shapes
key = random.PRNGKey(42)

# Efficient shapes (multiples of 128)
x_good = random.normal(key, (256, 512))
w_good = random.normal(key, (512, 256))

# Inefficient shapes (not aligned to 128)
x_bad = random.normal(key, (100, 200))
w_bad = random.normal(key, (200, 50))

# MLP weights
w1 = random.normal(key, (512, 1024))
w2 = random.normal(key, (1024, 512))
b1 = jnp.zeros(1024)
b2 = jnp.zeros(512)


In [6]:
# Run the profiled functions
print("Running efficient matmul...")
for _ in range(5):
    result1 = efficient_matmul(x_good, w_good)

print("Running inefficient matmul...")
for _ in range(5):
    result2 = inefficient_matmul(x_bad, w_bad)

print("Running MLP layer...")
for _ in range(3):
    result3 = mlp_layer(x_good, w1, w2, b1, b2)

print(f"\nProfiled {profiler.profile_data.total_ops} operations")


Running efficient matmul...
Running inefficient matmul...
Running MLP layer...

Profiled 13 operations


## Interactive Dashboard

Launch the full interactive dashboard:


In [None]:
from IPython.display import HTML, display

# Generate and display the report inline
profiler.export("tpu_report.html", format="html")

with open("tpu_report.html", "r") as f:
    html_content = f.read()

display(HTML(html_content))

In [None]:
# Display the interactive dashboard
profiler.dashboard()


## Individual Analyzer Examples

You can also access each analyzer individually:

### Systolic Array Utilization


In [7]:
# Analyze MXU utilization
systolic_analysis = profiler.systolic.analyze()

if systolic_analysis['status'] == 'ok':
    metrics = systolic_analysis['metrics']
    print(f"Overall MXU Utilization: {metrics.overall_utilization:.1f}%")
    print(f"Total MatMul Operations: {metrics.total_matmul_ops}")
    print(f"Low Efficiency Operations: {metrics.low_util_ops}")
    print(f"Wasted FLOPS: {metrics.wasted_flops:,}")
    
    print("\nEfficiency Distribution:")
    for bucket, count in metrics.efficiency_buckets.items():
        print(f"  {bucket}: {count} ops")


Overall MXU Utilization: 70.7%
Total MatMul Operations: 13
Low Efficiency Operations: 5
Wasted FLOPS: 31,943,040

Efficiency Distribution:
  90-100%: 8 ops
  70-90%: 0 ops
  50-70%: 0 ops
  30-50%: 0 ops
  0-30%: 5 ops


### Padding Analysis


In [8]:
# Analyze padding inefficiency
padding_analysis = profiler.padding.analyze()

if padding_analysis['status'] == 'ok':
    metrics = padding_analysis['metrics']
    print(f"Average Padding Waste: {metrics.total_wasted_compute_pct:.1f}%")
    print(f"Critical Shapes (>30% waste): {metrics.critical_ops}")
    print(f"Warning Shapes (10-30% waste): {metrics.warning_ops}")
    
    print("\nWorst Shapes:")
    for op in metrics.worst_operations[:3]:
        print(f"  {op['name']}: {op['shape']} -> {op['waste_pct']:.1f}% waste")
        if op.get('recommendation'):
            print(f"    Suggestion: {op['recommendation']}")


Average Padding Waste: 26.7%
Critical Shapes (>30% waste): 5

Worst Shapes:
  inefficient_matmul: (100, 50) -> 69.5% waste
    Suggestion: Consider reshaping to (128, 0) to reduce padding waste from 69.5% to ~0%
  inefficient_matmul: (100, 50) -> 69.5% waste
    Suggestion: Consider reshaping to (128, 0) to reduce padding waste from 69.5% to ~0%
  inefficient_matmul: (100, 50) -> 69.5% waste
    Suggestion: Consider reshaping to (128, 0) to reduce padding waste from 69.5% to ~0%


In [22]:
# Get optimal shape suggestions for a specific tensor
suggestions = profiler.padding.suggest_optimal_shapes((100, 200))

print(f"Original shape: {suggestions['original']}")
print(f"Current waste: {suggestions['original_waste_pct']:.1f}%")
print("\nSuggestions:")
for s in suggestions['suggestions']:
    print(f"  {s['type']}: {s['shape']} - {s['description']}")


Original shape: (100, 200)
Current waste: 39.0%

Suggestions:
  pad_up: (128, 256) - Pad to (128, 256) - next multiple of 128
  tile_64: (128, 256) - Align to 64-element tiles: (128, 256)
  tile_32: (128, 224) - Align to 32-element tiles: (128, 224)


### TPU Doctor - All Recommendations


In [23]:
# Get comprehensive diagnosis
diagnosis = profiler.doctor.diagnose()

print(f"TPU Health Score: {diagnosis['health_score']}/100 ({diagnosis['health_status']})")
print(f"\nIssues Found:")
print(f"  Critical: {diagnosis['critical_count']}")
print(f"  Warnings: {diagnosis['warning_count']}")
print(f"  Info: {diagnosis['info_count']}")

print("\n=== Top Recommendations ===")
for i, rec in enumerate(diagnosis['top_recommendations'][:5], 1):
    severity_emoji = {'critical': 'üî¥', 'warning': 'üü°', 'info': 'üîµ'}.get(rec['severity'], '‚ö™')
    print(f"\n{i}. {severity_emoji} {rec['title']}")
    print(f"   {rec['message']}")
    print(f"   Impact: {rec['impact_estimate']}")


TPU Health Score: 39/100 (needs_attention)

Issues Found:
  Critical: 2
  Info: 1

=== Top Recommendations ===

1. üî¥ Fusion Opportunity
   Fusion rate is very low (0.0%)
   Impact: 5-15% speedup

2. üî¥ Padding Inefficiency
   High padding overhead (26.7% average waste)
   Impact: ~high compute savings

3. üü° Compilation Cache Issue
   Cache hit rate is below optimal (76.9%)
   Impact: Minor compilation overhead

4. üü° Memory Issue
   5 operations are memory-bound
   Impact: high

5. üü° MXU Underutilization
   5/13 operations have <50% utilization
   Impact: 10-30% potential speedup


## Utility Functions

TPUsight provides helpful utility functions:


In [9]:
from tpusight.utils.helpers import (
    calculate_padding_waste,
    estimate_mxu_utilization,
    format_bytes,
    format_flops,
)

# Analyze padding for a shape
shape = (100, 200)
padding = calculate_padding_waste(shape)
print(f"Shape {shape}:")
print(f"  Padded to: {padding['padded_shape']}")
print(f"  Waste: {padding['wasted_compute_pct']:.1f}%")
print(f"  Recommendation: {padding['recommendation']}")


Shape (100, 200):
  Padded to: (128, 256)
  Waste: 39.0%
  Recommendation: Consider reshaping to (128, 256) to reduce padding waste from 39.0% to ~0%


In [10]:
# Estimate MXU utilization for a matmul
# (M, K) x (K, N) = (M, N)
m, n, k = 100, 200, 150
mxu = estimate_mxu_utilization(m, n, k)

print(f"Matmul ({m}, {k}) x ({k}, {n}):")
print(f"  MXU Utilization: {mxu['mxu_utilization_pct']:.1f}%")
print(f"  Actual FLOPS: {format_flops(mxu['actual_flops'])}")
print(f"  Wasted FLOPS: {format_flops(mxu['wasted_flops'])}")
print(f"  Bottleneck: {mxu['bottleneck']}")

# Compare with optimal shape
m_opt, n_opt, k_opt = 128, 256, 128
mxu_opt = estimate_mxu_utilization(m_opt, n_opt, k_opt)

print(f"\nOptimal Matmul ({m_opt}, {k_opt}) x ({k_opt}, {n_opt}):")
print(f"  MXU Utilization: {mxu_opt['mxu_utilization_pct']:.1f}%")
print(f"  Wasted FLOPS: {format_flops(mxu_opt['wasted_flops'])}")


Matmul (100, 150) x (150, 200):
  MXU Utilization: 35.8%
  Actual FLOPS: 6.00 MFLOPS
  Wasted FLOPS: 10.78 MFLOPS
  Bottleneck: None

Optimal Matmul (128, 128) x (128, 256):
  MXU Utilization: 100.0%
  Wasted FLOPS: 0.00 FLOPS


## Time Breakdown Analysis

See where time is actually spent - compute, memory wait, compilation, etc.


In [11]:
# Print formatted time breakdown
profiler.time_breakdown.print_breakdown()


No timing data available


In [12]:
# Access detailed breakdown programmatically
time_analysis = profiler.time_breakdown.analyze()

if time_analysis['status'] == 'ok':
    breakdown = time_analysis['breakdown']
    pct = time_analysis['percentages']
    
    print(f"Total time: {breakdown.total_time_ms:.2f} ms")
    print(f"\nTime breakdown:")
    print(f"  üü¢ Compute:        {pct['compute']:.1f}%")
    print(f"  üî¥ Memory Wait:    {pct['memory_wait']:.1f}%")
    print(f"  üü£ Rematerialization: {pct['rematerialization']:.1f}%")
    print(f"  üü° Compilation:    {pct['compilation']:.1f}%")
    
    print(f"\nBottleneck: {time_analysis['bottleneck']}")
    print(f"  {time_analysis['bottleneck_description']}")


## Live Profiling Mode

Real-time monitoring with live alerts and auto-updating metrics.


In [4]:
from tpusight import LiveProfiler

# Create live profiler with custom alert thresholds
live = LiveProfiler(
    alert_thresholds={
        "mxu_utilization_warning": 60.0,  # Alert if MXU < 60%
        "padding_waste_high": 25.0,       # Alert if padding > 25%
    }
)

# Register alert callback - get notified immediately!
@live.on_alert
def handle_alert(alert):
    icon = {"critical": "üî¥", "warning": "üü°", "info": "üîµ"}.get(alert.severity, "‚ö™")
    print(f"{icon} ALERT: {alert.message} (op: {alert.operation})")

print("Live profiler created with custom thresholds")


Live profiler created with custom thresholds


In [5]:
# Define functions with live tracing
@live.trace
def live_efficient_matmul(x, w):
    return jnp.dot(x, w)

@live.trace  
def live_inefficient_matmul(x, w):
    return jnp.dot(x, w)

# Start live profiling
live.start()


üî¥ Live profiling started (session: live_13d70d9f)


In [12]:
# Run operations - alerts fire in real-time!
print("Running efficient operations...")
for i in range(5):
    _ = live_efficient_matmul(x_good, w_good)

print("\nRunning inefficient operations (watch for alerts!)...")
for i in range(5):
    _ = live_inefficient_matmul(x_bad, w_bad)

# Check current metrics
metrics = live.get_current_metrics()
print(f"\nüìä Live Metrics:")
print(f"  Total ops: {metrics.total_ops}")
print(f"  Ops/sec: {metrics.ops_per_second:.1f}")
print(f"  MXU util: {metrics.mxu_utilization:.1f}%")


Running efficient operations...

Running inefficient operations (watch for alerts!)...
üî¥ ALERT: Very low MXU utilization: 23.8% (op: live_inefficient_matmul)
üü° ALERT: High padding waste: 69.5% (op: live_inefficient_matmul)
üî¥ ALERT: Very low MXU utilization: 23.8% (op: live_inefficient_matmul)
üü° ALERT: High padding waste: 69.5% (op: live_inefficient_matmul)
üî¥ ALERT: Very low MXU utilization: 23.8% (op: live_inefficient_matmul)
üü° ALERT: High padding waste: 69.5% (op: live_inefficient_matmul)
üî¥ ALERT: Very low MXU utilization: 23.8% (op: live_inefficient_matmul)
üü° ALERT: High padding waste: 69.5% (op: live_inefficient_matmul)
üî¥ ALERT: Very low MXU utilization: 23.8% (op: live_inefficient_matmul)
üü° ALERT: High padding waste: 69.5% (op: live_inefficient_matmul)

üìä Live Metrics:
  Total ops: 0
  Ops/sec: 0.0
  MXU util: 0.0%


In [16]:
# View alert summary
alerts = live.get_recent_alerts(10)
alert_counts = live.get_alert_counts()

print(f"üìã Alert Summary:")
print(f"  Total alerts: {len(alerts)}")
for category, count in alert_counts.items():
    print(f"    {category}: {count}")

print(f"\nüö® Recent Alerts:")
for alert in alerts[-5:]:
    print(f"  [{alert.severity}] {alert.message}")


üìã Alert Summary:
  Total alerts: 10
    mxu_utilization: 5
    padding: 5

üö® Recent Alerts:
  [critical] Very low MXU utilization: 23.8%
  [critical] Very low MXU utilization: 23.8%


In [17]:
# Stop live profiling
live.stop()

# The collected data is compatible with TPUsight for full analysis
print(f"\nCollected {live.profile_data.total_ops} operations during live session")


‚èπÔ∏è  Live profiling stopped. Captured 10 operations, 10 alerts

Collected 10 operations during live session


### Live Dashboard (Works in Colab/Cursor!)

A simple HTML-based live dashboard that updates in real-time - no widgets required:


In [7]:
from tpusight.visualization.live_dashboard import SimpleLiveDashboard
import time

# Create a new live profiler for dashboard demo
live_dash = LiveProfiler()

@live_dash.trace
def dash_matmul(x, w):
    return jnp.dot(x, w)

# Create the simple dashboard (works without ipywidgets!)
dashboard = SimpleLiveDashboard(live_dash)

# Start profiling and dashboard
live_dash.start()
dashboard.start(update_interval=0.5)  # Updates every 0.5 seconds

print("Dashboard started! Run the next cell to generate operations...")


üî¥ Live profiling started (session: live_2dbdb069)


Dashboard started! Run the next cell to generate operations...


In [10]:
# Run operations - watch the dashboard update in real-time!
print("Running operations... watch the dashboard above update!")

for i in range(20):
    # Mix of efficient and inefficient operations
    if i % 3 == 0:
        _ = dash_matmul(x_bad, w_bad)  # Inefficient - will trigger alerts
    else:
        _ = dash_matmul(x_good, w_good)  # Efficient
    time.sleep(0.2)  # Slow down so you can see updates

print("Done! Check the dashboard for results.")


Running operations... watch the dashboard above update!
Done! Check the dashboard for results.


In [11]:
# Stop the live dashboard when done
dashboard.stop()
live_dash.stop()

print("Live session ended!")


‚èπÔ∏è  Live profiling stopped. Captured 20 operations, 14 alerts
Live session ended!


## Next Steps

1. **Profile your own models** - Use `@profiler.trace` or `with profiler.trace_context()`
2. **Check the dashboard** - `profiler.dashboard()` for interactive analysis
3. **Use live profiling** - `LiveProfiler` for real-time monitoring during training
4. **Analyze time breakdown** - `profiler.time_breakdown.print_breakdown()` to see where time goes
5. **Follow recommendations** - Address critical issues first
6. **Iterate** - Profile again after optimizations to measure improvement

For more information, see the [README](../README.md).
