In [None]:
%%shell
# pip uninstall -y torch torchvision torchaudio torch_xla
pip install -q "torch==2.8.*" "torchvision==0.23.*" "torchaudio==2.8.*"
pip install -q "torch_xla[tpu]==2.8.*"


In [None]:
%%shell
pip install diffusers

In [None]:
import os, time, json, statistics, pathlib, re
import torch
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.debug import profiler as xp
from torch_xla.debug import metrics as xmetrics
from torch.profiler import profile, ProfilerActivity
import torch.nn.functional as F
from transformers import (
    AutoModelForImageClassification,
    AutoModelForSequenceClassification,
    AutoModelForCausalLM,
    AutoTokenizer,
)
from diffusers import UNet2DConditionModel


In [None]:
device = torch_xla.device()
print(device)

# device2 = xm.xla_device()
# print(device2)

print("Dynamo backends:", torch._dynamo.list_backends())

In [None]:
def _human_bytes(n):
    for u in ["B","KB","MB","GB","TB","PB"]:
        if n < 1024:
            return f"{n:.2f} {u}"
        n /= 1024
    return f"{n:.2f} EB"

def _summarize_times(series_ms):
    if not series_ms:
        return {}
    s = sorted(series_ms)
    idx95 = max(0, min(len(s)-1, int(0.95*len(s))-1))
    return {
        "count": len(series_ms),
        "mean_ms": statistics.mean(series_ms),
        "p50_ms": statistics.median(series_ms),
        "p95_ms": s[idx95],
        "min_ms": min(series_ms),
        "max_ms": max(series_ms),
    }

In [None]:
def _export_stablehlo(module_cpu, input_args_cpu, input_kwargs_cpu, out_path):
    """
    Export StableHLO for a (CPU) module using torch.export → save_as_stablehlo.
    We use CPU & float32 for widest compatibility.
    """
    # Prefer torch.export; fall back to torch._export for older PyTorch.
    try:
        from torch.export import export as _export
    except Exception:
        from torch._export import export as _export

    from torch_xla.stablehlo import save_as_stablehlo

    # Ensure inputs are tensors on CPU (convert bf16/fp16 to fp32 for export robustness)
    def _to_cpu_fp32(x):
        if isinstance(x, torch.Tensor):
            return x.detach().to("cpu", dtype=torch.float32)
        return x

    args_cpu = tuple(_to_cpu_fp32(a) for a in input_args_cpu)
    kwargs_cpu = {k: _to_cpu_fp32(v) for k, v in input_kwargs_cpu.items()}

    module_cpu = module_cpu.to("cpu", dtype=torch.float32).eval()
    with torch.no_grad():
        ep = _export(module_cpu, args_cpu, kwargs_cpu)
    save_as_stablehlo(ep, out_path)
    return out_path

In [None]:
!pip install graphviz

from pathlib import Path
import re
from graphviz import Digraph

def visualize_mlir(mlir_path, func_name, graph_path):
  pass
  # with open(mlir_path) as f:
  #     text = f.read()

  # # matches like: %2 = stablehlo.dot_general %0, %1, ...
  # pattern = re.compile(r"(%\w+)\s*=\s*([\w\.]+)\s+(.*)")
  # ops = []
  # for line in text.splitlines():
  #     line = line.strip()
  #     m = pattern.match(line)
  #     if not m:
  #         continue
  #     out, op, rest = m.groups()
  #     # Extract inputs like %0, %1
  #     inputs = re.findall(r"%\w+", rest)
  #     ops.append((out, op, inputs))

  # g = Digraph("StableHLO", format="png")
  # g.attr(rankdir="LR")

  # for out, op, inputs in ops:
  #     g.node(out, f"{out}\n{op}")

  #     for inp in inputs:
  #         g.edge(inp, out)

  # for arg in re.findall(r"(%arg\d+):", text):
  #     g.node(arg, f"{arg}\ninput", shape="box")

  # ret = re.search(r"return\s+(.*)\s*:", text)
  # if ret:
  #     outs = re.findall(r"%\w+", ret.group(1))
  #     for o in outs:
  #         g.edge(o, "return")
  #     g.node("return", "return", shape="box")

  # g.render(graph_path, view=True, cleanup=True)


In [None]:
# GEMM

def gemm_module_fn():
    # No parameters needed — GEMM is just an operation, not a module
    def _build():
        # Wrap matmul as a callable "module" for consistency
        class GEMMOp(nn.Module):
            def forward(self, a, b):
                return a.matmul(b)
        return GEMMOp()
    return _build


def gemm_input_fn(batch_m=128, hidden_dim=256, out_dim=1024):
    def _make(device, dt):
        a = torch.randn(batch_m, hidden_dim, device=device, dtype=dt)
        b = torch.randn(hidden_dim, out_dim, device=device, dtype=dt)
        return (a, b), {}   # (args, kwargs)
    return _make


In [None]:
# Multi-head attention

def mha_module_fn(embed_dim=256, num_heads=16, batch_first=True, dropout=0.0):
    def _build():
        return nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            batch_first=batch_first,
            dropout=dropout,
            bias=True,
        )
    return _build

def mha_input_fn(batch=128, seq_len=64, embed_dim=256):
    def _make(device, dt):
        x = torch.randn(batch, seq_len, embed_dim, device=device, dtype=dt)
        return (x, x, x), {}   # (q,k,v), no kwargs
    return _make


In [None]:
# CNN
def conv_module_fn(in_ch=3, out_ch=3, kernel_size=3, stride=1, padding=1):
    def _build():
        return nn.Conv2d(
            in_channels=in_ch,
            out_channels=out_ch,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=True,
        )
    return _build


def conv_input_fn(batch=128, in_ch=3, height=12, width=12):
    def _make(device, dt):
        x = torch.randn(batch, in_ch, height, width, device=device, dtype=dt)
        return (x,), {}   # single tensor arg, no kwargs
    return _make

In [None]:
def ffn_module_fn(embed_dim=256, hidden_dim=1024):
    def _build():
        class FFN(nn.Module):
            def __init__(self):
                super().__init__()
                self.w1 = nn.Linear(embed_dim, hidden_dim)
                self.w2 = nn.Linear(hidden_dim, embed_dim)

            def forward(self, x):
                # x: (batch, seq_len, embed_dim)
                x2 = F.relu(self.w1(x))
                x2 = self.w2(x2)
                return x2

        return FFN()
    return _build


def ffn_input_fn(batch=128, seq_len=64, embed_dim=256):
    def _make(device, dt):
        x = torch.randn(batch, seq_len, embed_dim, device=device, dtype=dt)
        return (x,), {}   # (args, kwargs)
    return _make


In [None]:
def profile_module_on_tpu(
    module_fn,
    input_fn,
    *,
    iters=20,
    warmup=5,
    dtype="bf16",          # 'bf16' (recommended on TPU) or 'fp32'
    do_backward=False,     # include backward() in timing/memory
    trace_dir="/tmp/tpu_module_trace",
    stablehlo_out="/tmp/module.stablehlo.mlir",  # StableHLO export path (set None to skip)
    print_report=True
):
    """
    Profile any torch.nn.Module on a single TPU core for `iters` iterations AND export StableHLO.

    Args:
      module_fn: () -> nn.Module
      input_fn:  (device, dtype) -> (args, kwargs)
      iters:     timed iterations
      warmup:    warmup steps (inside profiler)
      dtype:     'bf16' or 'fp32' for TPU execution
      do_backward: include backward() timing/memory
      trace_dir: XLA trace output directory
      stablehlo_out: file path to write StableHLO MLIR (None to disable)
      print_report: print JSON report

    Returns:
      Dict with timing, memory, throughput, trace info, and StableHLO export path (if any).
    """
    assert callable(module_fn) and callable(input_fn), "module_fn and input_fn must be callables"

    device = torch_xla.device()
    dtype_map = {"bf16": torch.bfloat16, "fp32": torch.float32}
    dt = dtype_map[dtype]

    # Build module and inputs for TPU run
    module = module_fn().to(device).to(dt)
    args_dev, kwargs_dev = input_fn(device, dt)
    if not isinstance(args_dev, (tuple, list)):
        raise ValueError("input_fn must return (args, kwargs) where args is tuple/list and kwargs is dict")

    # -------- StableHLO export (on CPU, float32) --------
    stablehlo_path = None
    if stablehlo_out:
        module_cpu = module_fn()
        args_cpu, kwargs_cpu = input_fn("cpu", torch.float32)
        try:
            pathlib.Path(os.path.dirname(stablehlo_out) or ".").mkdir(parents=True, exist_ok=True)
            stablehlo_path = _export_stablehlo(module_cpu, args_cpu, kwargs_cpu, stablehlo_out)
        except Exception as e:
            stablehlo_path = f"EXPORT_FAILED: {type(e).__name__}: {e}"

    # Enable grads if requested
    if do_backward:
        for p in module.parameters():
            p.requires_grad_(True)

    # ---- helpers for memory & throughput ----
    def _compute_weight_bytes(mod: torch.nn.Module) -> int:
        param_bytes = sum(p.numel() * p.element_size() for p in mod.parameters())
        buffer_bytes = sum(b.numel() * b.element_size() for b in mod.buffers())
        return param_bytes + buffer_bytes

    def _infer_throughput_dims(args, kwargs):
        """Heuristic: try to detect batch size and 'tokens' / pixels per sample."""
        tensors = []

        def _collect(obj):
            if isinstance(obj, torch.Tensor):
                tensors.append(obj)
            elif isinstance(obj, (list, tuple)):
                for x in obj:
                    _collect(x)
            elif isinstance(obj, dict):
                for v in obj.values():
                    _collect(v)

        _collect(args)
        _collect(kwargs)
        if not tensors:
            return None  # can't infer

        # Prefer 3D (B, S, D) -> sequence, then 4D -> image, then 2D -> (B, features)
        cand_3d = [t for t in tensors if t.ndim == 3]
        cand_4d = [t for t in tensors if t.ndim == 4]
        cand_2d = [t for t in tensors if t.ndim == 2]

        mode = None
        main = None
        if cand_3d:
            main = max(cand_3d, key=lambda t: t.numel())
            B, S = int(main.shape[0]), int(main.shape[1])
            tokens_per_sample = S
            mode = "sequence_3d"
        elif cand_4d:
            main = max(cand_4d, key=lambda t: t.numel())
            B = int(main.shape[0])
            # Treat tokens as H*W regardless of channel layout
            H, W = int(main.shape[-2]), int(main.shape[-1])
            tokens_per_sample = H * W
            mode = "image_4d"
        elif cand_2d:
            main = max(cand_2d, key=lambda t: t.numel())
            B, S = int(main.shape[0]), int(main.shape[1])
            tokens_per_sample = S
            mode = "sequence_2d"
        else:
            return None

        tokens_per_step = B * tokens_per_sample
        return {
            "batch_size": B,
            "tokens_per_sample": tokens_per_sample,
            "tokens_per_step": tokens_per_step,
            "mode": mode,
            "example_shape": tuple(int(x) for x in main.shape),
        }

    # Step function
    def _one_step():
        with torch.autocast(device_type="xla", dtype=dt, enabled=(dtype != "fp32")):
            out = module(*args_dev, **kwargs_dev)
            loss_src = out[0] if isinstance(out, (tuple, list)) else out
            if do_backward:
                loss = loss_src.float().mean()
                loss.backward()
                for p in module.parameters():
                    if p.grad is not None:
                        p.grad.zero_()

    # ---- profiling ----
    all_warmups = []
    times_ms = []
    mem_snapshots = []
    peak_used_bytes = 0
    min_free_bytes = None
    total_bytes = None

    pathlib.Path(trace_dir).mkdir(parents=True, exist_ok=True)

    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU],
                 record_shapes=True,
                 profile_memory=True) as prof:

        # Warmup iterations (inside profiler, but reported separately)
        for _ in range(max(0, warmup)):
            warmup_start = time.perf_counter()
            _one_step()
            torch_xla.sync()
            warmup_end = time.perf_counter()
            warmup_ms = (warmup_end - warmup_start) * 1000.0
            all_warmups.append(warmup_ms)

        # Timed iterations
        for i in range(iters):
            t0 = time.perf_counter()
            _one_step()
            torch_xla.sync()
            t1 = time.perf_counter()

            iter_ms = (t1 - t0) * 1000.0
            times_ms.append(iter_ms)

            mem = xm.get_memory_info(device)  # {'bytes_used': '...', 'bytes_limit': '...'}
            used_bytes = int(mem["bytes_used"])
            total_bytes = int(mem["bytes_limit"])
            free_bytes = total_bytes - used_bytes

            mem_snapshots.append({
                "iter": i,
                "bytes_used": used_bytes,
                "bytes_free": free_bytes,
                "bytes_limit": total_bytes,
            })

            if min_free_bytes is None or free_bytes < min_free_bytes:
                min_free_bytes = free_bytes
            if used_bytes > peak_used_bytes:
                peak_used_bytes = used_bytes

            prof.step()  # advance profiler step if you later add schedules

    # Optional: visualize StableHLO graph
    visualize_mlir(stablehlo_out + "/functions/forward.mlir", "", stablehlo_out + "/graph_rep")

    # XLA runtime metrics
    xla_metrics_text = xmetrics.metrics_report()

    # ---- timing summary ----
    timing_summary = _summarize_times(times_ms)
    timing_summary["per_iter_latency_ms"] = times_ms
    timing_summary["warmup_latency_ms"] = all_warmups

    # ---- memory summary ----
    weight_bytes = _compute_weight_bytes(module)
    hbm_total_bytes = total_bytes or 0
    activation_est_bytes = max(0, peak_used_bytes - weight_bytes) if hbm_total_bytes > 0 else None
    usage_fraction = (peak_used_bytes / hbm_total_bytes) if hbm_total_bytes else None

    memory_summary = {
        # Raw numbers
        "hbm_total_bytes": hbm_total_bytes,
        "hbm_total_human": _human_bytes(hbm_total_bytes) if hbm_total_bytes else None,
        "peak_hbm_used_bytes": peak_used_bytes,
        "peak_hbm_used_human": _human_bytes(peak_used_bytes) if peak_used_bytes else None,
        "weights_bytes": weight_bytes,
        "weights_human": _human_bytes(weight_bytes),
        "activation_est_bytes": activation_est_bytes,
        "activation_est_human": _human_bytes(activation_est_bytes) if activation_est_bytes is not None else None,
        "min_free_bytes": min_free_bytes,
        "min_free_human": _human_bytes(min_free_bytes) if min_free_bytes is not None else None,
        "usage_fraction": usage_fraction,   # peak_used / total (0–1)

        # NOTE: true fragmentation (allocator internal fragmentation) is not observable
        # from this API; we only expose utilization numbers.
        "fragmentation_estimate": None,

        # Per-iter snapshots for debugging
        "per_iter_bytes_used": [s["bytes_used"] for s in mem_snapshots],
        "per_iter_bytes_free": [s["bytes_free"] for s in mem_snapshots],
    }

    # ---- throughput summary ----
    throughput_dims = _infer_throughput_dims(args_dev, kwargs_dev)
    throughput_summary = None
    if throughput_dims is not None and times_ms:
        avg_step_ms = timing_summary.get("mean_ms", sum(times_ms) / len(times_ms))
        avg_step_s = avg_step_ms / 1000.0

        B = throughput_dims["batch_size"]
        tokens_per_sample = throughput_dims["tokens_per_sample"]
        tokens_per_step = throughput_dims["tokens_per_step"]
        mode = throughput_dims["mode"]

        imgs_per_s = None
        if mode == "image_4d":
            imgs_per_s = B / avg_step_s

        tokens_per_s = tokens_per_step / avg_step_s

        throughput_summary = {
            "mode": mode,
            "batch_size": B,
            "tokens_per_sample": tokens_per_sample,
            "tokens_per_step": tokens_per_step,
            "avg_step_ms": avg_step_ms,
            "images_per_second": imgs_per_s,
            "tokens_per_second": tokens_per_s,
            "example_input_shape": throughput_dims["example_shape"],
        }

    report = {
        "device": str(device),
        "config": {
            "iters": iters,
            "warmup": warmup,
            "dtype": dtype,
            "backward": do_backward,
            "trace_dir": trace_dir,
        },
        "timing_ms": timing_summary,
        "memory": memory_summary,
        "throughput": throughput_summary,
        "trace_dir": trace_dir,
        "stablehlo_mlir": stablehlo_path,  # path or EXPORT_FAILED:...
        "xla_metrics_sample": xla_metrics_text[:2000] + ("..." if len(xla_metrics_text) > 2000 else ""),
        "trace_howto": (
            f"Trace written to: {trace_dir}\n"
            "In Colab:\n"
            "  %load_ext tensorboard\n"
            f"  %tensorboard --logdir {trace_dir}\n"
            "Open the Profile tab for TPU traces."
        ),
    }

    if print_report:
        print(json.dumps(report, indent=2))
    return report

In [None]:
# ERIC TEMP
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import time
import json
import os
import pathlib
from transformers import AutoModelForCausalLM, AutoTokenizer

# ---------------------------------------------------------
# 1. Wrapper (Still needed to sanitize outputs)
# ---------------------------------------------------------
class GPT2Wrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask):
        # use_cache=False is critical for static shapes
        out = self.model(input_ids, attention_mask=attention_mask, use_cache=False)
        return out.logits

# ---------------------------------------------------------
# 2. StableHLO Export Function
# ---------------------------------------------------------
def export_stablehlo(module, args, kwargs, output_path):
    """Export a module to StableHLO MLIR format."""
    try:
        # Option 1: Using torch_xla.stablehlo (if available)
        try:
            import torch_xla.stablehlo as xla_stablehlo

            # Export to StableHLO
            stablehlo_program = torch.export.export(module, args, kwargs=kwargs)
            stablehlo_module = xla_stablehlo.exported_program_to_stablehlo(stablehlo_program)

            # Get the MLIR text
            mlir_text = stablehlo_module.get_stablehlo_text()

            # Save it
            pathlib.Path(os.path.dirname(output_path) or ".").mkdir(parents=True, exist_ok=True)
            with open(output_path, "w") as f:
                f.write(mlir_text)

            print(f"SUCCESS: StableHLO MLIR saved to {output_path}")
            return output_path

        except (ImportError, AttributeError):
            # Option 2: Using torch.export + dynamo backend
            print("torch_xla.stablehlo not available, trying alternative export method...")

            # Export using torch.export API
            exported_program = torch.export.export(module, args, kwargs=kwargs)

            # Convert to StableHLO using XLA backend
            # This requires torch_xla integration
            stablehlo_gm = torch_xla.experimental.stablehlo.exported_program_to_stablehlo(
                exported_program
            )

            # Save the MLIR
            with open(output_path, "w") as f:
                f.write(str(stablehlo_gm))

            print(f"SUCCESS: StableHLO MLIR saved to {output_path}")
            return output_path

    except Exception as e:
        print(f"ERROR: StableHLO export failed: {e}")
        print(f"Exception type: {type(e).__name__}")
        import traceback
        traceback.print_exc()
        return f"EXPORT_FAILED: {e}"

# ---------------------------------------------------------
# 3. Profiler with StableHLO Export
# ---------------------------------------------------------
def profile_module_on_tpu(
    module_fn,
    input_fn,
    *,
    iters=10,
    warmup=5,
    stablehlo_out="gpt2_stablehlo.mlir",
):
    device = torch_xla.device()

    # Setup TPU module
    module_tpu = module_fn().to(device)
    args_dev, kwargs_dev = input_fn(device)

    # -------------------------------------------------------
    # A. StableHLO Export Phase (on CPU with tracing)
    # -------------------------------------------------------
    mlir_path = None
    if stablehlo_out:
        print("\n=== Exporting to StableHLO ===")
        try:
            # Create CPU version for export (StableHLO export works better on CPU)
            module_cpu = module_fn()
            args_cpu, kwargs_cpu = input_fn("cpu")

            # Export to StableHLO
            mlir_path = export_stablehlo(module_cpu, args_cpu, kwargs_cpu, stablehlo_out)

        except Exception as e:
            print(f"ERROR: StableHLO export failed: {e}")
            mlir_path = f"EXPORT_FAILED: {e}"

    # -------------------------------------------------------
    # B. Profiling Phase (on TPU)
    # -------------------------------------------------------
    print("\n=== Profiling on TPU ===")

    def _one_step():
        with torch.no_grad():
            module_tpu(*args_dev, **kwargs_dev)

    # Warmup
    print(f"Warming up for {warmup} iterations...")
    for _ in range(warmup):
        _one_step()
        torch_xla.sync()

    # Timing
    print(f"Running {iters} timed iterations...")
    times = []
    for i in range(iters):
        t0 = time.perf_counter()
        _one_step()
        torch_xla.sync()
        t1 = time.perf_counter()
        times.append((t1 - t0) * 1000)

    avg_ms = sum(times) / len(times)
    min_ms = min(times)
    max_ms = max(times)

    print(f"\nLatency Statistics:")
    print(f"  Average: {avg_ms:.2f} ms")
    print(f"  Min:     {min_ms:.2f} ms")
    print(f"  Max:     {max_ms:.2f} ms")

    return {
        "mlir_path": mlir_path,
        "latency_ms": {
            "avg": avg_ms,
            "min": min_ms,
            "max": max_ms,
            "all": times
        }
    }

# ---------------------------------------------------------
# 4. Run It
# ---------------------------------------------------------

def gpt2_module_fn():
    base = AutoModelForCausalLM.from_pretrained("gpt2")
    return GPT2Wrapper(base)

def gpt2_input_fn(device):
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    text = "The quick brown fox jumps over the lazy dog."
    enc = tokenizer(
        [text] * 2,
        padding="max_length",
        truncation=True,
        max_length=64,
        return_tensors="pt"
    )

    if device == "cpu":
        return (enc["input_ids"], enc["attention_mask"]), {}
    else:
        return (enc["input_ids"].to(device), enc["attention_mask"].to(device)), {}

print("\n================ NLP: GPT-2 small (StableHLO Export) ================")
result = profile_module_on_tpu(
    module_fn=gpt2_module_fn,
    input_fn=gpt2_input_fn,
    stablehlo_out="gpt2_stablehlo.mlir"
)

print(f"\nResults: {json.dumps(result, indent=2)}")

In [None]:
# from Gemimi to get stable HLO for GPT-2 small:

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import time
import os
import pathlib
from transformers import AutoModelForCausalLM, AutoTokenizer

# ---------------------------------------------------------
# 1. Wrapper (Essential for clean graph capture)
# ---------------------------------------------------------
class GPT2Wrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask):
        # Disable cache to ensure a static graph (no dynamic shapes)
        out = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            use_cache=False,
            return_dict=True
        )
        return out.logits

# ---------------------------------------------------------
# 2. Profiler with StableHLO Runtime Capture
# ---------------------------------------------------------
def profile_module_on_tpu(
    module_fn,
    input_fn,
    *,
    iters=5,
    warmup=2,
    stablehlo_out="gpt2_stablehlo.mlir",
):
    device = xm.xla_device()

    # Instantiate model on TPU
    module = module_fn().to(device)
    args_dev, kwargs_dev = input_fn(device)

    # -------------------------------------------------------
    # A. Capture StableHLO directly from Runtime
    # -------------------------------------------------------
    mlir_path = None
    if stablehlo_out:
        print("Attempting Runtime StableHLO Capture...")
        try:
            # 1. Run Forward Pass (Lazy Tensors created)
            output = module(*args_dev, **kwargs_dev)

            # 2. Capture Graph
            # Try the modern StableHLO API first
            if hasattr(xm, "get_stablehlo"):
                # This returns the graph in StableHLO dialect
                graph_text = xm.get_stablehlo([output])
            else:
                # Fallback for older torch_xla versions (returns HLO, similar but not identical)
                print("WARNING: xm.get_stablehlo not found. Falling back to HLO capture.")
                graph_text = torch_xla._XLAC._get_xla_tensors_text([output])

            # 3. Save
            pathlib.Path(os.path.dirname(stablehlo_out) or ".").mkdir(parents=True, exist_ok=True)
            with open(stablehlo_out, "w") as f:
                f.write(graph_text)

            mlir_path = stablehlo_out
            print(f"SUCCESS: IR saved to {stablehlo_out}")

            # Clear graph to prevent memory accumulation
            xm.mark_step()

        except Exception as e:
            print(f"ERROR: Capture failed: {e}")
            mlir_path = f"CAPTURE_FAILED: {e}"

    # -------------------------------------------------------
    # B. Performance Profiling
    # -------------------------------------------------------
    def _one_step():
        with torch.no_grad():
            module(*args_dev, **kwargs_dev)

    print(f"Running {iters} iterations...")
    # Warmup
    for _ in range(warmup):
        _one_step()
        xm.mark_step()

    # Timing
    t0 = time.perf_counter()
    for _ in range(iters):
        _one_step()
        xm.mark_step()
    t1 = time.perf_counter()

    avg_ms = ((t1 - t0) * 1000) / iters
    print(f"Avg Latency: {avg_ms:.2f} ms")

    return {"mlir_path": mlir_path, "latency_ms": avg_ms}

# ---------------------------------------------------------
# 3. Execution
# ---------------------------------------------------------
def gpt2_module_fn():
    base = AutoModelForCausalLM.from_pretrained("gpt2")
    return GPT2Wrapper(base)

def gpt2_input_fn(device):
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
    text = "The quick brown fox jumps over the lazy dog."
    enc = tokenizer(
        [text] * 2,
        padding="max_length", truncation=True, max_length=64,
        return_tensors="pt"
    )
    return (enc["input_ids"].to(device), enc["attention_mask"].to(device)), {}

if __name__ == "__main__":
    profile_module_on_tpu(
        module_fn=gpt2_module_fn,
        input_fn=gpt2_input_fn,
        stablehlo_out="gpt2_stablehlo.mlir"
    )

In [None]:
EMBED, HEADS = 256, 16
BATCH, SEQ = 128, 64

# MHA

profile_module_on_tpu(
    module_fn=mha_module_fn(embed_dim=EMBED, num_heads=HEADS, batch_first=True),
    input_fn=mha_input_fn(batch=BATCH, seq_len=SEQ, embed_dim=EMBED),
    iters=10,
    warmup=5,
    dtype="bf16",                     # 'bf16' recommended on TPU
    do_backward=False,
    trace_dir="tpu_mha_trace_generic",
    stablehlo_out="mha.stablehlo.mlir",  # <-- StableHLO export file
    print_report=True
)

In [None]:
EMBED, HEADS = 256, 16
BATCH, SEQ = 128, 64

# GEMM

profile_module_on_tpu(
    module_fn=gemm_module_fn(),
    input_fn=gemm_input_fn(),
    iters=10,
    warmup=5,
    dtype="bf16",                     # 'bf16' recommended on TPU
    do_backward=False,
    trace_dir="tpu_gemm_trace_generic",
    stablehlo_out="gemm.stablehlo.mlir",  # <-- StableHLO export file
    print_report=True
)

In [None]:
EMBED, HEADS = 256, 16
BATCH, SEQ = 128, 64

# FFN

profile_module_on_tpu(
    module_fn=ffn_module_fn(),
    input_fn=ffn_input_fn(),
    iters=10,
    warmup=5,
    dtype="bf16",                     # 'bf16' recommended on TPU
    do_backward=False,
    trace_dir="tpu_ffn_trace_generic",
    stablehlo_out="ffn.stablehlo.mlir",  # <-- StableHLO export file
    print_report=True
)

In [None]:
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

def clear_tpu_state():
    try:
        xm.mark_step()
        xm.wait_device_ops()
    except Exception:
        pass

    # Clear compile cache (supported on v5e)
    try:
        xr.clear_all_cache()
    except Exception:
        pass

    print("✔️ Cleared TPU runtime & compile cache")

In [None]:
# ---------------------------------------------------------
# 1) Vision: ResNet-50
# ---------------------------------------------------------

def resnet50_module_fn(model_name="microsoft/resnet-50"):
    def _build():
        return AutoModelForImageClassification.from_pretrained(model_name)
    return _build

def resnet50_input_fn(batch=8, height=224, width=224):
    def _make(device, dt):
        # HF vision models expect NCHW
        x = torch.randn(batch, 3, height, width, device=device, dtype=dt)
        return (x,), {}
    return _make

print("\n================ Vision: ResNet-50 (PyTorch) ================")
profile_module_on_tpu(
    module_fn=resnet50_module_fn(),
    input_fn=resnet50_input_fn(batch=8, height=224, width=224),
    iters=10,
    warmup=5,
    dtype="bf16",
    do_backward=False,
    trace_dir="tpu_resnet50_trace",
    stablehlo_out="resnet50.stablehlo.mlir",
    print_report=True,
)



In [None]:

# ---------------------------------------------------------
# 2) Vision: ViT-B/16
# ---------------------------------------------------------

def vit_module_fn(model_name="google/vit-base-patch16-224"):
    def _build():
        return AutoModelForImageClassification.from_pretrained(model_name)
    return _build

def vit_input_fn(batch=8, height=224, width=224):
    def _make(device, dt):
        # Vit image models also expect NCHW
        x = torch.randn(batch, 3, height, width, device=device, dtype=dt)
        return (x,), {}
    return _make

print("\n================ Vision: ViT-B/16 (PyTorch) ================")
profile_module_on_tpu(
    module_fn=vit_module_fn(),
    input_fn=vit_input_fn(batch=8, height=224, width=224),
    iters=10,
    warmup=5,
    dtype="bf16",
    do_backward=False,
    trace_dir="tpu_vit_b16_trace",
    stablehlo_out="vit_b16.stablehlo.mlir",
    print_report=True,
)



In [None]:

# ---------------------------------------------------------
# 3) NLP: BERT-base (sequence classification)
# ---------------------------------------------------------

def bert_base_module_fn(model_name="bert-base-uncased"):
    def _build():
        return AutoModelForSequenceClassification.from_pretrained(model_name)
    return _build

def bert_base_input_fn(
    model_name="bert-base-uncased",
    batch=4,
    seq_len=128,
):
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    def _make(device, dt):
        # Tokens are integer IDs – keep them as long, ignore dt
        enc = tokenizer(
            ["this is a dummy sentence"] * batch,
            padding="max_length",
            truncation=True,
            max_length=seq_len,
            return_tensors="pt",
        )
        input_ids = enc["input_ids"].to(device)
        attention_mask = enc["attention_mask"].to(device)
        # Call pattern: module(input_ids, attention_mask=...)
        return (input_ids,), {"attention_mask": attention_mask}

    return _make

print("\n================ NLP: BERT-base (PyTorch) ================")
profile_module_on_tpu(
    module_fn=bert_base_module_fn(),
    input_fn=bert_base_input_fn(
        model_name="bert-base-uncased",
        batch=4,
        seq_len=128,
    ),
    iters=10,
    warmup=5,
    dtype="bf16",
    do_backward=False,
    trace_dir="tpu_bert_base_trace",
    stablehlo_out="bert_base.stablehlo.mlir",
    print_report=True,
)



In [None]:

# ---------------------------------------------------------
# 4) NLP: GPT-2 small
# ---------------------------------------------------------

def gpt2_module_fn(model_name="gpt2"):
    def _build():
        return AutoModelForCausalLM.from_pretrained(model_name)
    return _build

def gpt2_input_fn(
    model_name="gpt2",
    batch=2,
    seq_len=64,
    vocab_example_text="The quick brown fox jumps over the lazy dog.",
):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    # GPT-2 has no pad_token by default – set it for padding
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    def _make(device, dt):
        enc = tokenizer(
            [vocab_example_text] * batch,
            padding="max_length",
            truncation=True,
            max_length=seq_len,
            return_tensors="pt",
        )
        input_ids = enc["input_ids"].to(device)
        attention_mask = enc["attention_mask"].to(device)
        # Call pattern: module(input_ids, attention_mask=...)
        return (input_ids,), {"attention_mask": attention_mask}

    return _make

print("\n================ NLP: GPT-2 small (PyTorch) ================")
profile_module_on_tpu(
    module_fn=gpt2_module_fn(),
    input_fn=gpt2_input_fn(
        model_name="gpt2",
        batch=2,
        seq_len=64,
        vocab_example_text="The quick brown fox jumps over the lazy dog.",
    ),
    iters=10,
    warmup=5,
    dtype="bf16",
    do_backward=False,
    trace_dir="tpu_gpt2_small_trace",
    stablehlo_out="gpt2_small.stablehlo.mlir",
    print_report=True,
)



In [None]:

# ---------------------------------------------------------
# 5) NLP: GPT-Neo 1.3B
# ---------------------------------------------------------

def gptneo_module_fn(model_name="EleutherAI/gpt-neo-1.3B"):
    def _build():
        return AutoModelForCausalLM.from_pretrained(model_name)
    return _build

def gptneo_input_fn(
    model_name="EleutherAI/gpt-neo-1.3B",
    batch=1,
    seq_len=64,
    vocab_example_text="Dynamic pruning and batching for large language models.",
):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    def _make(device, dt):
        enc = tokenizer(
            [vocab_example_text] * batch,
            padding="max_length",
            truncation=True,
            max_length=seq_len,
            return_tensors="pt",
        )
        input_ids = enc["input_ids"].to(device)
        attention_mask = enc["attention_mask"].to(device)
        return (input_ids,), {"attention_mask": attention_mask}

    return _make

print("\n================ NLP: GPT-Neo 1.3B (PyTorch) ================")
profile_module_on_tpu(
    module_fn=gptneo_module_fn(),
    input_fn=gptneo_input_fn(
        model_name="EleutherAI/gpt-neo-1.3B",
        batch=1,
        seq_len=64,
        vocab_example_text="Dynamic pruning and batching for large language models.",
    ),
    iters=10,
    warmup=5,
    dtype="bf16",
    do_backward=False,
    trace_dir="tpu_gptneo_1_3b_trace",
    stablehlo_out="gptneo_1_3b.stablehlo.mlir",
    print_report=True,
)



In [None]:

# ---------------------------------------------------------
# 6) Diffusion: UNet (Stable Diffusion)
# ---------------------------------------------------------

def unet_module_fn(model_name="runwayml/stable-diffusion-v1-5"):
    def _build():
        # Load the UNet from Stable Diffusion v1.5
        return UNet2DConditionModel.from_pretrained(
            model_name,
            subfolder="unet",
        )
    return _build

def unet_input_fn(
    batch=1,
    height=64,
    width=64,
    channels=4,
):
    """
    Build dummy inputs matching your JAX profiling:
    - sample: [B, 4, 64, 64]
    - timesteps: [B]
    - encoder_hidden_states: [B, 77, 768]
    """
    def _make(device, dt):
        sample = torch.randn(batch, channels, height, width, device=device, dtype=dt)
        timesteps = torch.ones(batch, device=device, dtype=torch.float32)
        encoder_hidden_states = torch.randn(batch, 77, 768, device=device, dtype=dt)
        return (sample, timesteps, encoder_hidden_states), {}

    return _make

print("\n================ Diffusion: UNet (Stable Diffusion) ================")
profile_module_on_tpu(
    module_fn=unet_module_fn(),
    input_fn=unet_input_fn(
        batch=1,
        height=64,
        width=64,
        channels=4,
    ),
    iters=10,
    warmup=5,
    dtype="bf16",
    do_backward=False,
    trace_dir="tpu_unet_sd_trace",
    stablehlo_out="unet_sd.stablehlo.mlir",
    print_report=True,
)