In [1]:
import jax
import jax.numpy as jnp
from jax import random



In [2]:
master_key = random.PRNGKey(0)
num_runs = 10

In [3]:
def gemm(a, b):
  return jnp.matmul(a, b)

In [4]:
def ffn(x, w1, b1, w2, b2):
  """
  Feed Forward Network:
    Input: (128, 64, 256)
    w1=(256,1024), b1=(1024)
    w2=(1024,256), b2=(256)
  """
  h = jnp.matmul(x, w1) + b1
  h = jax.nn.relu(h)
  return jnp.matmul(h, w2) + b2

In [5]:
def mha(x, w_q, w_k, w_v, w_o, num_heads=16):
    """
    Multi-Head Self-Attention (MHA) benchmark.

    Args:
        x: Input tensor of shape (batch, seq_len, embed_dim)
        w_q, w_k, w_v: Projection weights for Q, K, V (embed_dim, embed_dim)
        w_o: Output projection weight (embed_dim, embed_dim)
        num_heads: Number of attention heads (default 16)
    Returns:
        Tensor of shape (batch, seq_len, embed_dim)
    """

    batch, seq_len, embed_dim = x.shape
    head_dim = embed_dim // num_heads

    # Get query, key, and value vectors for each token.
    q = jnp.dot(x, w_q)
    k = jnp.dot(x, w_k)
    v = jnp.dot(x, w_v)

    # Reshape to (batch, num_heads, seq_len, head_dim).
    def reshape_heads(t):
        return t.reshape(batch, seq_len, num_heads, head_dim).transpose(0, 2, 1, 3)

    q, k, v = map(reshape_heads, (q, k, v))

    # Scaled dot product attention per head.
    d_k = head_dim
    scores = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / jnp.sqrt(d_k)   # (B, H, S, S)
    attn_weights = jax.nn.softmax(scores, axis=-1)
    attn_output = jnp.matmul(attn_weights, v)                         # (B, H, S, d_k)

    # Recombine heads.
    attn_output = attn_output.transpose(0, 2, 1, 3).reshape(batch, seq_len, embed_dim)

    # Project back input embedding dimension.
    output = jnp.dot(attn_output, w_o)
    return output

In [6]:
def conv_layer(x, w):
  """
    N = batch size
    C = input channels
    H = input height
    W = input width

    O = output channels (number of filters)
    I = input channels (like RGB)
    H = kernel height
    W = kernel width

    x = input tensor of shape (N, I, H, W) (input height and width)
    w = weight tensor of shape (O, I, H, W) (kernel height and width)
  """
  return jax.lax.conv_general_dilated(
      lhs=x, rhs=w,
      window_strides=(1, 1),
      padding='SAME',
      dimension_numbers=('NCHW', 'OIHW', 'NCHW')
  )


In [7]:
!pip install graphviz

from pathlib import Path
import re
from graphviz import Digraph

def visualize_mlir(mlir_path, func_name, graph_path):
  with open(mlir_path) as f:
      text = f.read()

  # matches like: %2 = stablehlo.dot_general %0, %1, ...
  pattern = re.compile(r"(%\w+)\s*=\s*([\w\.]+)\s+(.*)")
  ops = []
  for line in text.splitlines():
      line = line.strip()
      m = pattern.match(line)
      if not m:
          continue
      out, op, rest = m.groups()
      # Extract inputs like %0, %1
      inputs = re.findall(r"%\w+", rest)
      ops.append((out, op, inputs))

  g = Digraph("StableHLO", format="png")
  g.attr(rankdir="LR")

  for out, op, inputs in ops:
      g.node(out, f"{out}\n{op}")

      for inp in inputs:
          g.edge(inp, out)

  for arg in re.findall(r"(%arg\d+):", text):
      g.node(arg, f"{arg}\ninput", shape="box")

  ret = re.search(r"return\s+(.*)\s*:", text)
  if ret:
      outs = re.findall(r"%\w+", ret.group(1))
      for o in outs:
          g.edge(o, "return")
      g.node("return", "return", shape="box")

  g.render(graph_path, view=True, cleanup=True)


Collecting graphviz
  Downloading graphviz-0.21-py3-none-any.whl.metadata (12 kB)
Downloading graphviz-0.21-py3-none-any.whl (47 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/47.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.3/47.3 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: graphviz
Successfully installed graphviz-0.21


In [8]:
import time, math
import jax, os
import jax.numpy as jnp
import shutil, glob, gzip
import numpy as np

import json

import gc

gc.collect()
jax.clear_caches()

def extract_tpu_ops(trace_json):
    fields = ["name",  "timestamp_ns", "duration_ns",
              "bytes_accessed", "raw_bytes_accessed",
              "device_duration_ps", "device_offset_ps",
              "hlo_category", "model_flops", "long_name"
              "shape_with_layout", "tf_op"]

    tpu_ops = []

    for event in trace_json.get("traceEvents", []):
        if event.get("ph") != "X":
            continue

        args = event.get("args", {})
        if "device_duration_ps" not in args and "hlo_category" not in args:
            continue

        op = {
            "name": event.get("name", ""),
            "timestamp_ns": event.get("ts"),
            "duration_ns": event.get("dur"),
        }

        # Add all optional fields from args
        for field in ["bytes_accessed", "raw_bytes_accessed", "device_duration_ps",
                      "device_offset_ps", "hlo_category", "model_flops",
                      "shape_with_layout", "long_name", "tf_op"]:
            op[field] = args.get(field)

        tpu_ops.append(op)

    lines = []
    for i, op in enumerate(tpu_ops, 1):
        lines.append(f"--- TPU Op {i} ---")
        for k, v in op.items():
            lines.append(f"{k:20}: {v}")
    return "\n".join(lines)

# Return function call time in ms.
def profile_func_call(func_name, compiled_func, get_trace, *args):
    logs_dir = f"logs/{func_name}"

    # First call will include compile time.
    t_s = time.perf_counter()
    compiled_func(*args).block_until_ready()
    t_e = time.perf_counter()

    # Generate trace.
    if get_trace:
      jax.profiler.start_trace(logs_dir)
      compiled_func(*args).block_until_ready()
      jax.profiler.stop_trace()

    return (t_e - t_s) * 1000

def write_trace_summary(root_profile_dir):
    trace_files = glob.glob(os.path.join(root_profile_dir, "**", "*.json.gz"), recursive=True)

    with gzip.open(trace_files[0], "rt", encoding="utf-8") as f:
        trace_json = json.load(f)

    summary_text = extract_tpu_ops(trace_json)

    summary_path = os.path.join(root_profile_dir, "summary.txt")
    with open(summary_path, "w") as f:
        f.write(summary_text)

def profile_jax_function(func_name, func, num_runs, *args):
    print(f"Profiling {func_name}.\n")

    dir_to_remove = f"logs/{func_name}"
    if os.path.exists(dir_to_remove):
      shutil.rmtree(dir_to_remove)

    compiled_func = jax.jit(func)

    first_call_time = profile_func_call(func_name, compiled_func, True, *args)
    print(f"Run 0: {first_call_time:.3f} ms.")

    all_times = []
    for i in range(num_runs):
        run_time = profile_func_call(func_name, compiled_func, False, *args)
        all_times.append(run_time)
        print(f"Run {i+1}: {run_time:.3f} ms")

    avg_time = np.mean(all_times)
    std_time = np.std(all_times)
    print(f"\nAverage over {num_runs} runs: {avg_time:.3f} ms ± {std_time:.3f} ms.")

    root_profile_dir = f"{dir_to_remove}/plugins/profile"
    write_trace_summary(root_profile_dir)

    stablehlo_ir = compiled_func.lower(*args).compiler_ir("stablehlo")
    mlir_path = f"{root_profile_dir}/stablehlo.txt"
    graph_path = f"{root_profile_dir}/stablehlo"

    with open(mlir_path, "w") as f:
        f.write(str(stablehlo_ir))

    visualize_mlir(mlir_path, func_name, graph_path)

In [9]:
keys = random.split(master_key, 2)

A = random.normal(keys[0], (128, 256))
B = random.normal(keys[1], (256, 1024))

profile_jax_function("GEMM", gemm, num_runs, A, B)

# Trace had two steps: copy-done and fusion.
# Copy seemed to only moves tensor B (GPT says it's rearranging the tensor for more efficient computation).
  # %copy-done = f32[256,1024]{1,0:T(8,128)S(1)} copy-done((f32[256,1024]{1,0:T(8,128)S(1)}, f32[256,1024]{1,0:T(8,128)}, u32[]{:S(2)}) %copy-start)
# The fusion read all of A and B and wrote the output.
  # %fusion = f32[128,1024]{1,0:T(8,128)} fusion(f32[128,256]{1,0:T(8,128)} %Arg_0.1, f32[256,1024]{1,0:T(8,128)S(1)} %copy-done), kind=kOutput, calls=%fused_computation

Profiling GEMM.

Run 0: 55.935 ms.
Run 1: 0.387 ms
Run 2: 0.246 ms
Run 3: 0.231 ms
Run 4: 0.202 ms
Run 5: 0.208 ms
Run 6: 0.193 ms
Run 7: 0.180 ms
Run 8: 0.188 ms
Run 9: 0.188 ms
Run 10: 0.188 ms

Average over 10 runs: 0.221 ms ± 0.059 ms.


In [10]:
keys = random.split(master_key, 5)

x = random.normal(keys[0], (128, 64, 256))
w1 = random.normal(keys[1], (256, 1024))
b1 = random.normal(keys[2], (1024,))
w2 = random.normal(keys[3], (1024, 256))
b2 = random.normal(keys[4], (256,))

profile_jax_function("FFN", ffn, num_runs, x, w1, b1, w2, b2)

# First matrix multiplication is broken up into fusion and copy.1
  # %copy-done = f32[256,1024]{1,0:T(8,128)S(1)} copy-done((f32[256,1024]{1,0:T(8,128)S(1)}, f32[256,1024]{1,0:T(8,128)}, u32[]{:S(2)}) %copy-start)
  # %fusion = f32[128,64,1024]{2,0,1:T(8,128)S(1)} fusion(f32[128,64,256]{2,1,0:T(8,128)} %Arg_0.1, f32[256,1024,1]{1,0,2:T(8,128)S(1)} %bitcast.8), kind=kOutput, calls=%fused_computation
  # %copy.1 = f32[128,64,1024]{2,1,0:T(8,128)} copy(f32[128,64,1024]{2,0,1:T(8,128)S(1)} %fusion)
# We add the first bias.
  # %broadcast_add_fusion = f32[128,64,1024]{2,1,0:T(8,128)} fusion(f32[128,64,1024]{2,1,0:T(8,128)} %Arg_0.1, f32[1024]{0:T(1024)} %Arg_1.2), kind=kLoop, calls=%fused_computation
# Some stuff happens on the CPU in between.
  # %broadcast_maximum_fusion = f32[128,64,1024]{2,1,0:T(8,128)} fusion(f32[128,64,1024]{2,1,0:T(8,128)} %Arg_0.1), kind=kLoop, calls=%fused_computation
# Then the second matmul.
  # %fusion.3 = f32[128,64,256]{2,1,0:T(8,128)} fusion(f32[128,64,1024]{2,1,0:T(8,128)S(1)} %copy-done, f32[1024,256,1]{1,0,2:T(8,128)} %bitcast.9), kind=kOutput, calls=%fused_computation.2
# Then the second bias is added.
  # %broadcast_add_fusion = f32[128,64,256]{2,1,0:T(8,128)} fusion(f32[128,64,256]{2,1,0:T(8,128)} %Arg_0.1, f32[256]{0:T(256)} %Arg_1.2), kind=kLoop, calls=%fused_computation

# Maybe missing some details.

Profiling FFN.

Run 0: 565.346 ms.
Run 1: 0.345 ms
Run 2: 0.273 ms
Run 3: 0.243 ms
Run 4: 0.234 ms
Run 5: 0.200 ms
Run 6: 0.210 ms
Run 7: 0.196 ms
Run 8: 0.197 ms
Run 9: 0.199 ms
Run 10: 0.211 ms

Average over 10 runs: 0.231 ms ± 0.045 ms.


In [11]:
keys = random.split(master_key, 5)

batch, seq, embed, heads = 128, 64, 256, 16

x = random.normal(keys[0], (batch, seq, embed))
w_q = random.normal(keys[1], (embed, embed))
w_k = random.normal(keys[2], (embed, embed))
w_v = random.normal(keys[3], (embed, embed))
w_o = random.normal(keys[4], (embed, embed))

profile_jax_function("MHA", mha, num_runs, x, w_q, w_k, w_v, w_o)

Profiling MHA.

Run 0: 1015.022 ms.
Run 1: 0.503 ms
Run 2: 0.442 ms
Run 3: 0.428 ms
Run 4: 0.400 ms
Run 5: 0.407 ms
Run 6: 0.411 ms
Run 7: 0.417 ms
Run 8: 0.404 ms
Run 9: 0.391 ms
Run 10: 0.397 ms

Average over 10 runs: 0.420 ms ± 0.031 ms.


In [12]:
# using built-in JAX convolution (probably why it's faster)

keys = random.split(master_key, 2)

batch, in_channels, input_height, input_width = 128, 3, 12, 12
out_channels, kernel_height, kernel_width = 3, 3, 3

# padding and stride are 1 (implicit in the function)

x = random.normal(keys[0], (batch, in_channels, input_height, input_width))   # NCHW
w = random.normal(keys[1], (out_channels, in_channels, kernel_height, kernel_width))       # OIHW

profile_jax_function("Conv", conv_layer, num_runs, x, w)

Profiling Conv.

Run 0: 128.949 ms.
Run 1: 0.322 ms
Run 2: 0.244 ms
Run 3: 0.188 ms
Run 4: 0.171 ms
Run 5: 0.158 ms
Run 6: 0.139 ms
Run 7: 0.121 ms
Run 8: 0.119 ms
Run 9: 0.119 ms
Run 10: 0.125 ms

Average over 10 runs: 0.171 ms ± 0.063 ms.
