In [None]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
from functools import partial

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

In [None]:
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 *  4.).reshape(16, 4)

@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
         out_specs=P('x', None))
def matmul_basic(a_block, b_block):
  # a_block: f32[2, 8]
  # b_block: f32[8, 4]
  c_partialsum = jnp.dot(a_block, b_block)
  c_block = jax.lax.psum(c_partialsum, 'y')
  # c_block: f32[2, 4]
  return c_block

c = matmul_basic(a, b)   # c: f32[8, 4]z

In [None]:
from jax.tree_util import tree_map, tree_all

def allclose(a, b):
  return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

allclose(c, jnp.dot(a, b))

In [None]:
mesh.axis_names[0]

In [None]:
jax.debug.visualize_array_sharding(c)

In [None]:
from jax.sharding import NamedSharding

a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
b = jax.device_put(b, NamedSharding(mesh, P('y', None)))

@jax.jit
def matmul_reference(a, b):
  c = jnp.dot(a, b)
  return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))

c_ref = matmul_reference(a, b)
allclose(c_ref, jnp.dot(a, b))

In [None]:
print('a blocks:'); jax.debug.visualize_array_sharding(a)
print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c)

In [None]:
import numpy as np
devices = np.array(jax.devices()[:4])
mesh = Mesh(devices, ('i',))  # mesh.shape['i'] = 4

def check_shmap(f, y):
  ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)
  expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
  print(allclose(ans, expected))


check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))

In [None]:
B,S,H,D = (4,16, 12,32)

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('row', 'col'))

X = jnp.arange( B*S*H*D).reshape(B*S, H*D)
W = jnp.arange(H*D*4*H*D).reshape(H*D, 4*H*D)

Xo = jax.device_put(X, NamedSharding(mesh, P('row', 'col')))
Wo = jax.device_put(W, NamedSharding(mesh, P('row', 'col')))
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('row', 'col')),
         out_specs=P('row', 'col'))
def GSPMD_OS(Xij, Wij):
    Xi = jax.lax.all_gather(Xij, 'col', tiled=True, axis=1)
    
    Wj = jax.lax.all_gather(Wij, 'row', tiled=True, axis=0)
    
    return Xi @ Wj

y_ref = GSPMD_OS(X, W)
#y_ref.shape
#jnp.dot(X, W).shape
allclose(y_ref, jnp.dot(X, W))
#y_ref

In [None]:
X1 = jnp.arange( B*S*4*H*D).reshape(B*S, 4*H*D)
W1 = jnp.arange(H*D*4*H*D).reshape(4*H*D, H*D).transpose()

Xi = jax.device_put(X1, NamedSharding(mesh, P('row', 'col')))
Wi = jax.device_put(W1, NamedSharding(mesh, P('row', 'col')))

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('row', 'col')),
         out_specs=P('row', 'col'))
def GSPMD_IS(Xij, Wji):
    Wj = jax.lax.all_gather(Wji, 'row', tiled=True, axis=0)
    print(Wj.shape)
    Yp = jnp.einsum('mn,kn->mk',Xij,Wj)
    return jax.lax.psum_scatter(Yp, 'col', scatter_dimension=1, tiled=True)

y_ref = GSPMD_IS(Xi, Wi)
y_ref.shape
#jnp.dot(X, W).shape
allclose(y_ref, jnp.dot(X1, W1.transpose()))

In [None]:
X1 = jnp.arange( B*S*H*D).reshape(B*S,H*D)
W1 = jnp.arange(H*D*4*H*D).reshape(H*D, 4*H*D)

Xw = jax.device_put(X1.transpose(), NamedSharding(mesh, P('row', 'col')))
Ww = jax.device_put(W1, NamedSharding(mesh, P('row', 'col')))

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('row', 'col')),
         out_specs=P('row', 'col'))
def GSPMD_WS(Xji, Wij):
    Xi = jax.lax.all_gather(Xji, 'col', tiled=True, axis=1)
    #print(Xij.shape)
    #print(Wij.shape)
    Yp = jnp.einsum('ib,io->bo',Xi,Wij)
    return jax.lax.psum_scatter(Yp, 'row', scatter_dimension=0, tiled=True)

y_ref = GSPMD_WS(Xw, Ww)
y_ref.shape
#jnp.dot(X, W).shape
allclose(y_ref, jnp.dot(X1, W1))

In [None]:
Yo = GSPMD_OS(Xo,Wo)
print(Yo.shape)
Xp = GSPMD_IS(Yo,Wo)
assert Xp.shape == Xo.shape
print(Xp.shape)
Wp = GSPMD_WS(Xo,Yo)
assert Wp.shape == Wp.shape
print(Wp.shape)
#%timeit jax.jit(GSPMD_WS)(Xt, Wt).block_until_ready()

In [None]:
import math
B,S,H,D = (4,256, 48,64)

devices = mesh_utils.create_device_mesh((2, 4))
mesh = Mesh(devices, axis_names=('row', 'col'))

X = jnp.arange( B*S*H*D,dtype=jnp.float32).reshape(B*S, H*D)/(B*S*H*D)
W = jnp.arange(H*D*4*H*D,dtype=jnp.float32).reshape(H*D, 4*H*D) / (4*H*D*H*D)

Xo = jax.device_put(X, NamedSharding(mesh, P('row', 'col')))
Wo = jax.device_put(W, NamedSharding(mesh, P('row', 'col')))

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('row', 'col')),
         out_specs=P('row', 'col'))
def Wang_OS(Xij, Wij):
    Wj = jax.lax.all_gather(Wij, 'row', tiled=True, axis=0)
    size = jax.lax.psum(1, 'col')
    idx = jax.lax.axis_index('col')
    shift_up = partial(jax.lax.ppermute, axis_name='col',
                        perm=[(i, (i + 1) % size) for i in range(size)])
    shift_dn = partial(jax.lax.ppermute, axis_name='col',
                        perm=[(i, (i - 1) % size) for i in range(size)])

    B = Wj.shape[0] // size // 2  # half-size blocks
    w_blocks = lambda i, hi: jax.lax.dynamic_slice_in_dim(Wj, (2*i+hi) * B, B, 0)

    x_lo, x_hi = jnp.split(Xij, 2, axis=1)

    out_block  =  x_lo @ w_blocks(idx, 0)
    out_block +=  x_hi @ w_blocks(idx, 1)
    
    '''
    def body_fn(ii,carry, myidx, mysize):
        low,high,oblock = carry
        low = shift_up(low)
        high = shift_dn(high)
        oblock +=  low @ w_blocks((myidx - ii) % mysize, 0)
        oblock +=  high @ w_blocks((myidx + ii) % mysize, 1)
        return (low,high,oblock)
    mybody = partial(body_fn, myidx=idx, mysize=size)
    return jax.lax.fori_loop(1,size,mybody,init_val=(x_lo,x_hi,out_block))[2]
    '''
    
    for i in range(1, size):
        x_lo = shift_up(x_lo)
        x_hi = shift_dn(x_hi)
        out_block +=  x_lo @ w_blocks((idx - i) % size, 0)
        out_block +=  x_hi @ w_blocks((idx + i) % size, 1)
    return out_block
    
    

y_ref = Wang_OS(X, W)
#y_ref.shape
#jnp.dot(X, W).shape
allclose(y_ref, jnp.dot(X, W))



In [None]:
%timeit Wang_OS(Xo,Wo).block_until_ready()
%timeit GSPMD_OS(Xo,Wo).block_until_ready()
allclose(GSPMD_OS(X,W), Wang_OS(X,W))

In [None]:
Xi = jax.device_put(X, NamedSharding(mesh, P('row', 'col')))
Wi = jax.device_put(W.transpose(), NamedSharding(mesh, P('row', 'col')))

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('row', 'col')),
         out_specs=P('row', 'col'))
def Wang_IS(Xij, Wji):
    Wj = jax.lax.all_gather(Wji, 'row', tiled=True, axis=0)
    size = jax.lax.psum(1, 'col')
    idx = jax.lax.axis_index('col')
    shift_up = partial(jax.lax.ppermute, axis_name='col',
                        perm=[(i, (i + 1) % size) for i in range(size)])
    shift_dn = partial(jax.lax.ppermute, axis_name='col',
                        perm=[(i, (i - 1) % size) for i in range(size)])

    B = Wj.shape[0] // size // 2  # half-size blocks
    w_blocks = lambda i, hi: jax.lax.dynamic_slice_in_dim(Wj, (2*i+hi) * B, B, 0)

    y_lo  = jnp.einsum('bi,oi->bo',Xij,w_blocks((idx-1)%size, 0))
    y_hi  = jnp.einsum('bi,oi->bo',Xij, w_blocks((idx+1)%size, 1))
    def body_fn(ii,carry, myidx, mysize):
        low,high = carry
        low = shift_up(low)
        high = shift_dn(high)
        low += jnp.einsum('bi,oi->bo',Xij, w_blocks((myidx - ii - 1) % mysize, 0))
        
        high += jnp.einsum('bi,oi->bo',Xij, w_blocks((myidx + ii + 1) % mysize, 1))
        return (low,high)
    mybody = partial(body_fn, myidx=idx, mysize=size)
    '''
    for i in range(1, size):
        x_lo = shift_up(x_lo)
        x_hi = shift_dn(x_hi)
        out_block +=  x_lo @ w_blocks((idx - i) % size, 0)
        out_block +=  x_hi @ w_blocks((idx + i) % size, 1)
    '''
    
    y_lo,y_hi =  jax.lax.fori_loop(1,size,mybody,init_val=(y_lo,y_hi))
    return jnp.concatenate([y_lo,y_hi],axis=1)

y_ref = Wang_IS(Xi, Wi)
#y_ref.shape
#jnp.dot(X, W).shape
allclose(y_ref, GSPMD_IS(Xi,Wi))

In [None]:
%timeit GSPMD_IS(Xi, Wi).block_until_ready()
%timeit Wang_IS(Xi, Wi).block_until_ready()

In [None]:

def Systolic_OS_(Xij, Wij, B, K): #smallest block size B
    
    X_split = Xij.reshape(Xij.shape[0], Xij.shape[1]//(K*B),K,B)
    W_split = Wij.reshape(Wij.shape[0]//(K*B),K,B, Wij.shape[1])
    x_blocks = lambda i: jax.lax.dynamic_slice_in_dim(X_split, i, 1, 2)
    w_blocks = lambda i: jax.lax.dynamic_slice_in_dim(W_split, i, 1, 1)
    Xk = jax.lax.all_gather(x_blocks(0), 'col', tiled=True, axis=1)
    Wk = jax.lax.all_gather(w_blocks(0), 'row', tiled=True, axis=0)
    
    Yij = Xk.reshape(Xij.shape[0],-1) @ Wk.reshape(-1,Wij.shape[1])
    for k in range(1,K):
        Xk = jax.lax.all_gather(x_blocks(k), 'col', tiled=True, axis=1)
        Wk = jax.lax.all_gather(w_blocks(k), 'row', tiled=True, axis=0)
        Yij += Xk.reshape(Xij.shape[0],-1) @ Wk.reshape(-1,Wij.shape[1])
    return Yij
        

Systolic_OS = jax.jit(partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('row', 'col')),
         out_specs=P('row', 'col'))(partial(Systolic_OS_, B=2,K=8)))

y_ref = Systolic_OS(X, W)
#y_ref.shape
#jnp.dot(X, W).shape
allclose(y_ref, jnp.dot(X, W))

In [None]:
B,S,H,D = (16,128, 48,64)

X = jnp.arange( B*S*H*D,dtype=jnp.float32).reshape(B*S, H*D)/(B*S*H*D)
W = jnp.arange(H*D*4*H*D,dtype=jnp.float32).reshape(H*D, 4*H*D) / (4*H*D*H*D)

Xo = jax.device_put(X, NamedSharding(mesh, P('row', 'col')))
Wo = jax.device_put(W, NamedSharding(mesh, P('row', 'col')))

Systolic_OS = jax.jit(partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('row', 'col')),
         out_specs=P('row', 'col'))(partial(Systolic_OS_, B=8,K=8)))

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('row', 'col')),
         out_specs=P('row', 'col'))
def onlyAGX(Xij, Wij):
    Xi = jax.lax.all_gather(Xij, 'col', tiled=True, axis=1)
    #print(Xi.shape)
    #Wj = jax.lax.all_gather(Wij, 'row', tiled=True, axis=0)
    #print(Wj.shape)
    return Xi

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('row', 'col')),
         out_specs=P('row', 'col'))
def onlyAGW(Xij, Wij):
    #Xi = jax.lax.all_gather(Xij, 'col', tiled=True, axis=1)
    #print(Xi.shape)
    Wj = jax.lax.all_gather(Wij, 'row', tiled=True, axis=0)
    #print(Wj.shape)
    return Wj

%timeit onlyAGX(Xo,Wo).block_until_ready()
%timeit onlyAGW(Xo,Wo).block_until_ready()
%timeit Systolic_OS(Xo,Wo).block_until_ready()
%timeit Wang_OS(Xo,Wo).block_until_ready()
%timeit GSPMD_OS(Xo,Wo).block_until_ready()

allclose(GSPMD_OS(X,W), Systolic_OS(X,W))

In [None]:
Xi = jax.device_put(X, NamedSharding(mesh, P('row', 'col')))
Wi = jax.device_put(W.transpose(), NamedSharding(mesh, P('row', 'col')))

def Systolic_IS_(Xij, Wji, B, K): #smallest block size B
    O = Wji.shape[0]*jax.lax.psum(1, 'row')
    C = jax.lax.psum(1,'col')
    W_split = Wji.reshape(Wji.shape[0]//K//B,K,B, Wji.shape[1])
    w_blocks = lambda i: jax.lax.dynamic_slice_in_dim(W_split, i, 1, 1)
    Yij = jnp.zeros((Xij.shape[0],O//(K*B*C), K,B),Xij.dtype)
    #y_blocks = lambda i: jax.lax.dynamic_slice_in_dim(Yij, i, 1, 2)
    #Shape: (I/C, O/R/K/B, 1, B) -> (I/C, O/K/B, 1, B)
    Wk = jax.lax.all_gather(w_blocks(0), 'row', tiled=True, axis=0)
    Yk = jnp.einsum('ni,okbi->nokb',Xij,Wk) #(B/R, O/K/B, 1, B)
    #Yijs = []
    Yk2 = jax.lax.psum_scatter(Yk, 'col', scatter_dimension=1, tiled=True)#(B/R, O/K/B/C, 1, B)
    #Yijs.append(Yk2)

    #print(Yk2.shape)
    
    Yij = jax.lax.dynamic_update_index_in_dim(Yij,Yk2,0,axis=2)
    def body_fn(kk,carry):
        Wk = jax.lax.all_gather(w_blocks(kk), 'row', tiled=True, axis=0)
        Yk = jnp.einsum('ni,okbi->nokb',Xij,Wk) #(B/R, O/K/B, 1, B)
        Yk2 = jax.lax.psum_scatter(Yk, 'col', scatter_dimension=1, tiled=True)
        #Yijs.append(Yk2)
        return jax.lax.dynamic_update_index_in_dim(carry,Yk2,kk,axis=2)
        
    Yij = jax.lax.fori_loop(1,K,body_fn,init_val=Yij)
    return Yij.reshape(Xij.shape[0],-1)
    ''' 
    for k in range(1,K):
        Wk = jax.lax.all_gather(w_blocks(k), 'row', tiled=True, axis=1)
        Yk = jnp.einsum('mn,nikb->mikb',Xij,Wk) #(B/R, O/K/B, 1, B)
        Yk2 = jax.lax.psum_scatter(Yk, 'col', scatter_dimension=1, tiled=True)
        Yijs.append(Yk2)
        #Yij = jax.lax.dynamic_update_index_in_dim(Yij,Yk2,k,axis=2)
        #Yij.at[:,:,k:k+1,:] = Yk2
    return jnp.concatenate(Yijs,axis=2).reshape(Xij.shape[0],-1)
    #return Yij.reshape(Xij.shape[0],-1)
    '''
    
        

Systolic_IS = jax.jit(partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('row', 'col')),
         out_specs=P('row', 'col'))(partial(Systolic_IS_, B=8,K=8)))

y_ref = Systolic_IS(Xi, Wi)
#y_ref.shape
#jnp.dot(X, W).shape
allclose(y_ref, GSPMD_IS(Xi,Wi))


In [None]:
B,S,H,D = (16,512, 32,64)

X = jnp.arange( B*S*4*H*D,dtype=jnp.float32).reshape(B*S, 4*H*D)/(B*S*H*D)
W = jnp.arange(H*D*4*H*D,dtype=jnp.float32).reshape(4*H*D, H*D) / (4*H*D*H*D)

Xi = jax.device_put(X, NamedSharding(mesh, P('row', 'col')))
Wi = jax.device_put(W.transpose(), NamedSharding(mesh, P('row', 'col')))

Systolic_IS = jax.jit(partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('row', 'col')),
         out_specs=P('row', 'col'))(partial(Systolic_IS_, B=4,K=16)))
print(allclose(Systolic_IS(Xi,Wi), Wang_IS(Xi,Wi)))

%timeit GSPMD_IS(Xi,Wi).block_until_ready()
%timeit Wang_IS(Xi,Wi).block_until_ready()
%timeit Systolic_IS(Xi,Wi).block_until_ready()

In [None]:
Xw = jax.device_put(X.transpose(), NamedSharding(mesh, P('row', 'col')))
Ww = jax.device_put(W, NamedSharding(mesh, P('row', 'col')))

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('row', 'col')),
         out_specs=P('row', 'col'))
def Wang_WS(Xji, Wij):
    R = jax.lax.psum(1,'row')
    C = jax.lax.psum(1,'col')
    #input traffic is bigger, so osplit
    if C * Xji.shape[0] > R*Wij.shape[1]:
        Xi = jax.lax.all_gather(Xji, 'col', tiled=True, axis=1)
        size = R
        idx = jax.lax.axis_index('row')
        shift_up = partial(jax.lax.ppermute, axis_name='row',
                            perm=[(i, (i + 1) % size) for i in range(size)])
        shift_dn = partial(jax.lax.ppermute, axis_name='row',
                            perm=[(i, (i - 1) % size) for i in range(size)])

        B = Xi.shape[1] // size // 2  # half-size blocks
        x_blocks = lambda i, hi: jax.lax.dynamic_slice_in_dim(Xi, (2*i+hi) * B, B, 1)

        y_lo  =  jnp.einsum('ib,io->bo',x_blocks((idx-1)%size, 0), Wij)
        y_hi  =  jnp.einsum('ib,io->bo',x_blocks((idx+1)%size, 1), Wij)
        def body_fn(ii,carry, myidx, mysize):
            low,high = carry
            low = shift_up(low)
            high = shift_dn(high)
            low +=  jnp.einsum('ib,io->bo',x_blocks((myidx - ii - 1) % mysize, 0), Wij)
            high +=  jnp.einsum('ib,io->bo',x_blocks((myidx + ii + 1) % mysize, 1), Wij)
            return (low,high)
        mybody = partial(body_fn, myidx=idx, mysize=size)
        
        
        y_lo,y_hi =  jax.lax.fori_loop(1,size,mybody,init_val=(y_lo,y_hi))
        return jnp.concatenate([y_lo,y_hi],axis=0)
    #output traffic is bigger, so isplit
    else:
        size = C
        idx = jax.lax.axis_index('col')
        shift_up = partial(jax.lax.ppermute, axis_name='col',
                            perm=[(i, (i + 1) % size) for i in range(size)])
        shift_dn = partial(jax.lax.ppermute, axis_name='col',
                            perm=[(i, (i - 1) % size) for i in range(size)])
        Yj = jnp.zeros((Xji.shape[1]*C, Wij.shape[1]), Xji.dtype)
        B = Yj.shape[0] // size //2 # half-size blocks
        
        Yk = jnp.einsum('ib,io->bo',Xji, Wij)
        print(Yk.shape)
        print(Yj.shape)
        Yj = jax.lax.dynamic_update_slice_in_dim(Yj, Yk, 2*((idx)%size)*B, axis=0)
        x_lo, x_hi = jnp.split(Xji, 2, 1)
        def body_fn(ii,carry, myidx, mysize):
            lo,hi, Yj = carry
            lo = shift_up(lo)
            hi = shift_dn(hi)
            Y_lo = jnp.einsum('ib,io->bo',lo, Wij)
            Yj = jax.lax.dynamic_update_slice_in_dim(Yj, Y_lo, 2*((idx-ii)%size)*B, axis=0)
            Y_hi = jnp.einsum('ib,io->bo',hi, Wij)
            Yj = jax.lax.dynamic_update_slice_in_dim(Yj, Y_hi, 2*((idx+ii)%size)*B+B, axis=0)
            return (lo,hi,Yj)
        mybody = partial(body_fn, myidx=idx, mysize=size)
        Yj = jax.lax.fori_loop(1,size,mybody,init_val=(x_lo,x_hi,Yj))[2]
        return jax.lax.psum_scatter(Yj, 'row', scatter_dimension=0, tiled=True)
        

def Systolic_WS_(Xji, Wij, B, K): #smallest block size B
    N = Xji.shape[1]*jax.lax.psum(1, 'col')
    R = jax.lax.psum(1,'row')
    X_split = Xji.reshape(Xji.shape[0], Xji.shape[1]//K//B,K,B)
    x_blocks = lambda i: jax.lax.dynamic_slice_in_dim(X_split, i, 1, 2)
    Yij = jnp.zeros((N//(K*B*R), K,B, Wij.shape[1]),Xji.dtype)
    #y_blocks = lambda i: jax.lax.dynamic_slice_in_dim(Yij, i, 1, 2)
    
    Xk = jax.lax.all_gather(x_blocks(0), 'col', tiled=True, axis=1)
    Yk = jnp.einsum('inkb,io->nkbo',Xk,Wij) 
    #Yijs = []
    Yk2 = jax.lax.psum_scatter(Yk, 'row', scatter_dimension=0, tiled=True)
    #Yijs.append(Yk2)

    #print(Yk2.shape)
    
    Yij = jax.lax.dynamic_update_index_in_dim(Yij,Yk2,0,axis=1)
    def body_fn(kk,carry):
        Xk = jax.lax.all_gather(x_blocks(kk), 'col', tiled=True, axis=1)
        Yk = jnp.einsum('inkb,io->nkbo',Xk,Wij) 
        Yk2 = jax.lax.psum_scatter(Yk, 'row', scatter_dimension=0, tiled=True)
        #Yijs.append(Yk2)
        return jax.lax.dynamic_update_index_in_dim(carry,Yk2,kk,axis=1)
        
    Yij = jax.lax.fori_loop(1,K,body_fn,init_val=Yij)
    return Yij.reshape(-1,Wij.shape[1])
    
        

Systolic_WS = jax.jit(partial(shard_map, mesh=mesh, in_specs=(P('row', 'col'), P('row', 'col')),
         out_specs=P('row', 'col'))(partial(Systolic_WS_, B=8,K=8)))

y_ref = Wang_WS(Xw, Ww)
#y_ref.shape
#jnp.dot(X, W).shape
print(allclose(y_ref, GSPMD_WS(Xw,Ww)))
print(allclose(Systolic_WS(Xw,Ww), y_ref))
%timeit Systolic_WS(Xw,Ww).block_until_ready()

In [None]:
%timeit GSPMD_WS(Xw,Ww).block_until_ready()
%timeit Wang_WS(Xw,Ww).block_until_ready()
%timeit Systolic_WS(Xw,Ww).block_until_ready()

In [None]:
B,S,H,D = (4,128, 48,64)

X = jnp.arange( B*S*H*D,dtype=jnp.float32).reshape(B*S, H*D)/(B*S*H*D)
W = jnp.arange(H*D*4*H*D,dtype=jnp.float32).reshape(H*D, 4*H*D) / (4*H*D*H*D)

Xo = jax.device_put(X, NamedSharding(mesh, P('row', 'col')))
Wo = jax.device_put(W, NamedSharding(mesh, P('row', 'col')))
OS, IS, WS = GSPMD_OS, GSPMD_IS, GSPMD_WS
Yo = GSPMD_OS(Xo,Wo)

%timeit OS(Xo,Wo).block_until_ready()
%timeit IS(Yo,Wo).block_until_ready()
%timeit WS(Xo,Yo).block_until_ready()

In [None]:
%timeit Wang_OS(Xo,Wo).block_until_ready()
%timeit Wang_IS(Yo,Wo).block_until_ready()
%timeit Wang_WS(Xo,Yo).block_until_ready()

In [None]:
%timeit Systolic_OS(Xo,Wo).block_until_ready()
%timeit Systolic_IS(Yo,Wo).block_until_ready()
%timeit Systolic_WS(Xo,Yo).block_until_ready()