# Murat's JAX tutorial notebook



## 0) Setup & Installation (quick notes)

JAX has a different instalation process for GPU (specifically CUDA, hence you need a nvidia card), CPU (works for everything else), and hypothetically the Apple Metal (for apple M processors GPU, but at this stage it is in the beta level, and at least I ran into a lot of issues trying to make it run on my mac, so I sticked with CPU)

- **CPU‑only (any platform):**
  ```bash
  pip install -U jax
  ```
- **NVIDIA GPU (CUDA):** you need a CUDA‑compatible wheel that matches your CUDA/cuDNN.
  ```bash
  pip install -U "jax[cuda12]"
  ```
- **TPU (Google special processors that are specifically designed to run ML code)**
  ```bash
  pip install -U "jax[tpu]"
  ```

### quick note, **jax.numpy as jnp** is practically a 1-1 copy of numpy, meaning that you can use all the same functions as in numpy, just with a jnp prefix instead

In [3]:
import math, time, functools, sys, platform
import numpy as np

import jax
import jax.numpy as jnp
from jax import jit, grad, value_and_grad, vmap, random

import matplotlib.pyplot as plt

print("Python:", sys.version.split()[0])
print("Platform:", platform.platform())
print("JAX:", jax.__version__)
print("JAX devices:", jax.devices())
print("Backend:", jax.default_backend())

Python: 3.12.3
Platform: macOS-15.6.1-arm64-arm-64bit
JAX: 0.5.0
JAX devices: [CpuDevice(id=0)]
Backend: cpu


## 1) JAX arrays: feel like NumPy, run on accelerators

**Key ideas**
- Use `jax.numpy` as `jnp` (NumPy‑like API).
- Arrays live on device (CPU/GPU/TPU) and are **decoders** (functional style).
- JAX does not support a dynamic array change, so instead you ave to use `.at[...].set/add/mul` instead of in‑place writes.
- Favor **pure functions**: same inputs ⇒ same outputs, no hidden backgroudn or global running states. This is due to the fact that JAX is basically JIT `Just In Time` and makes stuff like autograd very efficient

In [4]:
# Create arrays, notice that logic is very similar to numpy and/or matlab
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.arange(3)            # 0,1,2
c = jnp.linspace(0, 1, 5)    # 5 points from 0 to 1

print("a:", a)
print("b dtype:", b.dtype, "| c:", c)

# Broadcasting like NumPy and/or matlab
print("a + 10:", a + 10)
print("a * b:", a * b)

# No dynamic change, instead use .at to update
x = jnp.zeros((3,))
x2 = x.at[1].set(7.0)
print("x original:", x)
print("x2 updated:", x2)

a: [1. 2. 3.]
b dtype: int32 | c: [0.   0.25 0.5  0.75 1.  ]
a + 10: [11. 12. 13.]
a * b: [0. 2. 6.]
x original: [0. 0. 0.]
x2 updated: [0. 7. 0.]


## 2) Automatic differentiation: `grad` and `value_and_grad`

`grad(f)` returns a function that computes **∂f/∂x**. Your function must be **pure** and operate on JAX arrays.

In [5]:
def f_scalar(x):
    return 3.0 * x**2 + 2.0 * x + 1.0

df = grad(f_scalar)

x0 = 5.0
print("f(x0) =", f_scalar(x0))
print("df/dx at x0 =", df(x0))

# Works on vectorized functions via vmap (next section), or write f to accept vectors.

f(x0) = 86.0
df/dx at x0 = 32.0


## 3) JIT compilation: `@jit`

Bascially before every high-level/high-energy computnig use `jit` to compile to XLA for big speedups (especially for large arrays / loops). The first time the function runs is slow, but the subsequent runs will be much faster. This is very helpful since python itself is a very slow language when it comes down to calculation

Example: the first call includes compile time; subsequent calls are fast.

In [8]:
import time
import jax.numpy as jnp
from jax import jit

# Function without JIT
def heavy_poly_nojit(x):
    y = 0.0
    for i in range(200):
        y = y + (i + 1) * x**2 - (i - 1) * x + 3.0
    return y

# Function with JIT
@jit
def heavy_poly_jit(x):
    y = 0.0
    for i in range(200):
        y = y + (i + 1) * x**2 - (i - 1) * x + 3.0
    return y

# Input data
x_demo = jnp.linspace(0, 1, 10_000)

# --- Run without JIT ---
t0 = time.time()
_ = heavy_poly_nojit(x_demo).block_until_ready() # block_until_ready just ensures that we capture
                                                # the true run time from CPU/GPUT, 
                                                # we don't need to use it in our research
t1 = time.time()

# --- Run with JIT (first call: compile + run) ---
t2 = time.time()
_ = heavy_poly_jit(x_demo).block_until_ready()
t3 = time.time()

# --- Run with JIT (second call: cached run) ---
t4 = time.time()
_ = heavy_poly_jit(x_demo).block_until_ready()
t5 = time.time()

print(f"Without JIT:               {t1 - t0:.6f} s")
print(f"With JIT (1st call):       {t3 - t2:.6f} s")
print(f"With JIT (2nd call):       {t5 - t4:.6f} s")


Without JIT:               0.018598 s
With JIT (1st call):       0.182805 s
With JIT (2nd call):       0.000354 s


## 5) Batch without for‑loops: `vmap`

`vmap` automatically vectorizes a function across a batch dimension.

In [10]:
# A simple function: square and add 1
def simple_fn(x):
    return x**2 + 1

# A single input
print("Single input:", simple_fn(3.0))   # → 10.0

# A batch of inputs
xs = jnp.arange(5.0)   # [0, 1, 2, 3, 4]

# --- Without vmap (manual loop) ---
results_loop = jnp.array([simple_fn(x) for x in xs])
print("With loop:", results_loop)

# --- With vmap (automatic batching) ---
results_vmap = vmap(simple_fn)(xs)
print("With vmap:", results_vmap)

Batched logits shape: (1000,)
Single input: 10.0
With loop: [ 1.  2.  5. 10. 17.]
With vmap: [ 1.  2.  5. 10. 17.]


### **now into more complicated example**

In [15]:
import jax.numpy as jnp
from jax import random, vmap

# Define a function for one input vector
def single_logit(params, x):
    """
    Compute a single logit: w·x + b
    - params is a tuple (w, b):
        w = weight vector
        b = bias (a scalar)
    - x is one input vector
    """
    w, b = params
    return jnp.dot(w, x) + b  # dot product = sum of w[i] * x[i]

# -------------------------------
# Create random parameters and data
# -------------------------------
key = random.PRNGKey(42)         # PRNG key (needed in JAX for reproducible randomness)
key, kW, kX = random.split(key, 3)  # split into subkeys

# Generate random weights for 3 features
w = random.normal(kW, (3,))      # shape (3,)
b = 0.5                          # bias is just a number
params = (w, b)                  # bundle them into a tuple

# Generate a dataset of 1000 input vectors, each of size 3
X = random.normal(kX, (1000, 3)) # shape (1000, 3)

# -------------------------------
# Apply single_logit to every row of X
# -------------------------------

# Normally, we would loop over rows of X like this:
# logits = jnp.array([single_logit(params, x) for x in X])

# Instead, vmap does this automatically and efficiently:
batched_logits = vmap(lambda x: single_logit(params, x))(X)

print("Batched logits shape:", batched_logits.shape)
# -> (1000,)   because each of the 1000 inputs produces one logit


Batched logits shape: (1000,)


## 6) Pytrees (nested parameter containers)

A **pytree** is any nested structure of lists/tuples/dicts of arrays. JAX knows how to map/stack/flatten them.

**tuple** exist only in python, and it is bascially an array of different types of variables, and it can very by dimensions as well 
Common ops:
- `jax.tree_util.tree_map(fn, pytree)`
- `jax.tree_util.tree_flatten(pytree)`
- People usually store NN weights as a dict of arrays.

In [8]:
from jax import tree_util

params = {
    "layer1": {"W": jnp.ones((2, 3)), "b": jnp.zeros((3,))},
    "layer2": {"W": jnp.ones((3, 1))*2, "b": jnp.array([0.1])},
}

# Apply a function to every array in the pytree
scaled = tree_util.tree_map(lambda x: x * 0.5, params)

leaves, treedef = tree_util.tree_flatten(params)
print("Num leaves:", len(leaves))
print("First leaf shape:", leaves[0].shape)

# Restore from leaves + structure
restored = tree_util.tree_unflatten(treedef, leaves)
assert all([jnp.allclose(a, b) for a,b in zip(leaves, tree_util.tree_flatten(restored)[0])])
print("Restored pytree OK.")

Num leaves: 4
First leaf shape: (2, 3)
Restored pytree OK.
