# Advanced JAX: Vectorization and Parallelization

JAX truly shines with its `vmap` (vectorizing map) and `pmap` (parallelizing map) transformations.

In [None]:
import jax
import jax.numpy as jnp

## Vectorization with `vmap`

Instead of writing loops, we can automatically vectorize a function.

In [None]:
def simple_func(x, y):
    return x + y

xs = jnp.arange(5)
ys = jnp.arange(5)

# Applying simple_func to batches manually usually requires rewriting the function
# or using loops.
# With vmap:

vectorized_func = jax.vmap(simple_func)
result = vectorized_func(xs, ys)
print("Vectorized result:", result)

## Auto-vectorizing complex functions

`vmap` is extremely powerful for batch processing in neural networks without managing batch dimensions manually in every function.