In [2]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
import jax
jax.config.update("jax_debug_nans", False)

import jax.numpy as jnp
mesh = jax.make_mesh((8,), ('tensor',))

In [None]:
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jax.nn.relu(outputs)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))

In [3]:
def init_layer(key, n_in, n_out):
  k1, k2 = jax.random.split(key)
  W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
  b = jax.random.normal(k2, (n_out,))
  return W, b

def init(key, layer_sizes, batch_size):
  key, *keys = jax.random.split(key, len(layer_sizes))
  params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

  key, *keys = jax.random.split(key, 3)
  inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
  targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

  return params, (inputs, targets)

In [4]:
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32

params, batch = init(jax.random.key(0), layer_sizes, batch_size)

In [None]:
from functools import partial

from jax.sharding import NamedSharding, Mesh, PartitionSpec as P



# replicate initial params on all devices, shard data batch over devices
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P()))

# adapt the loss function to sum the losses across devices
def loss_dp(params, batch):
  @partial(jax.shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())
  def loss_spmd(local_batch):
    inputs, targets = local_batch
    predictions = predict(params, inputs)  # use reference 'predict`
    local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
    return jax.lax.pmean(local_loss, 'batch')
  return loss_spmd(batch)

In [6]:
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_dp)(params, batch))

11.920299
11.920299


In [7]:
def allclose(a, b):
  return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_dp))(params, batch)))

NameError: name 'tree_all' is not defined

In [7]:
L = len(params) - 2        # num layers, excluding first and last
N = batch_size             # batch size
F = params[0][0].shape[1]  # num features

# choose some pipeline parameters
S = 2      # number of stages
B = 8      # size of each microbatch
assert L % S == 0, "S (number of stages) must divide L (number of inner layers)"

# compute some useful quantities
M, ragged = divmod(N, B)  # M is number of microbatches
assert not ragged, "B (size of each microbatch) must divide total batch size"
K, ragged = divmod(M, S)  # K is microbatches per stage
assert not ragged, "S (number of stages) must divide number of microbatches"
print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total')
print(f'{B} examples per microbatch, {M} microbatches total')

2 stages, 2 layer(s) per stage, 4 pipelined layers total
8 examples per microbatch, 4 microbatches total


In [8]:
mesh = Mesh(jax.devices()[:S], ('stages',))
def spmd_pipeline(fn, stage_params, inputs):
  stage = jax.lax.axis_index('stages')
  outputs = jnp.zeros_like(inputs) * jnp.nan
  state = jnp.zeros((L // S, B, F)) * jnp.nan
  for i in range(M+L-1):
    state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0]))
    state = jax.vmap(fn)(stage_params, state)
    outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K]))
    state, inputs, outputs = shift(i, state, inputs, outputs)
  outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)])
  return outputs

def shift(i, state, inputs, outputs):
  sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)])
  state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1))
  if (i % K) == (-1 % K):
    inputs = sh(inputs, +1)
  if ((i-L+1) % K) == (-1 % K):
    outputs = sh(outputs, +1)
  return state, inputs, outputs
def predict_pp(params, inputs):
  (W_first, b_first), inner_params, (W_last, b_last) = params
  inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first)
  inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]),
                        inner_params, inputs)
  outputs = jnp.dot(inputs, W_last) + b_last
  return outputs

@partial(jax.shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),
         out_specs=P())
def loss_pp(params, batch):
  inputs, targets = batch
  predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)
  local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
  return jax.lax.pmean(local_loss, 'stages')

In [9]:
first_params, *inner_params, last_params = params
Ws, bs = zip(*inner_params)
params_stacked = jnp.stack(Ws), jnp.stack(bs)
first_params = jax.device_put(first_params, NamedSharding(mesh, P()))
params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages')))
last_params = jax.device_put(last_params, NamedSharding(mesh, P()))
params_ = first_params, params_stacked, last_params

batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))

In [12]:
print(loss_pp(params_, batch_))

11.920298


In [11]:
_ = jax.jit(jax.grad(loss_pp))(params_, batch_)   # don't crash

In [14]:
_


((Array([[ 0.00083114, -0.02315384, -0.01116189, ...,  0.00821616,
          -0.03333623,  0.02138161],
         [ 0.00264254,  0.01418981, -0.00351758, ...,  0.00914318,
           0.00340969,  0.00174211],
         [ 0.00046268, -0.06213158, -0.00636453, ..., -0.00573286,
          -0.00366312,  0.0008602 ],
         ...,
         [ 0.00031812, -0.00743218,  0.02016719, ...,  0.00143266,
           0.02880489, -0.01309206],
         [ 0.00283473, -0.00373359,  0.01345176, ...,  0.00291418,
           0.04606411,  0.01251762],
         [ 0.00109067, -0.03741013,  0.01151825, ...,  0.00170438,
           0.01178429,  0.01915886]], dtype=float32),
  Array([-0.00184396,  0.01669584,  0.02164971, -0.00165527, -0.12654759,
          0.01433781, -0.00963732, -0.03216164,  0.01713463, -0.03323705,
          0.09363743,  0.07619127,  0.00747464,  0.15072134,  0.03830011,
         -0.05274152,  0.10495227, -0.01853032,  0.00505064, -0.03301641,
         -0.13752127,  0.04429763,  0.        ,  

In [13]:
A = jnp.array([[0, -1], [2, 3]])
def matrix_exp(A, n):
  for i in range(n-1):
    A = A @ A
  return A

In [15]:
matrix_exp(A, 7)

Array([[ 2,  1],
       [-2, -1]], dtype=int32)

In [23]:
A = jnp.array([[1/jnp.sqrt(2), -1/jnp.sqrt(2)], [1/jnp.sqrt(2), 1/jnp.sqrt(2)]])
B = jnp.array([[3, -1], [1, 1]])

In [19]:
(A.T @ A).astype(jnp.float16)

Array([[1., 0.],
       [0., 1.]], dtype=float16)

In [22]:
C  = (A.T @ B).astype(jnp.float16)

ValueError: matmul input operand 1 must have ndim at least 1, but it has ndim 0

In [26]:
A_T = A.T

In [None]:
C = A_T @ B

In [29]:
D = C @ A

In [None]:
D.astype(jnp.float16)

Array([[ 2.e+00, -2.e+00],
       [ 6.e-08,  2.e+00]], dtype=float16)

In [3]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
import jax
jax.config.update("jax_debug_nans", False)

import jax.numpy as jnp
mesh = jax.make_mesh((8,), ('tensor',))
mesh

Mesh(device_ids=array([0, 1, 2, 3, 4, 5, 6, 7]), axis_names=('tensor',), axis_types=(Auto,))

In [4]:
import flax.linen as nn
from jaxtyping import Array

class TestNet(nn.Module):
  features: int = 8

  @nn.compact
  def __call__(self, x: Array):
    # params = self.scope.get_variable("params", "Dense_0")
    # params = jax.lax.all_gather(params, 'data')
    out = nn.Dense(self.features)(x)
    return out

class TensorNet(nn.Module):
  features: int = 8

  @nn.compact
  def __call__(self, x: Array):
    local_out = nn.Dense(self.features)(x)
    if not self.is_mutable_collection("params"):
      local_out = jax.lax.psum_scatter(
        local_out,
        'tensor',
        scatter_dimension=local_out.ndim - 1,
        tiled=True
      )
    return local_out


In [5]:
model = TestNet()
x = jnp.ones((2, 4, 16))
init_key = jax.random.PRNGKey(0)
params = model.init(init_key, x)['params']

In [None]:
from functools import partial
from jax.sharding import PartitionSpec as P
model_tensor = TensorNet()

tensor_keys = jax.random.split(init_key, 8)
tensor_keys = jnp.array(tensor_keys)

var_spec = jax.eval_shape(lambda key, x: model_tensor.init(key, x)['params'],  init_key, x)
def get_sharding(current_shape):
  if(current_shape.ndim < 2):
    return P()
  else:
   return P('tensor')

out_spec = jax.tree.map(lambda x: get_sharding(x), var_spec)
@partial(jax.shard_map, mesh=mesh, in_specs=(P(None, None, 'tensor'), P('tensor')), out_specs=out_spec)
def init_tensor(x, key):
  return model_tensor.init(key[0], x)['params']

@partial(jax.shard_map, mesh=mesh, in_specs=(out_spec, P(None, None, 'tensor')), out_specs=(P(None, None, 'tensor')))
def predict_tensor(params, x):
  return model_tensor.apply({'params': params}, x)

In [13]:

params_tensor = init_tensor(x, tensor_keys)

In [8]:
jax.debug.visualize_array_sharding(params_tensor['Dense_0']['kernel'])

In [86]:
params['Dense_0']['kernel'].shape

(16, 8)

In [102]:
print(x.shape)
print(model.apply({'params': params}, x))

(2, 4, 16)
{'kernel': Array([[-0.4129499 ,  0.18863787, -0.34443754, -0.21300575,  0.22984762,
         0.32090548,  0.26101372, -0.01162406],
       [ 0.10569897,  0.19433817, -0.45485714,  0.21357062, -0.17656627,
         0.12733398,  0.30540565, -0.0687379 ],
       [-0.37083328, -0.33003575, -0.17782678, -0.00833714, -0.08470196,
        -0.1561097 , -0.527091  , -0.13473354],
       [-0.28729278, -0.1519707 ,  0.13622738, -0.18332122,  0.31884882,
         0.30622914,  0.42549726, -0.36607867],
       [ 0.1551552 ,  0.4243707 ,  0.2194014 , -0.06142971, -0.00613193,
        -0.21234725,  0.27880386,  0.11867873],
       [ 0.14646956, -0.06827474, -0.02934157,  0.17079332,  0.1213038 ,
         0.00091495,  0.09979176, -0.2644029 ],
       [ 0.09694615, -0.3536361 , -0.2172665 ,  0.20046584,  0.26204365,
         0.24415551,  0.48724577, -0.00846941],
       [-0.10895626,  0.2645487 ,  0.11885007, -0.09636065, -0.30997655,
         0.2348184 , -0.13377994, -0.16334581],
       [ 0

In [95]:
x_shard = predict_tensor(params_tensor, x)

x shape: (2, 4, 2)
local_out shape: (2, 4, 8)
2
(2, 4, 1)


In [98]:
jax.debug.visualize_array_sharding(x_shard[0])

In [94]:
x_shard

(2, 4, 8)

In [9]:
def display_param_sharding(params):
  jax.tree.map(
    lambda p: jax.debug.visualize_array_sharding(p), params
  )

In [10]:
display_param_sharding(params)

In [14]:
display_param_sharding(params_tensor)