In [None]:
import os

flags = os.environ.get("XLA_FLAGS", "")
flags += " --xla_force_host_platform_device_count=8"
os.environ["XLA_FLAGS"] = flags

os.environ["CUDA_VISIBLE_DEVICES"] = ""


In [8]:

import functools
from pprint import pprint
from typing import Any, Callable, Dict, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.core.frozen_dict import FrozenDict
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict
jax.config.update("jax_platforms", 'cpu')

# Helper types
PyTree = Any
Parameter = jax.Array | nn.Partitioned
Metrics = Dict[str, Tuple[jax.Array, ...]]

In [9]:
devices = jax.devices()[:4]
mesh = Mesh(devices=devices, axis_names=('model'))
mesh


Mesh(device_ids=array([0, 1, 2, 3]), axis_names=('model',), axis_types=(Auto,))

In [10]:
class SimpleDenseNetwork(nn.Module):
  hidden_dim: int = 4
  output_dim: int = 4

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.hidden_dim)(x)
    x = nn.Dense(self.output_dim)(x)
    return x


In [11]:
model = SimpleDenseNetwork()
L = 8
key, *init_key = jax.random.split(jax.random.key(0), L + 1)
params = [model.init(k, jnp.ones((1, 4)))['params'] for k in init_key]

params = jax.tree.map(
  lambda *x: jnp.concat(x, axis=0),
  *params
)

jax.tree.map(
  lambda x: x.shape,
  params
)


{'Dense_0': {'bias': (32,), 'kernel': (32, 4)},
 'Dense_1': {'bias': (32,), 'kernel': (32, 4)}}

In [None]:
from functools import partial
from jax import vjp

L = 8


@jax.jit
@partial(jax.shard_map, mesh=mesh, in_specs=(P('model')), out_specs=(P('model')))
def pipeline_params_init(params):
   return params


def fwd_pass(params, buffer):

  n_devices = jax.lax.psum(1, 'model')
  layers_per_device = L // n_devices

  def slice_fn(x, i):

    c = x.shape[0] // layers_per_device
    start_idx = (i * c, *(0 for _ in range(x.ndim - 1)))
    len_idx = (c, *x.shape[1:])
    return jax.lax.dynamic_slice(x, start_idx, len_idx)

  def step(params, x):
      fn = lambda params, x: model.apply(
          {'params': params,},
          x
      )
      out, grad_fn = vjp(fn, params, x)
      return out, grad_fn


  grad_fns = []
  for i in range(layers_per_device):

      params_current = jax.tree.map(
          jax.tree_util.Partial(slice_fn, i=i),
          params
      )

      buffer, grad_fn = step(params_current, buffer)
      grad_fns.append(grad_fn)



  return buffer[None, :], grad_fns

# @jax.jit
@partial(jax.shard_map,mesh=mesh, in_specs=(P('model'), P('model')), out_specs=(P('model'), P('model')))
def pipeline_fwd_pass(params, buffer):

    n_devices = jax.lax.psum(1, 'model')
    perm = [(i, (i + 1) % n_devices) for i in range(n_devices)]
    grad_fn_stages = []

    for i in range(n_devices):
      buffer, grad_fn = fwd_pass(params, buffer[0])
      grad_fn_stages.append(grad_fn)
      if i < n_devices - 1:
        buffer = jax.lax.ppermute(buffer, 'model', perm)

    return buffer, grad_fn_stages


In [None]:
@jax.value_and_grad
def loss(input):
  return jnp.sum(input ** 2)

@jax.value_and_grad
def loss2(params, input):

  buffer = jnp.concat(
    [
      x[None, :],
      jnp.zeros((mesh.devices.shape[0] - 1, *(2,4)))
    ], axis=0
  )

  out, grad_fns = pipeline_fwd_pass(params_shard, buffer)

  loss = jnp.sum(out[-1] ** 2)

  return loss

x = jnp.ones((2,4))
params_shard = pipeline_params_init(params)

# out, grad_fn = pipeline_fwd_pass(params_shard, x)

out, grad = loss2(params_shard, x)

In [None]:
def bwd_pass(grads_x, grad_fn):

  n_devices = jax.lax.psum(1, 'model')
  layers_per_device = L // n_devices
  idx = jax.lax.axis_index('model')

  grads = []
  for i in range(layers_per_device -1, -1,-1):

    grads_params, grads_x = grad_fn[i](grads_x)
    grads.append(grads_params)

  grads = jax.tree.map(
     lambda *x: jnp.concat(x, axis=0),
      *grads
  )

  return grads, grads_x[None, :]

@jax.jit
@partial(jax.shard_map, mesh=mesh, in_specs=(P('model'),P('model'),P('model')), out_specs=P('model'))
def pipeline_bwd_pass(params, grads_x, grad_fn):

  n_devices = jax.lax.psum(1, 'model')
  perm = [(i, (i-1) % n_devices) for i in range(n_devices)]

  grads = jax.tree.map(
    lambda x: jnp.zeros_like(x),
    params
  )

  for i in range(n_devices - 1, -1, -1):

    grads_current, grads_x = bwd_pass(grads_x[0], grad_fn[i])
    grads = jax.tree.map(
      lambda x, y: x + y,
      grads,
      grads_current
    )

    if i > 0:
       grads_x = jax.lax.ppermute(grads_x, 'model', perm)

  return grads

grads = pipeline_bwd_pass(params_shard, grad_x, grad_fn)

In [54]:
grads

{'Dense_0': {'bias': Array([-0.00270163,  0.2153322 ,  0.05457603, -0.04286464,  0.00995323,
         -0.01980369, -0.00869385, -0.03147201, -0.00302013, -0.25277025,
         -0.05154405, -0.02218282, -0.04597821,  0.16051143, -0.12673308,
         -0.14682448, -0.56531256, -0.13443986, -0.31572703, -0.19429468,
         -0.06365026, -0.05744242,  0.20348167,  0.6087539 ,  0.3360747 ,
         -0.05146679, -0.3376532 , -0.11463025, -0.4256289 , -0.2123735 ,
          0.1171791 , -0.52579033], dtype=float32),
  'kernel': Array([[ 5.91858849e-03, -4.71737891e-01, -1.19562164e-01,
           9.39054862e-02],
         [ 1.48717128e-03, -1.18534185e-01, -3.00425366e-02,
           2.35957522e-02],
         [-1.72251160e-03,  1.37291849e-01,  3.47966775e-02,
          -2.73297075e-02],
         [ 5.69343334e-03, -4.53792036e-01, -1.15013771e-01,
           9.03331339e-02],
         [ 9.95322503e-03, -1.98036917e-02, -8.69385153e-03,
          -3.14720087e-02],
         [ 9.95322503e-03, -1.

In [59]:
grads

{'Dense_0': {'bias': Array([-0.00270163,  0.2153322 ,  0.05457603, -0.04286464,  0.00995323,
         -0.01980369, -0.00869385, -0.03147201, -0.00302013, -0.25277025,
         -0.05154405, -0.02218282, -0.04597821,  0.16051143, -0.12673308,
         -0.14682448, -0.56531256, -0.13443986, -0.31572703, -0.19429468,
         -0.06365026, -0.05744242,  0.20348167,  0.6087539 ,  0.3360747 ,
         -0.05146679, -0.3376532 , -0.11463025, -0.4256289 , -0.2123735 ,
          0.1171791 , -0.52579033], dtype=float32),
  'kernel': Array([[ 5.91858849e-03, -4.71737891e-01, -1.19562164e-01,
           9.39054862e-02],
         [ 1.48717128e-03, -1.18534185e-01, -3.00425366e-02,
           2.35957522e-02],
         [-1.72251160e-03,  1.37291849e-01,  3.47966775e-02,
          -2.73297075e-02],
         [ 5.69343334e-03, -4.53792036e-01, -1.15013771e-01,
           9.03331339e-02],
         [ 9.95322503e-03, -1.98036917e-02, -8.69385153e-03,
          -3.14720087e-02],
         [ 9.95322503e-03, -1.

In [None]:
_ = jax.tree.map(
  lambda x: x[:],
  params_shard
)
# model.apply(
#   {'params': _},
#   jnp.ones((2,4))
# )

a_1, a_2  = jnp.split(_['Dense_0']['kernel'], 2, axis=0)
b_1, b_2 = jnp.split(_['Dense_1']['kernel'], 2, axis=0)

print(a_1)
print(a_2)


print(jnp.ones((1,4)) @ a_1.T @ b_1.T @ a_2.T@ b_2.T)

[[-0.82300586  0.13979685 -1.0663078   0.21559261]
 [-0.49311835  0.2886271  -0.9571687  -0.27748352]
 [ 0.82770365 -0.7702767  -0.7409361  -0.94439197]
 [-0.3156069   0.0227289  -0.39290708  1.0027606 ]
 [-0.9370494  -0.41238844  0.7707996  -0.53318447]
 [ 0.365841    0.05689541 -0.23631333  0.81044877]
 [-0.64488286 -0.6312494  -0.27232984  0.963054  ]
 [ 0.31055728  0.01419805 -0.26563522  0.6040163 ]
 [-0.37554166  0.16121252 -0.36710387 -0.368435  ]
 [ 0.4102413   0.4437943  -0.40597242 -0.56065243]
 [ 0.29791793 -0.9014654  -0.48906383 -0.86474806]
 [ 0.6632264   0.4023416  -0.02816472 -0.34976414]
 [-0.3684946   1.0404862  -0.9486073  -0.30995214]
 [-0.9629016   0.6509631   0.08371046 -0.37072763]
 [ 0.06373774 -0.03931175 -0.6528216  -0.14480054]
 [ 0.17033684 -0.5886188  -1.0977224   0.9977888 ]]
[[ 3.55339259e-01 -7.73246229e-01 -2.98115402e-03  1.19543038e-01]
 [ 8.80670846e-02  6.59837782e-01  4.95789260e-01 -2.31808230e-01]
 [ 3.37153584e-01 -1.43336654e-01 -9.80363041e-02

TypeError: dot_general requires contracting dimensions to have the same shape, got (16,) and (4,).

In [165]:
mesh.devices

array([CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)],
      dtype=object)

In [149]:
jax.tree.map(
  lambda x: x.sharding,
  params_new
)

{'Dense_0': {'bias': NamedSharding(mesh=Mesh('model': 4, axis_types=(Auto,)), spec=PartitionSpec('model',), memory_kind=unpinned_host),
  'kernel': NamedSharding(mesh=Mesh('model': 4, axis_types=(Auto,)), spec=PartitionSpec('model',), memory_kind=unpinned_host)},
 'Dense_1': {'bias': NamedSharding(mesh=Mesh('model': 4, axis_types=(Auto,)), spec=PartitionSpec('model',), memory_kind=unpinned_host),
  'kernel': NamedSharding(mesh=Mesh('model': 4, axis_types=(Auto,)), spec=PartitionSpec('model',), memory_kind=unpinned_host)}}

In [150]:
jax.tree.map(
  lambda x: print(x.shape),
  params_new
)

(32,)
(32, 4)
(32,)
(32, 4)


{'Dense_0': {'bias': None, 'kernel': None},
 'Dense_1': {'bias': None, 'kernel': None}}

In [None]:
jax.tree.map(
  lambda x: x.shape,
  params_new
)

{'Dense_0': {'bias': (8,), 'kernel': (8, 4)},
 'Dense_1': {'bias': (8,), 'kernel': (8, 4)}}

In [162]:
jax.tree.map(
  lambda x: jax.debug.visualize_array_sharding(x),
  params_new
)

{'Dense_0': {'bias': None, 'kernel': None},
 'Dense_1': {'bias': None, 'kernel': None}}

In [None]:
{
  Dense_0: {Dense_0: {(L, kernel), bias}, Dense_1}
  Dnese_1; {Dense_0: }
  ...
  Dense_n: {}
}