In [None]:
!pip install diffusers

In [None]:
import jax
import jax.numpy as jnp
from jax import random

# ------------------------------------------------------------
# End-to-end model profiling with Hugging Face + torchax + JAX
# ------------------------------------------------------------
import numpy as np

import torchax
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForCausalLM,
    FlaxViTForImageClassification,
    FlaxBertForSequenceClassification,
    FlaxGPT2LMHeadModel,
    FlaxGPTNeoForCausalLM,
)

from huggingface_hub import login
from google.colab import userdata

# Optional: Diffusion UNet (requires diffusers)
try:
    from diffusers import UNet2DConditionModel, FlaxUNet2DConditionModel
    HAS_DIFFUSERS = True
except ImportError:
    HAS_DIFFUSERS = False

print(f"Successfully Installed Diffusers? {HAS_DIFFUSERS}")

In [None]:
hf_token = userdata.get('HF_TOKEN')

In [None]:
!pip install mlir-opt

In [None]:
master_key = random.PRNGKey(0)
num_runs = 10

In [None]:
def gemm(a, b):
  return jnp.matmul(a, b)

In [None]:
def ffn(x, w1, b1, w2, b2):
  """
  Feed Forward Network:
    Input: (128, 64, 256)
    w1=(256,1024), b1=(1024)
    w2=(1024,256), b2=(256)
  """
  h = jnp.matmul(x, w1) + b1
  h = jax.nn.relu(h)
  return jnp.matmul(h, w2) + b2

In [None]:
def mha(x, w_q, w_k, w_v, w_o, num_heads=16):
    """
    Multi-Head Self-Attention (MHA) benchmark.

    Args:
        x: Input tensor of shape (batch, seq_len, embed_dim)
        w_q, w_k, w_v: Projection weights for Q, K, V (embed_dim, embed_dim)
        w_o: Output projection weight (embed_dim, embed_dim)
        num_heads: Number of attention heads (default 16)
    Returns:
        Tensor of shape (batch, seq_len, embed_dim)
    """

    batch, seq_len, embed_dim = x.shape
    head_dim = embed_dim // num_heads

    # Get query, key, and value vectors for each token.
    q = jnp.dot(x, w_q)
    k = jnp.dot(x, w_k)
    v = jnp.dot(x, w_v)

    # Reshape to (batch, num_heads, seq_len, head_dim).
    def reshape_heads(t):
        return t.reshape(batch, seq_len, num_heads, head_dim).transpose(0, 2, 1, 3)

    q, k, v = map(reshape_heads, (q, k, v))

    # Scaled dot product attention per head.
    d_k = head_dim
    scores = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / jnp.sqrt(d_k)   # (B, H, S, S)
    attn_weights = jax.nn.softmax(scores, axis=-1)
    attn_output = jnp.matmul(attn_weights, v)                         # (B, H, S, d_k)

    # Recombine heads.
    attn_output = attn_output.transpose(0, 2, 1, 3).reshape(batch, seq_len, embed_dim)

    # Project back input embedding dimension.
    output = jnp.dot(attn_output, w_o)
    return output

In [None]:
# TEMP
import jax
import jax.numpy as jnp

def mha_pytorch_style(x, w_qkv, b_qkv, w_o, b_o, num_heads=16):
    """
    Architecturally equivalent to PyTorch's nn.MultiheadAttention.

    Args:
        x: Input (Batch, Seq, Embed)
           Note: If mimicking batch_first=False, transpose x before calling.
        w_qkv: Fused Projection Weights (Embed, 3 * Embed)
        b_qkv: Fused Projection Bias (3 * Embed)
        w_o: Output Projection Weight (Embed, Embed)
        b_o: Output Projection Bias (Embed)
    """
    batch, seq_len, embed_dim = x.shape
    head_dim = embed_dim // num_heads

    # 1. Linear Projection (Fused) + Explicit Bias Add
    # Matches PyTorch IR: dot_general followed by add
    # shape: (Batch, Seq, 3 * Embed)
    qkv = jnp.dot(x, w_qkv) + b_qkv

    # 2. Slice Q, K, V
    # Matches PyTorch IR: stablehlo.slice
    # PyTorch stores weights as (3*Embed, Embed), JAX usually (Embed, 3*Embed).
    # We assume w_qkv is already in JAX layout.
    q, k, v = jnp.split(qkv, 3, axis=-1)

    # 3. Reshape and Transpose for "Merged Batch/Heads" Optimization
    # PyTorch collapses (Batch, Heads) into one dim for BMM.
    # Target shape: (Batch * Heads, Seq, HeadDim)
    # This generates the <2048x64x16> tensors seen in the IR.
    def to_3d_layout(t):
        # (B, S, E) -> (B, S, H, D) -> (B, H, S, D) -> (B*H, S, D)
        return t.reshape(batch, seq_len, num_heads, head_dim) \
                .transpose(0, 2, 1, 3) \
                .reshape(batch * num_heads, seq_len, head_dim)

    q_3d = to_3d_layout(q)
    k_3d = to_3d_layout(k)
    v_3d = to_3d_layout(v)

    # 4. Scaled Dot Product Attention
    # Note: Multiply by reciprocal instead of divide (matches IR %10 stablehlo.multiply)
    scale = 1.0 / jnp.sqrt(head_dim).astype(q.dtype)
    q_scaled = q_3d * scale

    # Dot product: (B*H, S, D) @ (B*H, D, S) -> (B*H, S, S)
    # Matches IR: stablehlo.dot_general on rank-3 tensors
    scores = jnp.matmul(q_scaled, k_3d.swapaxes(-2, -1))

    # Softmax
    attn_weights = jax.nn.softmax(scores, axis=-1)

    # Weighted Sum: (B*H, S, S) @ (B*H, S, D) -> (B*H, S, D)
    attn_output_3d = jnp.matmul(attn_weights, v_3d)

    # 5. Reshape Back
    # (B*H, S, D) -> (B, H, S, D) -> (B, S, H, D) -> (B, S, E)
    attn_output = attn_output_3d.reshape(batch, num_heads, seq_len, head_dim) \
                                .transpose(0, 2, 1, 3) \
                                .reshape(batch, seq_len, embed_dim)

    # 6. Output Projection + Bias
    output = jnp.dot(attn_output, w_o) + b_o

    # 7. Return Tuple (Output, Weights)
    # PyTorch IR returns both. Weights must be reshaped to match PyTorch return signature.
    # PyTorch returns weights as (Batch, Seq, Seq) or (Batch*Heads, Seq, Seq) depending on config.
    # The IR provided returned tensor<128x64x64xf32>, which implies averaging over heads
    # OR the IR trace was from a specific configuration that reduced it.
    # However, to match the raw computation graph, we return the 3D weights or reshape them.
    # Here we return the tuple as seen in the PyTorch IR signature.

    # Based on IR %50 = multiply(reduce(attn_weights)), PyTorch often returns averaged weights
    # if need_weights=True is set but average_attn_weights=True (default).
    avg_weights = attn_weights.reshape(batch, num_heads, seq_len, seq_len).sum(axis=1) / num_heads

    return output, avg_weights

In [None]:
def conv_layer(x, w):
  """
    N = batch size
    C = input channels
    H = input height
    W = input width

    O = output channels (number of filters)
    I = input channels (like RGB)
    H = kernel height
    W = kernel width

    x = input tensor of shape (N, I, H, W) (input height and width)
    w = weight tensor of shape (O, I, H, W) (kernel height and width)
  """
  return jax.lax.conv_general_dilated(
      lhs=x, rhs=w,
      window_strides=(1, 1),
      padding='SAME',
      dimension_numbers=('NCHW', 'OIHW', 'NCHW')
  )


In [None]:
!pip install graphviz

from pathlib import Path
import re
from graphviz import Digraph

def visualize_mlir(mlir_path, func_name, graph_path):
  with open(mlir_path) as f:
      text = f.read()

  # matches like: %2 = stablehlo.dot_general %0, %1, ...
  pattern = re.compile(r"(%\w+)\s*=\s*([\w\.]+)\s+(.*)")
  ops = []
  for line in text.splitlines():
      line = line.strip()
      m = pattern.match(line)
      if not m:
          continue
      out, op, rest = m.groups()
      # Extract inputs like %0, %1
      inputs = re.findall(r"%\w+", rest)
      ops.append((out, op, inputs))

  g = Digraph("StableHLO", format="png")
  g.attr(rankdir="LR")

  for out, op, inputs in ops:
      g.node(out, f"{out}\n{op}")

      for inp in inputs:
          g.edge(inp, out)

  for arg in re.findall(r"(%arg\d+):", text):
      g.node(arg, f"{arg}\ninput", shape="box")

  ret = re.search(r"return\s+(.*)\s*:", text)
  if ret:
      outs = re.findall(r"%\w+", ret.group(1))
      for o in outs:
          g.edge(o, "return")
      g.node("return", "return", shape="box")

  g.render(graph_path, view=True, cleanup=True)


In [None]:
import time, math
import jax, os
import jax.numpy as jnp
import shutil, glob, gzip
import numpy as np
import json
import gc

gc.collect()
jax.clear_caches()

# ---------- helpers ----------

def _human_bytes(num: int) -> str:
    if num is None:
        return "N/A"
    num = float(num)
    for unit in ["B", "KB", "MB", "GB", "TB"]:
        if num < 1024.0 or unit == "TB":
            return f"{num:.2f} {unit}"
        num /= 1024.0

def _is_array_like(x):
    return (
        isinstance(x, (jnp.ndarray, np.ndarray))
        or (hasattr(x, "shape") and hasattr(x, "dtype") and hasattr(x, "size"))
    )

def _collect_arrays(obj, out_list):
    if _is_array_like(obj):
        out_list.append(obj)
    elif isinstance(obj, (list, tuple)):
        for v in obj:
            _collect_arrays(v, out_list)
    elif isinstance(obj, dict):
        for v in obj.values():
            _collect_arrays(v, out_list)

def _arrays_and_bytes_from_args(args):
    """Flatten all arrays from *args and compute total size in bytes."""
    arrays = []
    for a in args:
        _collect_arrays(a, arrays)
    total_bytes = 0
    for arr in arrays:
        try:
            itemsize = int(np.dtype(arr.dtype).itemsize)
            size = int(np.prod(arr.shape))
            total_bytes += itemsize * size
        except Exception:
            pass
    return arrays, total_bytes

def _tree_bytes(x):
    """Total bytes of all array-like leaves in a PyTree-like object."""
    arrays = []
    _collect_arrays(x, arrays)
    total = 0
    for arr in arrays:
        try:
            itemsize = int(np.dtype(arr.dtype).itemsize)
            size = int(np.prod(arr.shape))
            total += itemsize * size
        except Exception:
            pass
    return total

def _infer_throughput_dims_from_args(args):
    """
    Heuristic: inspect all array-like args and try to infer:
      - batch size
      - tokens per sample (seq len or H*W)
      - tokens per step
    """
    arrays, _ = _arrays_and_bytes_from_args(args)
    if not arrays:
        return None

    cand_3d = [a for a in arrays if len(getattr(a, "shape", ())) == 3]
    cand_4d = [a for a in arrays if len(getattr(a, "shape", ())) == 4]
    cand_2d = [a for a in arrays if len(getattr(a, "shape", ())) == 2]

    mode = None
    main = None

    def numel(a):
        try:
            return int(np.prod(a.shape))
        except Exception:
            return 0

    # Prefer 3D (B, S, D) -> sequence, then 4D -> image, then 2D -> (B, features)
    if cand_3d:
        main = max(cand_3d, key=numel)
        B, S = int(main.shape[0]), int(main.shape[1])
        tokens_per_sample = S
        mode = "sequence_3d"
    elif cand_4d:
        main = max(cand_4d, key=numel)
        B = int(main.shape[0])
        H, W = int(main.shape[-2]), int(main.shape[-1])
        tokens_per_sample = H * W
        mode = "image_4d"
    elif cand_2d:
        main = max(cand_2d, key=numel)
        B, S = int(main.shape[0]), int(main.shape[1])
        tokens_per_sample = S
        mode = "sequence_2d"
    else:
        return None

    tokens_per_step = B * tokens_per_sample
    return {
        "batch_size": B,
        "tokens_per_sample": tokens_per_sample,
        "tokens_per_step": tokens_per_step,
        "mode": mode,
        "example_shape": tuple(int(x) for x in main.shape),
    }

# ---------- TPU trace parsing ----------

def extract_tpu_ops(trace_json):
    fields = ["name",  "timestamp_ns", "duration_ns",
              "bytes_accessed", "raw_bytes_accessed",
              "device_duration_ps", "device_offset_ps",
              "hlo_category", "model_flops", "long_name"
              "shape_with_layout", "tf_op"]

    tpu_ops = []

    for event in trace_json.get("traceEvents", []):
        if event.get("ph") != "X":
            continue

        args = event.get("args", {})
        if "device_duration_ps" not in args and "hlo_category" not in args:
            continue

        op = {
            "name": event.get("name", ""),
            "timestamp_ns": event.get("ts"),
            "duration_ns": event.get("dur"),
        }

        # Add all optional fields from args
        for field in ["bytes_accessed", "raw_bytes_accessed", "device_duration_ps",
                      "device_offset_ps", "hlo_category", "model_flops",
                      "shape_with_layout", "long_name", "tf_op"]:
            op[field] = args.get(field)

        tpu_ops.append(op)

    lines = []
    for i, op in enumerate(tpu_ops, 1):
        lines.append(f"--- TPU Op {i} ---")
        for k, v in op.items():
            lines.append(f"{k:20}: {v}")
    return "\n".join(lines)

def compute_trace_memory_metrics(trace_json):
    """
    From the TPU trace, estimate memory *traffic* stats using bytes_accessed.
    This is NOT peak HBM, but useful for bandwidth-ish insight.
    """
    total_bytes_accessed = 0
    max_bytes_accessed = 0
    op_count_with_bytes = 0

    for event in trace_json.get("traceEvents", []):
        if event.get("ph") != "X":
            continue
        args = event.get("args", {})
        if "device_duration_ps" not in args and "hlo_category" not in args:
            continue

        b = args.get("bytes_accessed")
        if b is None:
            continue
        if isinstance(b, str):
            try:
                b = int(b)
            except ValueError:
                continue
        try:
            b = int(b)
        except Exception:
            continue

        op_count_with_bytes += 1
        total_bytes_accessed += b
        if b > max_bytes_accessed:
            max_bytes_accessed = b

    return {
        "ops_with_bytes_accessed": op_count_with_bytes,
        "total_bytes_accessed": total_bytes_accessed,
        "total_bytes_accessed_human": _human_bytes(total_bytes_accessed),
        "max_bytes_accessed": max_bytes_accessed,
        "max_bytes_accessed_human": _human_bytes(max_bytes_accessed),
    }

# ---------- basic timing helpers ----------

# Return function call time in ms.
def profile_func_call(func_name, compiled_func, get_trace, *args):
    logs_dir = f"logs/{func_name}"

    # First call will include compile time.
    t_s = time.perf_counter()
    jax.block_until_ready(compiled_func(*args))
    t_e = time.perf_counter()

    # Generate trace.
    if get_trace:
        jax.profiler.start_trace(logs_dir)
        jax.block_until_ready(compiled_func(*args))
        jax.profiler.stop_trace()

    return (t_e - t_s) * 1000

def write_trace_summary(root_profile_dir):
    trace_files = glob.glob(os.path.join(root_profile_dir, "**", "*.json.gz"), recursive=True)
    if not trace_files:
        return None, None  # no trace

    with gzip.open(trace_files[0], "rt", encoding="utf-8") as f:
        trace_json = json.load(f)

    summary_text = extract_tpu_ops(trace_json)
    summary_path = os.path.join(root_profile_dir, "summary.txt")
    with open(summary_path, "w") as f:
        f.write(summary_text)

    trace_mem_metrics = compute_trace_memory_metrics(trace_json)
    return summary_path, trace_mem_metrics

# ---------- main profiling entrypoint ----------

def profile_jax_function(func_name, func, num_runs, *args):
    """
    Profile a JAX function on the current backend.

    Memory section includes:
      - peak_hbm_bytes_est     (approx; based on args)
      - weight_bytes           (assumes args[0] is params PyTree)
      - activation_bytes_est   (total_arg_bytes - weight_bytes)
      - usage_fraction_est     (if JAX_TOTAL_HBM_BYTES env var is set)

    Returns:
      report: dict with timing, throughput, memory, and artifact paths.
    """
    print(f"Profiling {func_name}.\n")

    dir_to_remove = f"logs/{func_name}"
    if os.path.exists(dir_to_remove):
        shutil.rmtree(dir_to_remove)

    compiled_func = jax.jit(func)

    # --- compile + first execution (also captures a trace) ---
    first_call_time = profile_func_call(func_name, compiled_func, True, *args)
    print(f"Run 0 (compile + execute): {first_call_time:.3f} ms.")

    # --- steady-state timings ---
    all_times = []
    for i in range(num_runs):
        run_time = profile_func_call(func_name, compiled_func, False, *args)
        all_times.append(run_time)
        print(f"Run {i+1}: {run_time:.3f} ms")

    avg_time = float(np.mean(all_times))
    std_time = float(np.std(all_times))
    print(f"\nAverage over {num_runs} runs: {avg_time:.3f} ms Â± {std_time:.3f} ms.")

    root_profile_dir = f"{dir_to_remove}/plugins/profile"

    # --- write textual TPU-op summary + trace memory traffic ---
    summary_path, trace_mem_metrics = write_trace_summary(root_profile_dir)

    # --- StableHLO IR export + graph ---
    # stablehlo_ir = compiled_func.lower(*args).compiler_ir("stablehlo")
    mlir_path = f"{root_profile_dir}/stablehlo.txt"
    graph_path = f"{root_profile_dir}/stablehlo"

    # os.makedirs(root_profile_dir, exist_ok=True)
    # with open(mlir_path, "w") as f:
    #     f.write(str(stablehlo_ir))

    # visualize_mlir(mlir_path, func_name, graph_path)

    # --- device memory profile (for true HBM usage/fragmentation) ---
    memory_prof_path = os.path.join(root_profile_dir, "memory.prof")
    try:
        jax.profiler.save_device_memory_profile(memory_prof_path)
    except Exception as e:
        memory_prof_path = f"PROFILE_FAILED: {type(e).__name__}: {e}"

    # ---------- memory section (with the metrics you asked for) ----------

    # 1) Total size of all array-like args (params + inputs) as a lower bound on what
    #    needs device memory at some point.
    _, total_arg_bytes = _arrays_and_bytes_from_args(args)

    # 2) Weight bytes: assume first positional argument is a params PyTree.
    weight_bytes = 0
    if func_name == "FFN" or func_name == "MHA":
        weight_bytes = _tree_bytes(args[1]) + _tree_bytes(args[2]) + _tree_bytes(args[3]) + _tree_bytes(args[4])
    elif func_name != "GEMM":
        weight_bytes = _tree_bytes(args[0])

    # 3) Activation bytes (approx): "everything else".
    activation_bytes_est = max(total_arg_bytes - weight_bytes, 0)

    # 4) Peak HBM (approx): weights + these activations.
    peak_hbm_bytes_est = weight_bytes + activation_bytes_est

    # 5) Total HBM capacity (optional, from env), and usage fraction.
    total_hbm_bytes = None
    env_val = os.environ.get("JAX_TOTAL_HBM_BYTES")
    if env_val is not None:
        try:
            total_hbm_bytes = int(env_val)
        except ValueError:
            total_hbm_bytes = None

    usage_fraction_est = (
        (peak_hbm_bytes_est / total_hbm_bytes) if (total_hbm_bytes and peak_hbm_bytes_est) else None
    )

    memory_summary = {
        # Core metrics you asked for:
        "peak_hbm_bytes_est": peak_hbm_bytes_est,
        "peak_hbm_human": _human_bytes(peak_hbm_bytes_est),
        "weight_bytes": weight_bytes,
        "weight_human": _human_bytes(weight_bytes),
        "activation_bytes_est": activation_bytes_est,
        "activation_human": _human_bytes(activation_bytes_est),
        "usage_fraction_est": usage_fraction_est,

        # Context:
        "total_hbm_bytes_config": total_hbm_bytes,
        "total_hbm_human": _human_bytes(total_hbm_bytes) if total_hbm_bytes else None,
        "arg_total_bytes": total_arg_bytes,
        "arg_total_human": _human_bytes(total_arg_bytes),
        "note": (
            "peak_hbm_bytes_est is a heuristic based on argument sizes. "
            "True peak HBM and fragmentation should be read from memory.prof "
            "via pprof or TensorBoard."
        ),
        "trace_memory": trace_mem_metrics,
    }

    # --- throughput metrics (tokens/s, images/s) ---
    throughput_dims = _infer_throughput_dims_from_args(args)
    throughput_summary = None
    if throughput_dims is not None and avg_time > 0:
        avg_step_s = avg_time / 1000.0
        B = throughput_dims["batch_size"]
        tokens_per_step = throughput_dims["tokens_per_step"]
        mode = throughput_dims["mode"]

        tokens_per_second = tokens_per_step / avg_step_s
        images_per_second = None
        if mode == "image_4d":
            images_per_second = B / avg_step_s

        throughput_summary = {
            "mode": mode,
            "batch_size": B,
            "tokens_per_sample": throughput_dims["tokens_per_sample"],
            "tokens_per_step": tokens_per_step,
            "avg_step_ms": avg_time,
            "tokens_per_second": tokens_per_second,
            "images_per_second": images_per_second,
            "example_input_shape": throughput_dims["example_shape"],
        }

        print("\nThroughput estimate:")
        print(f"  Mode: {mode}")
        print(f"  Batch size: {B}")
        print(f"  Tokens per step: {tokens_per_step}")
        print(f"  Tokens/sec (avg): {tokens_per_second:,.2f}")
        if images_per_second is not None:
            print(f"  Images/sec (avg): {images_per_second:,.2f}")
    else:
        print("\nThroughput estimate: could not infer from argument shapes.")

    # --- pretty-print some memory stats ---
    print("\nMemory estimates:")
    print(f"  Weight bytes:          {memory_summary['weight_human']}")
    print(f"  Activation bytes (est):{memory_summary['activation_human']}")
    print(f"  Peak HBM (est):        {memory_summary['peak_hbm_human']}")
    if memory_summary["total_hbm_human"]:
        print(f"  Total HBM (config):    {memory_summary['total_hbm_human']}")
        print(f"  Usage fraction (est):  {memory_summary['usage_fraction_est']:.3f}")
    else:
        print("  Total HBM unknown (set JAX_TOTAL_HBM_BYTES to get usage fraction).")

    if trace_mem_metrics is not None:
        print("\nTrace memory traffic (bytes_accessed across TPU ops):")
        print(f"  Ops with bytes_accessed: {trace_mem_metrics['ops_with_bytes_accessed']}")
        print(f"  Total bytes_accessed:    {trace_mem_metrics['total_bytes_accessed_human']}")
        print(f"  Max bytes_accessed/op:   {trace_mem_metrics['max_bytes_accessed_human']}")

    print("\nArtifacts:")
    print(f"  TPU trace summary: {summary_path}")
    print(f"  StableHLO IR:      {mlir_path}")
    print(f"  StableHLO graph:   {graph_path}")
    print(f"  Device memory profile (HBM/fragmentation via pprof/TensorBoard): {memory_prof_path}")

    report = {
        "func_name": func_name,
        "timing_ms": {
            "first_run_ms": first_call_time,
            "runs_ms": all_times,
            "mean_ms": avg_time,
            "std_ms": std_time,
        },
        "throughput": throughput_summary,
        "memory": memory_summary,
        "paths": {
            "trace_summary_txt": summary_path,
            "stablehlo_txt": mlir_path,
            "stablehlo_graph_dir": graph_path,
            "device_memory_profile": memory_prof_path,
            "tensorboard_logdir": f"logs/{func_name}",
        },
    }

    print("\nJSON report:")
    print(json.dumps(report, indent=2, default=str))

    return report


In [None]:
keys = random.split(master_key, 2)

A = random.normal(keys[0], (128, 256))
B = random.normal(keys[1], (256, 1024))

profile_jax_function("GEMM", gemm, num_runs, A, B)

# Trace had two steps: copy-done and fusion.
# Copy seemed to only moves tensor B (GPT says it's rearranging the tensor for more efficient computation).
  # %copy-done = f32[256,1024]{1,0:T(8,128)S(1)} copy-done((f32[256,1024]{1,0:T(8,128)S(1)}, f32[256,1024]{1,0:T(8,128)}, u32[]{:S(2)}) %copy-start)
# The fusion read all of A and B and wrote the output.
  # %fusion = f32[128,1024]{1,0:T(8,128)} fusion(f32[128,256]{1,0:T(8,128)} %Arg_0.1, f32[256,1024]{1,0:T(8,128)S(1)} %copy-done), kind=kOutput, calls=%fused_computation

In [None]:
keys = random.split(master_key, 5)

x = random.normal(keys[0], (128, 64, 256))
w1 = random.normal(keys[1], (256, 1024))
b1 = random.normal(keys[2], (1024,))
w2 = random.normal(keys[3], (1024, 256))
b2 = random.normal(keys[4], (256,))

profile_jax_function("FFN", ffn, num_runs, x, w1, b1, w2, b2)

# First matrix multiplication is broken up into fusion and copy.1
  # %copy-done = f32[256,1024]{1,0:T(8,128)S(1)} copy-done((f32[256,1024]{1,0:T(8,128)S(1)}, f32[256,1024]{1,0:T(8,128)}, u32[]{:S(2)}) %copy-start)
  # %fusion = f32[128,64,1024]{2,0,1:T(8,128)S(1)} fusion(f32[128,64,256]{2,1,0:T(8,128)} %Arg_0.1, f32[256,1024,1]{1,0,2:T(8,128)S(1)} %bitcast.8), kind=kOutput, calls=%fused_computation
  # %copy.1 = f32[128,64,1024]{2,1,0:T(8,128)} copy(f32[128,64,1024]{2,0,1:T(8,128)S(1)} %fusion)
# We add the first bias.
  # %broadcast_add_fusion = f32[128,64,1024]{2,1,0:T(8,128)} fusion(f32[128,64,1024]{2,1,0:T(8,128)} %Arg_0.1, f32[1024]{0:T(1024)} %Arg_1.2), kind=kLoop, calls=%fused_computation
# Some stuff happens on the CPU in between.
  # %broadcast_maximum_fusion = f32[128,64,1024]{2,1,0:T(8,128)} fusion(f32[128,64,1024]{2,1,0:T(8,128)} %Arg_0.1), kind=kLoop, calls=%fused_computation
# Then the second matmul.
  # %fusion.3 = f32[128,64,256]{2,1,0:T(8,128)} fusion(f32[128,64,1024]{2,1,0:T(8,128)S(1)} %copy-done, f32[1024,256,1]{1,0,2:T(8,128)} %bitcast.9), kind=kOutput, calls=%fused_computation.2
# Then the second bias is added.
  # %broadcast_add_fusion = f32[128,64,256]{2,1,0:T(8,128)} fusion(f32[128,64,256]{2,1,0:T(8,128)} %Arg_0.1, f32[256]{0:T(256)} %Arg_1.2), kind=kLoop, calls=%fused_computation

# Maybe missing some details.

In [None]:
keys = random.split(master_key, 5)

batch, seq, embed, heads = 128, 64, 256, 16

x = random.normal(keys[0], (batch, seq, embed))
w_q = random.normal(keys[1], (embed, embed))
w_k = random.normal(keys[2], (embed, embed))
w_v = random.normal(keys[3], (embed, embed))
w_o = random.normal(keys[4], (embed, embed))

profile_jax_function("MHA", mha, num_runs, x, w_q, w_k, w_v, w_o)

In [None]:
# using built-in JAX convolution (probably why it's faster)

keys = random.split(master_key, 2)

batch, in_channels, input_height, input_width = 128, 3, 12, 12
out_channels, kernel_height, kernel_width = 3, 3, 3

# padding and stride are 1 (implicit in the function)

x = random.normal(keys[0], (batch, in_channels, input_height, input_width))   # NCHW
w = random.normal(keys[1], (out_channels, in_channels, kernel_height, kernel_width))       # OIHW

profile_jax_function("Conv", conv_layer, num_runs, x, w)

In [None]:
def make_image_model_jax_fn(model_name, batch_size=8, height=224, width=224):
    """
    Build an image classification model as a JAX fn.

    If a Flax version exists (ViT), use Flax.
    Otherwise (e.g., ResNet-50), fall back to torchax on PyTorch.

    Returns:
      jax_forward(params_or_weights, pixel_values),
      params_or_weights,
      dummy_pixel_values
    """
    # Dummy image batch: NCHW, as HF vision models typically expect
    dummy_pixel_values = jnp.ones(
        (batch_size, 3, height, width), dtype=jnp.float32
    )

    # ---- Flax ViT path ----
    if model_name == "google/vit-base-patch16-224":
        model = FlaxViTForImageClassification.from_pretrained(model_name)

        params = model.params

        def jax_forward(params, pixel_values):
            outputs = model(
                pixel_values=pixel_values,
                params=params,
                train=False,
            )
            # logits is a single JAX array
            return outputs.logits

        print("Using Flax!")
        return jax_forward, params, dummy_pixel_values

    # ---- Fallback: PyTorch ResNet-50 via torchax ----
    else:
        pt_model = AutoModelForImageClassification.from_pretrained(model_name)
        weights, raw_func = torchax.extract_jax(pt_model)

        def jax_forward(weights, pixel_values):
            # return_dict=False -> tuple of tensors; first element is logits
            logits, *_ = raw_func(
                weights,
                (),
                {"pixel_values": pixel_values, "return_dict": False},
            )
            return logits

        print("Using PyTorch and torchax!")
        return jax_forward, weights, dummy_pixel_values

In [None]:
# 1) Vision: ResNet-50
print("\n================ Vision: ResNet-50 ================")
resnet_fn, resnet_params, resnet_pixels = make_image_model_jax_fn(
    "microsoft/resnet-50",
    batch_size=8,
    height=224,
    width=224,
)
profile_jax_function(
    "ResNet-50",
    resnet_fn,
    num_runs,
    resnet_params,
    resnet_pixels,
)

In [None]:

# 2) Vision: ViT-B/16
print("\n================ Vision: ViT-B/16 ================")
vit_fn, vit_params, vit_pixels = make_image_model_jax_fn(
    "google/vit-base-patch16-224",
    batch_size=8,
    height=224,
    width=224,
)
profile_jax_function(
    "ViT-B-16",
    vit_fn,
    num_runs,
    vit_params,
    vit_pixels,
)

In [None]:

def make_bert_flax_jax_fn(model_name="bert-base-uncased", batch_size=4, seq_len=128):
    """
    Build a BERT-base sequence classifier using Hugging Face's Flax model,
    and return a function compatible with profile_jax_function:

      jax_forward(params, input_ids, attention_mask)
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = FlaxBertForSequenceClassification.from_pretrained(model_name)

    encoded = tokenizer(
        ["this is a dummy sentence"] * batch_size,
        padding="max_length",
        truncation=True,
        max_length=seq_len,
        return_tensors="jax",
    )
    input_ids = encoded["input_ids"]          # jnp.int32
    attention_mask = encoded["attention_mask"]  # jnp.int32

    params = model.params  # Flax parameters pytree

    def jax_forward(params, input_ids, attention_mask):
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            params=params,
            train=False,
        )
        return outputs.logits  # [batch, num_labels]

    return jax_forward, params, input_ids, attention_mask

In [None]:
# 3) NLP: BERT-base (sequence classification)
print("\n================ NLP: BERT-base (Flax) ================")
bert_fn, bert_params, bert_ids, bert_mask = make_bert_flax_jax_fn(
    model_name="bert-base-uncased",
    batch_size=4,
    seq_len=128,
)

profile_jax_function(
    "BERT-base-Flax",
    bert_fn,
    num_runs,
    bert_params,
    bert_ids,
    bert_mask,
)

In [None]:
def make_gpt_like_jax_fn(
    model_name,
    batch_size=2,
    seq_len=64,
    vocab_example_text="The quick brown fox jumps over the lazy dog.",
):
    """
    Build a GPT-like CausalLM:

    - If Flax version exists (GPT-2, GPT-Neo), use Flax*ForCausalLM.
    - Else (e.g., LLaMA-3.1-8B), fall back to torchax on PyTorch.

    Returns:
      jax_forward(params_or_weights, input_ids, attention_mask),
      params_or_weights,
      input_ids,
      attention_mask
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token


    encoded = tokenizer(
        [vocab_example_text] * batch_size,
        padding="max_length",
        max_length=seq_len,
        truncation=True,
        return_tensors="jax",
    )
    input_ids = encoded["input_ids"]
    attention_mask = encoded["attention_mask"]

    # ---- Flax GPT-2 ----
    if model_name == "gpt2":
        model = FlaxGPT2LMHeadModel.from_pretrained(model_name)
        params = model.params

        def jax_forward(params, input_ids, attention_mask):
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                params=params,
                train=False,
            )
            return outputs.logits  # [batch, seq, vocab]

        print("Using Flax!")
        return jax_forward, params, input_ids, attention_mask

    # ---- Flax GPT-Neo ----
    if model_name == "EleutherAI/gpt-neo-1.3B":
        model = FlaxGPTNeoForCausalLM.from_pretrained(model_name)
        params = model.params

        def jax_forward(params, input_ids, attention_mask):
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                params=params,
                train=False,
            )
            return outputs.logits

        print("Using Flax!")
        return jax_forward, params, input_ids, attention_mask

    # ---- Fallback: PyTorch CausalLM via torchax (e.g., LLaMA-3.1-8B) ----
    pt_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=None, token=hf_token)
    weights, raw_func = torchax.extract_jax(pt_model)

    def jax_forward(weights, input_ids, attention_mask):
        # Disable cache and return_dict to avoid custom output types.
        logits, *_ = raw_func(
            weights,
            (),
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "use_cache": False,
                "return_dict": False,
            },
        )
        return logits

    print("Using PyTorch and torchax!")
    return jax_forward, weights, input_ids, attention_mask

In [None]:

# 4) NLP: GPT-2 small
print("\n================ NLP: GPT-2 small (Flax) ================")
gpt2_fn, gpt2_params, gpt2_ids, gpt2_mask = make_gpt_like_jax_fn(
    model_name="gpt2",
    batch_size=2,
    seq_len=64,
    vocab_example_text="The quick brown fox jumps over the lazy dog.",
)
profile_jax_function(
    "GPT-2-small-Flax",
    gpt2_fn,
    num_runs,
    gpt2_params,
    gpt2_ids,
    gpt2_mask,
)


In [None]:
# 5) NLP: GPT-Neo 1.3B
print("\n================ NLP: GPT-Neo-1.3B (Flax) ================")
gptneo_fn, gptneo_params, gptneo_ids, gptneo_mask = make_gpt_like_jax_fn(
    model_name="EleutherAI/gpt-neo-1.3B",
    batch_size=1,       # keep small, it's big
    seq_len=64,
    vocab_example_text="Dynamic pruning and batching for large language models.",
)
profile_jax_function(
    "GPT-Neo-1.3B-Flax",
    gptneo_fn,
    num_runs,
    gptneo_params,
    gptneo_ids,
    gptneo_mask,
)

In [None]:

# 6) NLP: LLaMA-3.1 8B
print("\n================ NLP: LLaMA-3.1-8B (torchax) ================")
llama_fn, llama_weights, llama_ids, llama_mask = make_gpt_like_jax_fn(
    model_name="meta-llama/Meta-Llama-3.1-8B-Instruct",
    batch_size=1,
    seq_len=64,
    vocab_example_text="Large language models can be pruned dynamically per prompt.",
)
profile_jax_function(
    "LLaMA-3.1-8B",
    llama_fn,
    num_runs,
    llama_weights,
    llama_ids,
    llama_mask,
)


In [None]:
def make_unet_jax_fn(
    model_name="runwayml/stable-diffusion-v1-5",
    batch_size=2,
    height=64,
    width=64,
    channels=4,
):
    """
    Build a Stable Diffusion UNet block.

    If FlaxUNet2DConditionModel is available, use that.
    Otherwise, fall back to PyTorch UNet2DConditionModel + torchax.
    """
    if not HAS_DIFFUSERS:
        raise ImportError("diffusers is not installed, cannot create UNet model.")

    # Dummy latent + timestep + encoder hidden states
    sample = jnp.ones(
        (batch_size, channels, height, width), dtype=jnp.float32
    )
    timesteps = jnp.array([1.0] * batch_size, dtype=jnp.float32)
    encoder_hidden_states = jnp.ones(
        (batch_size, 77, 768), dtype=jnp.float32
    )  # typical CLIP text shape

    # ---- Flax UNet if available ----
    try:
        flax_unet = FlaxUNet2DConditionModel.from_pretrained(
            model_name,
            subfolder="unet",
            from_pt=True  # load PT weights into Flax
        )
        params = flax_unet.params

        def jax_forward(params, sample, timesteps, encoder_hidden_states):
            outputs = flax_unet(
                sample=sample,
                timestep=timesteps,
                encoder_hidden_states=encoder_hidden_states,
                params=params,
                train=False,
            )
            # Flax UNet returns dict-like; we just use sample
            return outputs.sample

        print("Using Flax!")
        return jax_forward, params, sample, timesteps, encoder_hidden_states

    except Exception:
        # ---- Fallback: PyTorch UNet via torchax ----
        unet = UNet2DConditionModel.from_pretrained(
            model_name,
            subfolder="unet"
        )
        weights, raw_func = torchax.extract_jax(unet)

        def jax_forward(weights, sample, timesteps, encoder_hidden_states):
            (out_sample,) = raw_func(
                weights,
                (),
                {
                    "sample": sample,
                    "timestep": timesteps,
                    "encoder_hidden_states": encoder_hidden_states,
                    "return_dict": False,
                },
            )
            return out_sample

        print("Using PyTorch and torchax!")
        return jax_forward, weights, sample, timesteps, encoder_hidden_states

In [None]:
# Diffusion: UNet (Stable Diffusion)
print("\n================ Diffusion: UNet (Stable Diffusion) ================")
unet_fn, unet_params, sample, timesteps, enc_hid = make_unet_jax_fn(
    model_name="runwayml/stable-diffusion-v1-5",
    batch_size=1,
    height=64,
    width=64,
    channels=4,
)
profile_jax_function(
    "UNet-StableDiffusion",
    unet_fn,
    num_runs,
    unet_params,
    sample,
    timesteps,
    enc_hid,
)