In [2]:
import jax
import jax.numpy as jnp


# Performance Optimization
## XLA Compiler Insights
JAX relies on XLA (Accelerated Linear Algebra) to optimize and compile numerical computations.

### Fusion Optimization
Fusion is a key optimization technique where multiple operations are combined into a single operation to reduce memory transfers and improve performance

In [3]:
def without_fusion(x):
  a = jnp.sin(x)
  b = jnp.cos(x)
  return a + b

In this example -without fusion- each operation would require multiple device memory transfers.

In [4]:
@jax.jit
def with_fusion(x):
  a = jnp.sin(x)
  b = jnp.cos(x)
  return a + b

In this case, **jax.jit** triggers XLA compilation and enables more aggressive fusion, XLA combines these operations into one kernel which reduce memory transfers and improving performance.
Note: Not all operations can be fused

### Memory layout management

JAX manages memory with these features: immutability, just-in-Time compilation==> forces memory layouts to be determined ahead of time.
JAX also manages device memory by handling transfers between CPU and accelerators (GPU/TPU) automatically. It also leverages XLA for memory optimizations.

### JAX Sharding: Distributed Array Storage

Sharding in JAX refers to splitting arrays across multiple devices (Like GPUs, or TPUs) to enable parallel computation.

In [10]:
# This code my throw an error it depends on your system, the configuration here is for 2*T4 GPU
import jax
import numpy as np
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P

# Create a 2D mesh of devices
devices = jax.devices()
mesh = Mesh(np.array(devices).reshape(1,2), ('batch', 'model'))
print(mesh)
# Define how to partition the array
spec = P('batch', 'model')  # Shard along both dimensions

# Create and shard an array
x = jnp.ones((16, 8))
x_sharded = jax.device_put(x, jax.sharding.NamedSharding(mesh, spec))

ValueError: cannot reshape array of size 1 into shape (1,2)