# JAX Fundamentals and Transformations

This notebook introduces the fundamental concepts of JAX and its key transformations. We'll cover:
1. Automatic differentiation (grad)
2. Just-In-Time compilation (jit)
3. Vectorization (vmap)
4. Parallel computation (pmap)

In [None]:
import jax
import jax.numpy as jnp
import time

print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")

## 1. Automatic Differentiation

JAX makes it easy to compute gradients of functions.

In [None]:
def simple_nn(params, x):
    """A simple neural network forward pass."""
    w1, b1, w2, b2 = params
    h1 = jax.nn.relu(jnp.dot(x, w1) + b1)
    return jnp.dot(h1, w2) + b2

# Initialize parameters
key = jax.random.PRNGKey(0)
w1 = jax.random.normal(key, (2, 3))
b1 = jax.random.normal(key, (3,))
w2 = jax.random.normal(key, (3, 1))
b2 = jax.random.normal(key, (1,))
params = (w1, b1, w2, b2)

# Compute gradients
grad_fn = jax.grad(lambda p, x: jnp.sum(simple_nn(p, x)))
x = jnp.array([[1.0, 2.0]])
grads = grad_fn(params, x)

print("Gradients of first layer weights:")
print(grads[0])

## 2. Just-In-Time Compilation

JIT compilation can significantly speed up your code.

In [None]:
# Define a computation-heavy function
def slow_function(x):
    return jnp.sum(jnp.sin(x) ** 2 + jnp.cos(x) ** 2)

# Create a JIT-compiled version
fast_function = jax.jit(slow_function)

# Compare performance
x = jax.random.normal(key, (1000, 1000))

# Warm-up
_ = slow_function(x)
_ = fast_function(x)

# Time comparison
start = time.time()
_ = slow_function(x)
print(f"Regular: {time.time() - start:.4f} seconds")

start = time.time()
_ = fast_function(x)
print(f"JIT: {time.time() - start:.4f} seconds")

## 3. Vectorization with vmap

vmap allows you to vectorize functions that operate on single examples.

In [None]:
def single_example_fn(x):
    """Function that operates on a single example."""
    return jnp.sin(x) ** 2

# Create vectorized version
batch_fn = jax.vmap(single_example_fn)

# Test on batch of inputs
x_batch = jnp.array([1.0, 2.0, 3.0, 4.0])
result = batch_fn(x_batch)
print("Vectorized result:", result)

## 4. Parallel Computation with pmap

pmap enables parallel computation across multiple devices.

In [None]:
# Only runs if multiple devices are available
if len(jax.devices()) > 1:
    def parallel_fn(x):
        return jnp.sum(jnp.sin(x) ** 2)

    # Create parallel version
    parallel_mapped_fn = jax.pmap(parallel_fn)

    # Create data for each device
    n_devices = len(jax.devices())
    x_parallel = jax.random.normal(key, (n_devices, 1000))

    result = parallel_mapped_fn(x_parallel)
    print("Parallel computation result:", result)
else:
    print("This example requires multiple devices to run")

## 5. Combining Transformations

You can combine JAX transformations for powerful effects.

In [None]:
# Define a function that computes gradients for a batch
@jax.jit  # Make it fast
@jax.vmap  # Vectorize it
def batch_gradients(x):
    return jax.grad(lambda x: jnp.sum(jnp.sin(x) ** 2))(x)

# Test it
x_batch = jax.random.normal(key, (10, 5))
grads = batch_gradients(x_batch)
print("Combined transformation result shape:", grads.shape)