# Introduction to JAX Neural Networks

JAX doesn't have a built-in neural network library like PyTorch or TensorFlow, but we can build one using its primitives. Libraries like Flax, Haiku, and Equinox are built on top of JAX.

In [None]:
import jax
import jax.numpy as jnp
from jax import random

## Manual Linear Layer

A linear layer `y = Wx + b` in JAX.

In [None]:
def linear(params, x):
    w, b = params
    return jnp.dot(x, w) + b

key = random.PRNGKey(0)
key_w, key_b = random.split(key)
w = random.normal(key_w, (3, 2))
b = random.normal(key_b, (2,))
params = (w, b)

x = jnp.ones((5, 3))
preds = linear(params, x)
print("Predictions shape:", preds.shape)

## Training Loop Concept

In JAX, we explicitly manage state (params). We use `jax.grad` to get gradients and update params.

In [None]:
def loss_fn(params, x, y):
    preds = linear(params, x)
    return jnp.mean((preds - y) ** 2)

# grad_fn = jax.grad(loss_fn)
# This returns a function that computes gradients w.r.t the first argument (params)