### 1. AllGather

In [25]:
import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
import jax
import jax.numpy as jnp
from jax import lax

jax.local_device_count()

4

In [47]:
inp = jnp.arange(16).reshape(4, 2, 2)
inp

Array([[[ 0,  1],
        [ 2,  3]],

       [[ 4,  5],
        [ 6,  7]],

       [[ 8,  9],
        [10, 11]],

       [[12, 13],
        [14, 15]]], dtype=int32)

In [54]:
def fn_0(x):
    return lax.all_gather(x, axis_name="i", axis=0, tiled=True)


fn_pmap0 = jax.pmap(fn_0, axis_name="i")

out = fn_pmap0(inp)

out[0]

Array([[ 0,  1],
       [ 2,  3],
       [ 4,  5],
       [ 6,  7],
       [ 8,  9],
       [10, 11],
       [12, 13],
       [14, 15]], dtype=int32)

def fn_1(x):
    return lax.all_gather(x,axis_name='i',axis=1,tiled=True)

fn_pmap1 = jax.pmap(fn_1,axis_name='i')

out = fn_pmap1(inp)

out[0]

### 2. Reduce_Scatter

In [66]:
inp = jnp.arange(16).reshape(4, 4)
inp

Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]], dtype=int32)

In [67]:
def fn(x):
    return lax.psum_scatter(x, axis_name="i", scatter_dimension=0, tiled=True)


fn_pmap = jax.pmap(fn, axis_name="i")

out = fn_pmap(inp)
out

Array([[24],
       [28],
       [32],
       [36]], dtype=int32)

### 3. AllReduce

In [69]:
inp = jnp.arange(8).reshape(4, 2)
inp

Array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7]], dtype=int32)

In [None]:
def fn(x):
    return lax.psum(x, axis_name="i")


fn_pmap = jax.pmap(fn, axis_name="i")

out = fn_pmap(inp)
out

Array([[12, 16],
       [12, 16],
       [12, 16],
       [12, 16]], dtype=int32)

In [71]:
inp = jnp.arange(16).reshape(4, 4)
inp

Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]], dtype=int32)

In [72]:
def fn(x):
    return lax.psum(x, axis_name="i", axis_index_groups=[[0, 1], [2, 3]])


fn_pmap = jax.pmap(fn, axis_name="i")

out = fn_pmap(inp)
out

Array([[ 4,  6,  8, 10],
       [ 4,  6,  8, 10],
       [20, 22, 24, 26],
       [20, 22, 24, 26]], dtype=int32)

### 4.AllToAll

In [74]:
inp = jnp.array(
    [
        [0, 0, 0, 0],  # Dev 0 的数据
        [1, 1, 1, 1],  # Dev 1 的数据
        [2, 2, 2, 2],  # Dev 2 的数据
        [3, 3, 3, 3],  # Dev 3 的数据
    ]
)

In [75]:
def fn(x):
    return lax.all_to_all(x, axis_name="i", split_axis=0, concat_axis=0)


fn_pmap = jax.pmap(fn, axis_name="i")

out = fn_pmap(inp)
out

Array([[0, 1, 2, 3],
       [0, 1, 2, 3],
       [0, 1, 2, 3],
       [0, 1, 2, 3]], dtype=int32)

In [77]:
inp = jnp.arange(64).reshape(4, 4, 4)
inp

Array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]],

       [[16, 17, 18, 19],
        [20, 21, 22, 23],
        [24, 25, 26, 27],
        [28, 29, 30, 31]],

       [[32, 33, 34, 35],
        [36, 37, 38, 39],
        [40, 41, 42, 43],
        [44, 45, 46, 47]],

       [[48, 49, 50, 51],
        [52, 53, 54, 55],
        [56, 57, 58, 59],
        [60, 61, 62, 63]]], dtype=int32)

In [80]:
def fn(x):
    return lax.all_to_all(x, axis_name="i", split_axis=1, concat_axis=0)


fn_pmap = jax.pmap(fn, axis_name="i")

out = fn_pmap(inp)
out

Array([[[ 0,  4,  8, 12],
        [16, 20, 24, 28],
        [32, 36, 40, 44],
        [48, 52, 56, 60]],

       [[ 1,  5,  9, 13],
        [17, 21, 25, 29],
        [33, 37, 41, 45],
        [49, 53, 57, 61]],

       [[ 2,  6, 10, 14],
        [18, 22, 26, 30],
        [34, 38, 42, 46],
        [50, 54, 58, 62]],

       [[ 3,  7, 11, 15],
        [19, 23, 27, 31],
        [35, 39, 43, 47],
        [51, 55, 59, 63]]], dtype=int32)

In [84]:
def fn(x):
    return lax.all_to_all(x, axis_name="i", split_axis=1, concat_axis=1)


fn_pmap = jax.pmap(fn, axis_name="i")

out = fn_pmap(inp)
out

Array([[[ 0, 16, 32, 48],
        [ 4, 20, 36, 52],
        [ 8, 24, 40, 56],
        [12, 28, 44, 60]],

       [[ 1, 17, 33, 49],
        [ 5, 21, 37, 53],
        [ 9, 25, 41, 57],
        [13, 29, 45, 61]],

       [[ 2, 18, 34, 50],
        [ 6, 22, 38, 54],
        [10, 26, 42, 58],
        [14, 30, 46, 62]],

       [[ 3, 19, 35, 51],
        [ 7, 23, 39, 55],
        [11, 27, 43, 59],
        [15, 31, 47, 63]]], dtype=int32)

In [82]:
def fn(x):
    return lax.all_to_all(x, axis_name="i", split_axis=0, concat_axis=1)


fn_pmap = jax.pmap(fn, axis_name="i")

out = fn_pmap(inp)
out

Array([[[ 0, 16, 32, 48],
        [ 1, 17, 33, 49],
        [ 2, 18, 34, 50],
        [ 3, 19, 35, 51]],

       [[ 4, 20, 36, 52],
        [ 5, 21, 37, 53],
        [ 6, 22, 38, 54],
        [ 7, 23, 39, 55]],

       [[ 8, 24, 40, 56],
        [ 9, 25, 41, 57],
        [10, 26, 42, 58],
        [11, 27, 43, 59]],

       [[12, 28, 44, 60],
        [13, 29, 45, 61],
        [14, 30, 46, 62],
        [15, 31, 47, 63]]], dtype=int32)

In [83]:
def fn(x):
    return lax.all_to_all(x, axis_name="i", split_axis=0, concat_axis=0)


fn_pmap = jax.pmap(fn, axis_name="i")

out = fn_pmap(inp)
out

Array([[[ 0,  1,  2,  3],
        [16, 17, 18, 19],
        [32, 33, 34, 35],
        [48, 49, 50, 51]],

       [[ 4,  5,  6,  7],
        [20, 21, 22, 23],
        [36, 37, 38, 39],
        [52, 53, 54, 55]],

       [[ 8,  9, 10, 11],
        [24, 25, 26, 27],
        [40, 41, 42, 43],
        [56, 57, 58, 59]],

       [[12, 13, 14, 15],
        [28, 29, 30, 31],
        [44, 45, 46, 47],
        [60, 61, 62, 63]]], dtype=int32)

In [85]:
import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
import jax
import jax.numpy as jnp

# Local Shape: (4, 4)
global_inp = jnp.zeros((4, 4, 4))


# 1. 测试 (0, 1) - 压扁
def run_0_1(x):
    return lax.all_to_all(x, axis_name="i", split_axis=0, concat_axis=1)


# 2. 测试 (1, 0) - 拉长
def run_1_0(x):
    return lax.all_to_all(x, axis_name="i", split_axis=1, concat_axis=0)


out_0_1 = jax.pmap(run_0_1, axis_name="i")(global_inp)
out_1_0 = jax.pmap(run_1_0, axis_name="i")(global_inp)

print("原始 Local Shape: (4, 4)")
print("-" * 30)
print(f"(0, 1) 输出 Local Shape: {out_0_1.shape[1:]}  <-- 变扁了！")
print(f"(1, 0) 输出 Local Shape: {out_1_0.shape[1:]}  <-- 变长了！")

原始 Local Shape: (4, 4)
------------------------------
(0, 1) 输出 Local Shape: (4, 4)  <-- 变扁了！
(1, 0) 输出 Local Shape: (4, 4)  <-- 变长了！


In [86]:
import jax
import jax.numpy as jnp
import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"

# 准备 (4, 4) 的数据
global_inp = jnp.arange(64).reshape(4, 4, 4)


def run_debug(x, s, c):
    out = lax.all_to_all(x, axis_name="i", split_axis=s, concat_axis=c)
    # 直接在 pmap 内部打印形状，这是最准的
    jax.debug.print("Split={} Concat={} -> Shape={}", s, c, out.shape)
    return out


# 1. 验证 (0, 1) - 应该是扁的
jax.pmap(lambda x: run_debug(x, 0, 1), axis_name="i")(global_inp)

# 2. 验证 (1, 0) - 应该是长的
jax.pmap(lambda x: run_debug(x, 1, 0), axis_name="i")(global_inp)

# 3. 验证 (1, 1) - 应该是方的 (且数据是你刚才发的那个)
jax.pmap(lambda x: run_debug(x, 1, 1), axis_name="i")(global_inp)

Split=0 Concat=1 -> Shape=(Array(4, dtype=int32), Array(4, dtype=int32))
Split=0 Concat=1 -> Shape=(Array(4, dtype=int32), Array(4, dtype=int32))
Split=0 Concat=1 -> Shape=(Array(4, dtype=int32), Array(4, dtype=int32))
Split=0 Concat=1 -> Shape=(Array(4, dtype=int32), Array(4, dtype=int32))
Split=1 Concat=0 -> Shape=(Array(4, dtype=int32), Array(4, dtype=int32))
Split=1 Concat=0 -> Shape=(Array(4, dtype=int32), Array(4, dtype=int32))
Split=1 Concat=0 -> Shape=(Array(4, dtype=int32), Array(4, dtype=int32))
Split=1 Concat=0 -> Shape=(Array(4, dtype=int32), Array(4, dtype=int32))
Split=1 Concat=1 -> Shape=(Array(4, dtype=int32), Array(4, dtype=int32))
Split=1 Concat=1 -> Shape=(Array(4, dtype=int32), Array(4, dtype=int32))
Split=1 Concat=1 -> Shape=(Array(4, dtype=int32), Array(4, dtype=int32))
Split=1 Concat=1 -> Shape=(Array(4, dtype=int32), Array(4, dtype=int32))


Array([[[ 0, 16, 32, 48],
        [ 4, 20, 36, 52],
        [ 8, 24, 40, 56],
        [12, 28, 44, 60]],

       [[ 1, 17, 33, 49],
        [ 5, 21, 37, 53],
        [ 9, 25, 41, 57],
        [13, 29, 45, 61]],

       [[ 2, 18, 34, 50],
        [ 6, 22, 38, 54],
        [10, 26, 42, 58],
        [14, 30, 46, 62]],

       [[ 3, 19, 35, 51],
        [ 7, 23, 39, 55],
        [11, 27, 43, 59],
        [15, 31, 47, 63]]], dtype=int32)