<a href="https://colab.research.google.com/github/durml91/State-Space-Models/blob/main/JAX_Pallas_explo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://astralord.github.io/posts/exploring-parallel-strategies-with-jax/

In [4]:
import os
import jax
import jax.numpy as jnp

In [5]:
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

In [6]:
jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

In [7]:
batch_size, embed_dim = 16, 8
x = jnp.zeros((batch_size, embed_dim))
print(x.device())

TFRT_CPU_0


  print(x.device())


In [8]:
# put array x on specific device
jax.device_put(x, jax.devices()[5]).device()

  jax.device_put(x, jax.devices()[5]).device()


CpuDevice(id=5)

In [10]:
from jax.sharding import PositionalSharding

sharding = PositionalSharding(jax.devices())

In [11]:
G = jax.local_device_count()

In [14]:
sharding.shape

(8,)

In [15]:
G

8

In [16]:
sharded_x = jax.device_put(x, sharding.reshape(1, G))

In [18]:
from jax.debug import visualize_array_sharding
import matplotlib as mpl

In [22]:
def visualise(tensor, color_map="Set3"):
  return visualize_array_sharding(tensor, color_map=mpl.colormaps[color_map])

In [23]:
visualise(sharded_x)

In [25]:
visualise(jax.device_put(x, sharding.reshape(2,4)))

In [26]:
visualise(jax.device_put(x, sharding.reshape(4,2)))

https://jameschen.io/jekyll/update/2024/02/12/mamba.html

In [None]:
!pip install einops

In [4]:
from einops import einsum
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jr

In [5]:
key = jr.PRNGKey(seed=2024)

In [6]:
B = 1 # batch size
L = 8192 # context length
N = 64 # hidden state size
D = 2 # num in channels
V = 1 # num out channels

In [7]:
# create log normal distributed array of size context_dim, num_in_dims
def generate_random_xs(key, num_inputs=L, num_channels=D):
  key, subkey = jr.split(key)
  xs = jr.lognormal(subkey, shape=(L, D))
  return key, xs

In [8]:
# generate A matrix of size - time variant therefore context length x square matrix of hidden state dim
def generate_random_As(key, num_inputs=L, state_size=N):
  key, subkey = jr.split(key)
  As = jr.lognormal(subkey, shape=(L, N, N))
  return key, As

In [9]:
# same idea as A amtrix but the num of columns is num_in channels given state space ODE
def generate_random_Bxs(key, num_inputs=L, state_size=N, num_channels=D):
  key, subkey = jr.split(key)
  Bxs = jr.lognormal(subkey, shape=(L, N, D))
  return key, Bxs

In [10]:
# matrix mutliply matrix b and xs
def get_bs(xs, Bxs):
  return einsum(Bxs, xs, "l n d, l d -> l n")

In [11]:
def extract(c, state_size):
  assert c.ndim == 1
  assert c.shape[0] == state_size * state_size + state_size
  return (
    c[:state_size * state_size].reshape((state_size, state_size)),
    c[-state_size:].reshape((state_size,))
  )

In [18]:
def operator(c_prev, c_curr, num_inputs=L, state_size=N, num_channels=D):
  prev_a, prev_b = extract(c_prev, state_size)
  curr_a, curr_b = extract(c_curr, state_size)
  # output weird operator
  return jnp.concatenate([
      jnp.ravel(curr_a @ prev_a),
      jnp.ravel(curr_a @ prev_b + curr_b)
  ])

vectorised_operator = jax.vmap(operator, in_axes=(0, 0), out_axes=0)

In [13]:
key, xs = generate_random_xs(key)
key, Bxs = generate_random_Bxs(key)
key, As = generate_random_As(key)

In [14]:
bs = get_bs(xs, Bxs)
cs = jnp.concatenate([As.reshape(-1, N*N), bs], axis=1)

In [19]:
# this computes prefix sum using associative binary operation
lax_scanned = lax.associative_scan(vectorised_operator, cs)[:, -N:]

In [20]:
def naive_scan_hs(h_0, As, Bxs, xs):
  output = [h_0]
  for a, bx, x in zip(As, Bxs, xs):
    b = einsum(bx, x, "n d, d -> n") # get b via matrix multiplication
    output.append(a @ output[-1] + b) # new hidden state
    return output[1:] # output all but h_0

In [21]:
naive_hs = jnp.vstack(
    naive_scan_hs(jnp.zeros((N,)), As, Bxs, xs)
)

In [22]:
jnp.allclose(naive_hs, lax_scanned)

Array(False, dtype=bool)