## Testing MeshFlow Algorithms in Jax CPU

This notebook shows correctness of MeshFlow and other 2D GeMM algorithms.

Instead of using a real 2D device mesh, we use Jax's CPU emulation of multi-device mesh.

You only require CPU version of Jax to run this notebook.

To understand the code in detail, please check out Jax shard\_map [tutorial](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html).

In [None]:
import os

#This allows emulating multi-device mesh with CPU threads.
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

from functools import partial

#Please see TensorParallel.py for 2D GeMM implementations.
from TensorParallel import SPMD, createMultihostMatrix

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
from jax.tree_util import tree_map, tree_all

jax.config.update('jax_platform_name', 'cpu')

#The function to check the correctness of the output.
def allclose(a, b):
  return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

In [None]:
#batsh size, input dimension, and output dimension
B,I,O = (512, 256, 1024)

#8 devices mapped to a (4x2) device mesh
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('row', 'col'))

#Create the reference matrices
X = jnp.arange(B*I,dtype=jnp.float32).reshape(B,I)/(B*I)
W = jnp.arange(I*O,dtype=jnp.float32).reshape(I, O) / (I*O)
Y = X@W

#First, check the correctness of collective algorithm
collective = SPMD(mesh,'collective')

#Partition the matrices to the device mesh
X_p = jax.device_put(X, NamedSharding(mesh, P('row', 'col')))
W_p = jax.device_put(W, NamedSharding(mesh, P('row', 'col')))

#Get the collective output-stationary algorithm
collective = SPMD(mesh,'collective')
collective_os = collective.OS()

#Compare the collective OS result with the reference
Y_p = collective_os(X_p, W_p)
print(allclose(Y_p,Y))


In [None]:
#Now, we demonstrate the backpropagation computations using the dataflows.
#The input for the backward pass is the output gradient, dY.

dY = jnp.ones(B*O,dtype=jnp.float32).reshape(B,O)/(B*O)
dY_p = jax.device_put(dY, NamedSharding(mesh, P('row', 'col')))

#Backward data pass computes dX = dY * W^T
dX = dY @ W.transpose()

#This can be computed via LS algorithm, LS(dY, W) = dY * W^T
collective_ls = collective.LS()
dX_p = collective_ls(dY_p, W_p)

print(allclose(dX_p, dX))

#Backward weight pass computes dW = X^T * dY
dW = X.transpose() @ dY

#RS algorithm computes this: RS(X, dY) = X^T * dY
collective_rs = collective.RS()
dW_p = collective_rs(X_p, dY_p)

print(allclose(dW_p, dW))

In [None]:
#Hereby we verify MeshFlow algorithm using collective algorithm as a reference.
meshflow = SPMD(mesh, 'meshflow')
meshflow_os = meshflow.OS()
meshflow_ls = meshflow.LS()
meshflow_rs = meshflow.RS()

Y_p2 = meshflow_os(X_p, W_p)
dX_p2 = meshflow_ls(dY_p, W_p)
dW_p2 = meshflow_rs(X_p, dY_p)

print(allclose(Y_p, Y_p2))
print(allclose(dX_p, dX_p2))
print(allclose(dW_p, dW_p2))