## Static vs Traced Values and Operations

Traced operations or values are a wal for XLA Compiler to know whether a process needs to be recompiled or not. This is being practiced below:

In [1]:
import jax
import jax.numpy as jnp
from jax import jit
from jax import make_jaxpr


SyntaxError: invalid syntax (1669963077.py, line 3)

This will break, as the value of neg is not clear at the compile time and the result depends on the neg parameter. Therefore, it canot be traced, if we define it as a static parameter, then it will be skiped and XLA will only recompile it if the value changes, not the shape or type, but the value.

In [None]:
@jit
def f(x, neg):
  print('hello')
  return -x if neg else x

f(1, True)

Argument `static_argnums` indicates that the second parameter, (neg in this case) is static, in other words, this value is not traced and it will recompile if the value changes. Therefore, in this case, the shape and type of this parameter is not being traced.

In [None]:
from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  print('x: ', x)
  print('neg: ', neg)
  return -x if neg else x

f(1, True)


In [None]:
f(1, False)


## Static Operations:

Static operations are evaluated at compile time - traced operations are compiled and evaluated at runtime.

For example, jnp is when you want an opearion to be traced and use np if you want the operation to be static.

## Side effects and functional programming

In [None]:
g = 15.
def impure_uses_globals(x):
  return x + g


In [None]:
# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))


In [None]:
g = 10.
def impure_saves_global(x):
    global g
    g = x + g
    return x

print(jit(impure_saves_global)(4.))
print(g)

In [None]:
import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0


In [None]:
# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error


In [None]:
# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error


In [None]:
make_jaxpr(func11)(iter(range(16)), 5.)

### Parallelism across all the available devices with `pmap`

`vmap` vectorizes in one devices while pmap distributes a job across all the devices, or it's a device parallelism approach. It implements the `SPMD` but sharding the data across add devices and run a function on the sharded data:

In [None]:
import numpy as np
import jax.numpy as jnp
import jax

x = np.arange(5)
w = np.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)


In [None]:
n_devices = jax.local_device_count() 
xs = np.arange(5 * n_devices).reshape(-1, 5)
ws = np.stack([w] * n_devices)

xs, ws


In [None]:
# Now let's vectorize the convolve operation

jax.vmap(convolve)(xs, ws)


In [None]:
# this runs the function across devices:
jax.pmap(convolve)(xs, ws)


This returned array is already sharded across all devices, if we run another pmap operation on the returned array, the sharded data stays on the corresponding devices.

In [None]:
jax.pmap(convolve)(xs, jax.pmap(convolve)(xs, ws))


How to split the data across the nodes, `pmap` allows you to shard the data across the devices but the `in_axes` parameter.

The xs argument is split along its leading axis (axis 0), meaning each device gets a different slice of xs. The w argument, however, is broadcast to all devices (indicated by None in in_axes), meaning each device uses the same w value. This is an alternative approach to manually replicating w across devices, simplifying the code and potentially reducing memory usage.

The output shown is the result of this parallel computation, indicating that each device has correctly performed its part of the computation using its slice of xs and the broadcasted w, leading to an efficiently computed result.

In [None]:
jax.pmap(convolve, in_axes=(0, None))(xs, w)


For example, given this input matrix for xs:

```json
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19]])
```

Given we have 4 devices, the matrix sharding is as follows

```json
Device 1 would receive [0, 1, 2, 3, 4]
Device 2 would receive [5, 6, 7, 8, 9]
Device 3 would receive [10, 11, 12, 13, 14]
Device 4 would receive [15, 16, 17, 18, 19]
```