# Introduction to JAX Tensors

JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

In [None]:
import jax
import jax.numpy as jnp

## JAX Arrays

JAX arrays are immutable. This is a key difference from NumPy.

In [None]:
x = jnp.array([1, 2, 3])
print("JAX Array:", x)

# Operations are just like NumPy
y = x * 2
print("Operation Result:", y)

## Introduction to JIT (Just In Time compilation)

JAX can compile functions to XLA for speed.

In [None]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
selu_jit = jax.jit(selu)

# Run once to compile
selu_jit(x).block_until_ready()

print("Function compiled and run!")