In [1]:
import jax

# JAX XLA Compiler
XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that plays a pivotal role in JAX's performance and flexibility. It enables Jax to generate optimized code for various hardware backends (CPUs, GPUs, TPUs) by transforming and compiling python code into efficient machine instructions.

Jax uses XLA's Just In Time (JIT) compilation to transform your Python functions into optimized XLA computations at runtime.

# Functional Programming & JAX

The miracles of JAX only manifest themselves when code is written in a purely functional paradigm.

*What are pure functions?*

A function is said to be "pure" if its return value is solely determined by its input parameters and it has no side effects.

# Installation & Setup
Follow the instruction here:
[Installation](https://docs.jax.dev/en/latest/installation.html)

# JAX Array programming

In [2]:
import jax.numpy as jnp

In [3]:
# We create an array from a python list
arr = jnp.array([1, 2, 3])
arr

Array([1, 2, 3], dtype=int32)

In [4]:
print(f'shape: {arr.shape}')
print(f'dtype: {arr.dtype}')
print(f'Device:{arr.devices()}')

shape: (3,)
dtype: int32
Device:{CudaDevice(id=0)}


## Immutability
JAX arrays are immutable; We can't assign values by items, instead we should take a proper functional approach

In [5]:
#arr[0] = 10   #This is going to throw an error

In [6]:
# Functional approach
new_arr = arr.at[0].set(10)
new_arr

Array([10,  2,  3], dtype=int32)

In [7]:
#Here's a cool demo of jax's workflow pattern

def update_array(arr):
  return arr.at[1:3].add(2).at[0].multiply(5)
update_array(arr)
# [1, 2, 3] --> [1*5, 2+2, 3+2] --> [5, 4, 5]

Array([5, 4, 5], dtype=int32)

## Device Placement

In [8]:
from jax import devices
print(devices()) #--> All available devices
#For example I have a cpu and a gpu, and I want to place the computations on the gpu

gpu_arr = jax.device_put(arr, devices('gpu')[0])
#Output: [CudaDevice(id=0)]


[CudaDevice(id=0)]


### Distributed Arrays

To distribute arrays across available devices we use the **jax.sharding** module. Sharding is splitting an array into smaller chunks that are disributed across devices. At the moment, we are just going to use **Positional Sharding**  which distributes array chunks based on their position in the array.

In [9]:
from jax.sharding import PositionalSharding

sharding = PositionalSharding(jax.devices())
distributed_arr = jax.device_put(jnp.arange(16))
distributed_arr

Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],      dtype=int32)

## Process management
Let's suppose that we do not intent to immediately compute an array. We can use the *block_until_ready* module

In [10]:
wait_arr = arr * 2 + 5

block_process = wait_arr.block_until_ready()

## XLA Optimization

we're going to compare between an optimized and unoptimized function.

In [40]:
def unoptimized_fn(x):
  return x @ x.T + jnp.diag(jnp.ones(x.shape[0]))

@jax.jit
def optimized_fn(x):
  return x @ x.T + jnp.diag(jnp.ones(x.shape[0]))

In [41]:
large_arr = jax.random.normal(jax.random.PRNGKey(1), (5000, 5000))

In [42]:
#Before going straight to the benchamrking code, we should warm-up the JIT compilation
_ = optimized_fn(large_arr)

In [43]:
import time
unop_start = time.time()
unop_opt = unoptimized_fn(large_arr)
unop_opt.block_until_ready()
unop_end = time.time()

print(f'Execution Time for the unoptimized function {unop_end - unop_start}s')

Execution Time for the unoptimized function 0.15828752517700195s


In [44]:
op_start = time.time()
op_opt =optimized_fn(large_arr)
op_opt.block_until_ready()
op_end = time.time()

print(f'Execution Time for the optimized function {op_end - op_start}s')

Execution Time for the optimized function 0.08936452865600586s
