In [None]:
"""jax_ad.ipynb"""

# Cell 01 - Plot an estimate f'(x)

import matplotlib.pyplot as plt
import numpy as np


def f(x):
    return np.sin(x) * np.exp(-np.power(x, 2))


def df(x):
    def D4(f, x, h):
        """4th Order Centered Differences"""
        t1, t2 = -f(x + 2 * h), 8 * f(x + h)
        t3, t4 = -8 * f(x - h), f(x - 2 * h)
        return (t1 + t2 + t3 + t4) / (12 * h)

    """Richardson Extrapolation"""
    h = 0.01
    t1, t2 = D4(f, x, h), D4(f, x, h / 2)
    return (16 * t2 - t1) / 15


a, b = 0, np.pi
x = np.linspace(a, b, 500)

plt.plot(x, f(x), label=r"$f(x)=\sin(x)\cdot e^{-x^2}$")
plt.plot(x, df(x), label=r"$y'$")
plt.title("Automatic Differentiation Demo")
plt.legend(loc="upper right")
plt.xlim(a, b)
plt.grid("on")

In [None]:
# Cell 02 - Plot the exact f'(x)


def f_prime(x):
    return np.exp(-np.power(x, 2)) * (np.cos(x) - 2 * x * np.sin(x))


plt.plot(x, f(x), label=r"$f(x)=\sin(x)\cdot e^{-x^2}$")
plt.plot(x, df(x), label=r"$f(x)'=e^{-x^2}(\cos x-2x\;\sin x)$")
plt.scatter(x[::5], f_prime(x[::5]), color="k", s=5)
plt.title("Automatic Differentiation Demo")
plt.legend(loc="upper right")
plt.xlim(a, b)
plt.grid("on")

In [None]:
# Cell 03 - Plot f'(x) using JAX

# Automatic Differentiation (AD) is a set of techniques to
# compute derivatives accurately and efficiently by systematically
# applying the chain rule to elementary operations in a function

import jax
import jax.numpy as jnp

"""
JAX is high-performance numerical computing library built by Google Research
It combines the familiar NumPy API with:
  - Automatic differentiation (jax.grad, jax.jacobian, jax.vjp, etc.)
  - Just-in-time (JIT) compilation via jax.jit
  - Vectorization via jax.vmap
  - Parallelization via jax.pmap

JAX = "Just After XLA"
    - XLA stands for Accelerated Linear Algebra,
      which is a domain-specific compiler originally developed for TensorFlow
    - JAX uses XLA under the hood to compile and optimize numerical computations
      for speed and efficiency on CPUs, GPUs, and TPUs

pip install "jax[cpu]"

"""


def f(x):
    return jnp.sin(x) * jnp.exp(-(x**2))


f_prime = jax.grad(f)

f_vals = jax.vmap(f)(x)
df_vals = jax.vmap(f_prime)(x)

plt.plot(x, f_vals, label=r"$f(x)=\sin(x)\cdot e^{-x^2}$")
plt.plot(x, df_vals, label=r"$f'(x)$ as per JAX AD")
plt.title("Automatic Differentiation Demo")
plt.legend(loc="upper right")
plt.xlim(a, b)
plt.grid("on")


In [None]:
# Cell 04 - Display the reasoning tree

# JAXPR = JAX Primitive Operations
# It is the intermediate representation (IR) of a function
# composed entirely of primitive operations that JAX understands,
# like add, mul, sin, exp, etc.

print(jax.make_jaxpr(f_prime)(1.0))

"""
{ lambda ; a:f32[]. let
    b = sin a                  # b = sin(x)
    c = cos a                  # c = cos(x)
    d = integer_pow[y=2] a     # d = x^2
    e = integer_pow[y=1] a     # e = x
    f = mul 2.0 e              # f = 2x
    g = neg d                  # g = -x^2
    h = exp g                  # h = e^{-x^2}
    _ = mul b h                # f(x) = sin(x) * e^{-x^2} (ignored here, just traced)
    
    # Gradient path begins:
    i = mul b 1.0              # i = sin(x)
    j = mul 1.0 h              # j = e^{-x^2}
    k = mul i h                # k = sin(x) * e^{-x^2}
    l = neg k                  # l = -sin(x) * e^{-x^2}
    m = mul l f                # m = -2x * sin(x) * e^{-x^2}
    n = mul j c                # n = e^{-x^2} * cos(x)
    o = add_any m n            # o = e^{-x^2} * cos(x) - 2x * sin(x) * e^{-x^2}
  in (o,) }

"""


In [None]:
# Cell 05 - Display all JAX Primitives

import jax.extend.core
import jax.lax


def get_primitives(module):
    prims = set()
    for name in dir(module):
        obj = getattr(module, name)
        # Functions wrapping a primitive usually have this
        if hasattr(obj, "primitive"):
            prims.add(obj.primitive.name)
        # Or they may *be* the primitive (rare)
        elif isinstance(obj, jax.extend.core.Primitive):
            prims.add(obj.name)
    return sorted(prims)


lax_primitives = get_primitives(jax.lax)
print("Primitives found in jax.lax:")
print(*lax_primitives, sep="\n")