## Profiling jax.jit

### 🚀 Introduction to `jax.jit`

`jax.jit` stands for **Just-In-Time compilation**. It’s one of the core transformations in JAX that dramatically **speeds up execution** by compiling Python functions into optimized XLA (Accelerated Linear Algebra) code.

### 🔧 What it does:
- **Compiles** your function the first time it runs.
- **Caches** the compiled version for reuse with the same input shapes/dtypes.
- **Reduces Python overhead**, enabling high-performance execution on **CPUs, GPUs, and TPUs**.

### 📈 Why use it:
- Massive speedups for computationally heavy functions (e.g., matrix multiplications, neural networks).
- Essential for **LLM training**, **model parallelism**, and **inference at scale**.
- Works seamlessly with JAX's functional programming style and other transformations (`grad`, `vmap`, etc.).

### ⚠️ Watch out:
- Compilation takes time (noticeable on the **first call**).
- Changing input **shapes or dtypes** can trigger recompilation.
- Python-side code (like print statements) may be ignored or delayed during JIT execution.

> 💡 Pro Tip: Use `jax.jit` for **hot loops** or functions called repeatedly to benefit from its full power.

In [2]:
# jit_profiling.ipynb

# Cell 1: Imports and Setup
import jax
import jax.extend
import jax.numpy as jnp
import numpy as np
import time
import gc
from jax.lib import xla_bridge

print("JAX Device:", jax.extend.backend.get_backend().platform)

JAX Device: cpu


In [3]:
# Cell 2: Helper — Memory Usage Utility (Only CPU-compatible)
import os, psutil

def get_memory_usage_mb():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024  # in MB

In [4]:
# Cell 3: Define a Simple Function (Matrix Multiplication)
def matmul(x, y):
    return jnp.dot(x, y)

# Create test inputs
size = 2048
x = jnp.ones((size, size), dtype=jnp.float32)
y = jnp.ones((size, size), dtype=jnp.float32)

In [5]:
# Cell 4: Profile Un-JITed Execution
gc.collect()

start_mem = get_memory_usage_mb()
start_time = time.time()

out = matmul(x, y).block_until_ready()

end_time = time.time()
end_mem = get_memory_usage_mb()

print("Un-JITed Execution Time:", end_time - start_time, "seconds")
print("Memory Usage:", end_mem - start_mem, "MB")

Un-JITed Execution Time: 0.06286382675170898 seconds
Memory Usage: 24.796875 MB


In [6]:
# Cell 5: Profile JIT-Compiled Execution (First Call Includes Compilation)
gc.collect()

jit_matmul = jax.jit(matmul)

start_mem = get_memory_usage_mb()
start_time = time.time()

out = jit_matmul(x, y).block_until_ready()

end_time = time.time()
end_mem = get_memory_usage_mb()

print("JITed (First Call) Execution Time:", end_time - start_time, "seconds")
print("Memory Usage:", end_mem - start_mem, "MB")

JITed (First Call) Execution Time: 0.06277608871459961 seconds
Memory Usage: 21.859375 MB


In [7]:
# Cell 6: Profile JIT-Compiled Execution (Second Call Shows Runtime Gains)
gc.collect()

start_mem = get_memory_usage_mb()
start_time = time.time()

out = jit_matmul(x, y).block_until_ready()

end_time = time.time()
end_mem = get_memory_usage_mb()

print("JITed (Second Call) Execution Time:", end_time - start_time, "seconds")
print("Memory Usage:", end_mem - start_mem, "MB")

JITed (Second Call) Execution Time: 0.04286909103393555 seconds
Memory Usage: 16.84375 MB


In [8]:
# Cell 7: Interpretation
from IPython.display import Markdown

Markdown("""
### 📊 Interpretation

- **First JIT Run** shows **high latency** due to compilation.
- **Second JIT Run** is **significantly faster**, demonstrating `jit`'s value for repeated execution.
- Memory usage differences can vary by platform (TPU vs GPU vs CPU) and compilation artifacts.

**💡 Tip**: Use `jax.jit` only for functions that run multiple times to amortize compilation cost.
""")


### 📊 Interpretation

- **First JIT Run** shows **high latency** due to compilation.
- **Second JIT Run** is **significantly faster**, demonstrating `jit`'s value for repeated execution.
- Memory usage differences can vary by platform (TPU vs GPU vs CPU) and compilation artifacts.

**💡 Tip**: Use `jax.jit` only for functions that run multiple times to amortize compilation cost.
