In [1]:
import os
import numpy as np
import jax

#os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

sharding = jax.sharding.PositionalSharding(jax.devices())
sharding

PositionalSharding([{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 4} {TPU 5} {TPU 6} {TPU 7}], memory_kind=device, shape=(8,))

In [2]:
import jax.numpy as jnp

B, E = 16, 8
G = jax.local_device_count()

print("no sharding")
x = jnp.zeros((B, E))
jax.debug.visualize_array_sharding(x)

print("column wise sharding")
sharded_x = jax.device_put(x, sharding.reshape(1, G))
jax.debug.visualize_array_sharding(sharded_x)

print("replication")
replicated_x = jax.device_put(x, sharding.replicate(0))
jax.debug.visualize_array_sharding(replicated_x)

print("2D sharding")
twod_x = jax.device_put(x, sharding.reshape(2, G//2))
jax.debug.visualize_array_sharding(twod_x)

print("2D sharding - replication across row axis")
combined_x = jax.device_put(x, sharding.reshape(2, G//2).replicate(0))
jax.debug.visualize_array_sharding(combined_x)

print("2D sharding - replication across column axis")
combined_x = jax.device_put(x, sharding.reshape(2, G//2).replicate(1))
jax.debug.visualize_array_sharding(combined_x)

no sharding


column wise sharding


replication


2D sharding


2D sharding - replication across row axis


2D sharding - replication across column axis


In [32]:
from typing import NamedTuple, List
from jax.typing import ArrayLike

class Params(NamedTuple):
    w1: jnp.ndarray
    w2: jnp.ndarray

def ffn(x: jnp.array, params: Params):
    y = jnp.maximum(x @ params.w1, 0)
    return y @ params.w2

def init_ffn_weights(embed_dim: int, hidden_dim: int, rng: ArrayLike):
    '''
    Create FFN weights with Xavier initialization
    '''
    std = jnp.sqrt(2/(embed_dim + hidden_dim))
    w1_key, w2_key = jax.random.split(rng)
    w1 = std * jax.random.normal(w1_key, (embed_dim, hidden_dim))
    w2 = std * jax.random.normal(w2_key, (hidden_dim, embed_dim))
    return Params(w1, w2)

def sample_data(batch_size: int, embed_dim: int, rng: ArrayLike):
    '''
    Create random features 'x' and predictable random targets 'y'
    '''
    x = jax.random.normal(rng, (batch_size, embed_dim))
    w = jax.random.normal(jax.random.PRNGKey(1), (embed_dim, embed_dim))
    y = jnp.sin(x @ w)
    return x, y

In [33]:
B, E, H = 16, 1024, 2048

data_key = jax.random.PRNGKey(0)
weight_key = jax.random.PRNGKey(42)

x, y = sample_data(B, E, data_key)
params = init_ffn_weights(E, H, weight_key)
y_pred = ffn(x, params)
jax.debug.visualize_array_sharding(y_pred)

In [34]:
sharded_x = jax.device_put(x, sharding.reshape(G, 1))
jax.debug.visualize_array_sharding(sharded_x)
y_pred = ffn(sharded_x, params)
jax.debug.visualize_array_sharding(y_pred)

In [35]:
@jax.jit
def model(x: ArrayLike, params: Params) -> jax.Array:
    x += ffn(x, params.w1)
    x += ffn(x, params.w2)
    return x

@jax.jit
def criterion(y_pred: ArrayLike, y_true: ArrayLike) -> float:
    return jnp.mean((y_pred - y_true) ** 2)

def step_fn(params: Params, x: ArrayLike) -> float:
    y_pred = model(x, params)
    return criterion(y_pred, y)

def create_dataset(num_samples: int, batch_size: int, embed_dim: int) -> jax.Array:
    return jnp.array([
        sample_data(batch_size, embed_dim, random.PRNGKey(i))
        for i in range(num_samples)
    ])

def init_weights(embed_dim: int, hidden_dim: int, num_layers: int, rng: ArrayLike) -> List[Params]:
    layer_keys = random.split(rng, num_layers)
    return [
        init_ffn_weights(embed_dim, hidden_dim, layer_keys[i])
        for i in layer_keys
    ]


In [41]:
x = jnp.arange(G * E).reshape(G, E)
tiled_params = jax.tree_map(lambda p: jnp.tile(p, (G, 1, 1)), params)
y = jax.pmap(ffn, axis_name='G')(x, tiled_params)
jax.debug.visualize_array_sharding(y)

In [42]:
def split(arr: ArrayLike, num_sections: int=None, axis: int=0) -> jax.Array:
  return jnp.array(jnp.split(arr, num_sections, axis=axis))

In [None]:
from functools import partial

