In [203]:
base_emb_dim = 128
batch_size = 2
length = 1024
inputs = jax.random.uniform(jax.random.PRNGKey(9876), [batch_size, length, base_emb_dim])

topn = 2
num_groups = inputs.shape[0]
token_shape = inputs.shape[:-1]
num_tokens = np.prod(token_shape)
m_dim = inputs.shape[-1]

tokens_per_group = num_tokens // num_groups
expert_capacity_factor = 1000
min_group_size = 1
num_experts = 8
expert_capacity = int(expert_capacity_factor * tokens_per_group / num_experts)

max_group_size = int(inputs.shape[1])
expert_capacity = min(expert_capacity, max_group_size)
expert_capacity = max(expert_capacity, min_group_size)
print(f'expert_capacity: {expert_capacity}')

# gsm
grouped_inputs = jnp.reshape(inputs, (num_groups, tokens_per_group, base_emb_dim))
# grouped_inputs = _split(grouped_inputs, (('replica', 'data'), None, 'mdl'))
token_inputs = jax.lax.convert_element_type(grouped_inputs, jnp.float32)
print(f'token_inputs: {token_inputs.shape}')

expert_capacity: 1024
token_inputs: (2, 1024, 128)


In [204]:
# params
_call_experts = jax.random.uniform(jax.random.PRNGKey(9876), [1, base_emb_dim, base_emb_dim])
call_experts = _call_experts.repeat(num_experts, 0)
router_gate = jax.random.uniform(jax.random.PRNGKey(9876), [base_emb_dim, num_experts])

# gsm
grouped_inputs = jnp.reshape(inputs, (num_groups, tokens_per_group, base_emb_dim))
token_inputs = jax.lax.convert_element_type(grouped_inputs, jnp.float32)
print(f'token_inputs: {token_inputs.shape}')
dtype = jnp.bfloat16
router_logits = jnp.einsum('bld,de->ble', token_inputs, router_gate)
# gse
router_probs = jax.nn.softmax(router_logits.astype(jnp.float32), axis=-1)
router_probs = router_probs.astype(dtype) # ble


token_inputs: (2, 1024, 128)


In [205]:
Array = jnp.ndarray


def _take_along_axis(array: Array, indices: Array, axis: int) -> Array:
    if array.ndim != indices.ndim:
        raise ValueError(
            'indices and array must have the same number of dimensions; '
            f'{indices.ndim} vs. {array.ndim}.')

    if (axis != -1 and axis != array.ndim - 1 and  # Not last dimension
        axis != 1 and axis != -array.ndim + 1):  # Not second dimension
        raise ValueError(
            'Only slices along the second or last dimension are supported; '
            f'array.ndim = {array.ndim}, while axis = {axis}.')

    if 1 or _favor_one_hot_slices():
        one_hot_length = array.shape[axis]
        one_hot_indices = jax.nn.one_hot(indices, one_hot_length, axis=axis)
        print(one_hot_indices.shape)

        if axis == -1 or array.ndim == 1:
            result = jnp.einsum(
                '...s,...is->...i',
                array,
                one_hot_indices,
                precision=jax.lax.Precision.HIGHEST)
        else:
            result = jnp.einsum(
                'ns...,nis...->ni...',
                array,
                one_hot_indices,
                precision=jax.lax.Precision.HIGHEST)
        return jax.lax.convert_element_type(result, array.dtype), one_hot_indices.max(2)
    else:
        return jnp.take_along_axis(array, indices, axis=axis), None


def _top_k(array, k: int):
    if 1 or _favor_one_hot_slices():
        top_k_indices = jax.lax.top_k(array, k)[-1]
        top_k_values, one_hot_indices = _take_along_axis(array, top_k_indices, axis=-1)
        print(one_hot_indices.shape)
        return top_k_values, top_k_indices, one_hot_indices
    else:
        return jax.lax.top_k(array, k), None


expert_gate, expert_index, one_hot_indices = _top_k(router_probs, k=topn)
router_probs *= one_hot_indices
router_probs /= router_probs.sum(-1, keepdims=True)

# g * 2 * s
expert_index = jnp.swapaxes(expert_index, 1, 2)
expert_index = expert_index.reshape(num_groups, -1)
# g * 2s
# g * 2s * e, expert_index 负值的地方忽略了?
expert_mask = jax.nn.one_hot(expert_index, num_experts, dtype=jnp.int32)
# g * 2s * e 
token_priority = jnp.cumsum(expert_mask, axis=1) * expert_mask - 1.0
# g * 2 * s * e
token_priority = token_priority.reshape(num_groups, topn, -1, num_experts)
token_priority = jnp.swapaxes(token_priority, 1, 2)
# g * s * 2 * e   ls: 每个token选择了2个专家，专家对应的位置的值表示当前编号专家选择的token数量
token_priority = jnp.max(token_priority, axis=2) 

compute_n_expert = num_experts
combined_outputs = None
_token_priority = token_priority
_router_probs = router_probs
# lsp： g * s * e * c  # 如果当前token选择了当前专家后，当前token被选中的总次数的one hot体现
_dispatch_mask = jax.nn.one_hot(_token_priority, int(expert_capacity), dtype=jnp.bool_)
# 把token选择专家的概率赋值到one_hot矩阵上
_combine_array = jnp.einsum('...se,...sec->...sec', _router_probs, _dispatch_mask)
_combine_array = jax.lax.convert_element_type(_combine_array, dtype)
# 专家的输入mask：gsm x gsec -> gecm
_expert_inputs = jnp.einsum('gs...,gsec->gec...', token_inputs, _dispatch_mask)
_expert_inputs = jax.lax.convert_element_type(_expert_inputs, dtype)
# gecm
# print(f'_expert_inputs: {_expert_inputs.shape}')
# # g * e * c * m
_expert_outputs = jnp.einsum('gecd,edm->gecm', _expert_inputs, call_experts)
# _expert_outputs = _call_experts(_expert_inputs, index, compute_n_expert, deterministic=deterministic)
_combined_outputs = jnp.einsum('gec...,gsec->gs...', _expert_outputs, _combine_array)

combined_outputs = _combined_outputs if combined_outputs is None else combined_outputs + _combined_outputs


(2, 1024, 2, 8)
(2, 1024, 8)


In [206]:
r = jnp.einsum('gcd,dm->gcm', token_inputs, _call_experts.squeeze())

In [208]:
r

Array([[[28.560665, 27.185635, 28.68104 , ..., 27.223991, 25.60635 ,
         30.83944 ],
        [30.88588 , 27.094418, 31.459332, ..., 29.879463, 27.927849,
         35.077682],
        [33.235268, 29.354126, 33.437965, ..., 31.392794, 28.466248,
         35.0506  ],
        ...,
        [33.00421 , 29.302488, 34.69037 , ..., 32.604965, 31.23965 ,
         36.159454],
        [31.797195, 29.900791, 32.5949  , ..., 31.50444 , 28.399307,
         33.13117 ],
        [32.66905 , 28.830624, 31.608982, ..., 30.723911, 29.060616,
         33.163406]],

       [[34.011265, 32.703957, 34.21184 , ..., 33.507072, 31.974867,
         37.61277 ],
        [28.907455, 29.038662, 32.866028, ..., 29.621647, 26.987638,
         34.60702 ],
        [29.792133, 27.072746, 30.789768, ..., 29.37832 , 27.429031,
         31.339783],
        ...,
        [31.87959 , 29.07582 , 30.04136 , ..., 29.387972, 28.34523 ,
         34.3482  ],
        [27.179043, 27.306644, 29.580395, ..., 27.039307, 25.54274 ,
   

In [207]:
combined_outputs

Array([[[28.5     , 27.125   , 28.625   , ..., 27.25    , 25.625   ,
         30.875   ],
        [30.875   , 27.125   , 31.5     , ..., 29.875   , 27.875   ,
         35.      ],
        [33.18506 , 29.317627, 33.43457 , ..., 31.31372 , 28.444336,
         34.93164 ],
        ...,
        [33.      , 29.25    , 34.75    , ..., 32.5     , 31.25    ,
         36.25    ],
        [31.718994, 29.845825, 32.46826 , ..., 31.469238, 28.34729 ,
         33.21753 ],
        [32.813965, 28.931396, 31.686768, ..., 30.810059, 29.05664 ,
         33.31494 ]],

       [[34.      , 32.75    , 34.25    , ..., 33.5     , 32.      ,
         37.5     ],
        [28.931396, 29.05664 , 32.813965, ..., 29.682861, 27.052734,
         34.567383],
        [29.808105, 27.177979, 30.810059, ..., 29.432373, 27.428467,
         31.43628 ],
        ...,
        [31.812744, 29.068115, 29.941406, ..., 29.317627, 28.31958 ,
         34.183105],
        [27.125   , 27.25    , 29.625   , ..., 27.      , 25.5     ,
   

In [195]:
combined_outputs

Array([[[0.35282516, 0.63547516],
        [0.0423975 , 0.08089638],
        [0.44726562, 0.86328125],
        [0.5703125 , 1.125     ]]], dtype=float32)

In [197]:
r

Array([[[0.35335103, 0.6379664 ],
        [0.04244471, 0.08118868],
        [0.4475174 , 0.86351776],
        [0.5697632 , 1.1248779 ]]], dtype=float32)

In [184]:
_dispatch_mask

Array([[[[False, False, False, False],
         [False, False, False,  True],
         [False, False, False, False],
         [ True, False, False, False]],

        [[False, False, False, False],
         [ True, False, False, False],
         [False, False, False, False],
         [False,  True, False, False]],

        [[False, False, False, False],
         [False,  True, False, False],
         [False, False, False, False],
         [False, False,  True, False]],

        [[False, False, False, False],
         [False, False,  True, False],
         [False, False, False, False],
         [False, False, False,  True]]]], dtype=bool)

In [187]:
_combine_array

Array([[[[0, 0, 0, 0],
         [0, 0, 0, 0.490234],
         [0, 0, 0, 0],
         [0.507812, 0, 0, 0]],

        [[0, 0, 0, 0],
         [0.503906, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0.494141, 0, 0]],

        [[0, 0, 0, 0],
         [0, 0.550781, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0.449219, 0]],

        [[0, 0, 0, 0],
         [0, 0, 0.59375, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0.40625]]]], dtype=bfloat16)

In [178]:
_token_priority

Array([[[-1.,  3., -1.,  0.],
        [-1.,  0., -1.,  1.],
        [-1.,  1., -1.,  2.],
        [-1.,  2., -1.,  3.]]], dtype=float32)

In [185]:
_router_probs

Array([[[0.        , 0.49084252, 0.        , 0.50915754],
        [0.        , 0.50579154, 0.        , 0.4942085 ],
        [0.        , 0.5517242 , 0.        , 0.4482759 ],
        [0.        , 0.5944171 , 0.        , 0.40558293]]], dtype=float32)

16.0

In [164]:
_dispatch_mask.shape

(1, 16, 4, 16)

In [118]:
a = jnp.array([1, 3 ,4])
b = jax.nn.one_hot(a, 2, dtype=jnp.bool_)
b

Array([[False,  True],
       [False, False],
       [False, False]], dtype=bool)

In [117]:
b.shape

(3, 5)

In [108]:
_token_priority.shape

(1, 4, 4)

In [113]:
_token_priority

Array([[[-1.,  3., -1.,  0.],
        [-1.,  0., -1.,  1.],
        [-1.,  1., -1.,  2.],
        [-1.,  2., -1.,  3.]]], dtype=float32)

In [94]:
token_priority

Array([[[[-1., -1., -1.,  0.],
         [-1.,  3., -1., -1.]],

        [[-1.,  0., -1., -1.],
         [-1., -1., -1.,  1.]],

        [[-1.,  1., -1., -1.],
         [-1., -1., -1.,  2.]],

        [[-1.,  2., -1., -1.],
         [-1., -1., -1.,  3.]]]], dtype=float32, weak_type=True)

In [92]:
token_priority

Array([[[[-1., -1., -1.,  0.],
         [-1.,  0., -1., -1.],
         [-1.,  1., -1., -1.],
         [-1.,  2., -1., -1.]],

        [[-1.,  3., -1., -1.],
         [-1., -1., -1.,  1.],
         [-1., -1., -1.,  2.],
         [-1., -1., -1.,  3.]]]], dtype=float32, weak_type=True)

In [90]:
token_priority

Array([[[-1., -1., -1.,  0.],
        [-1.,  0., -1., -1.],
        [-1.,  1., -1., -1.],
        [-1.,  2., -1., -1.],
        [-1.,  3., -1., -1.],
        [-1., -1., -1.,  1.],
        [-1., -1., -1.,  2.],
        [-1., -1., -1.,  3.]]], dtype=float32, weak_type=True)

In [87]:
expert_mask

Array([[[0, 0, 0, 1],
        [0, 1, 0, 0],
        [0, 1, 0, 0],
        [0, 1, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 0, 1]]], dtype=int32)

In [88]:
expert_mask.shape

(1, 8, 4)

In [85]:
expert_index

Array([[3, 1, 1, 1, 1, 3, 3, 3]], dtype=int32)

In [121]:
_combine_array

Array([[[[0],
         [0],
         [0],
         [0.507812]],

        [[0],
         [0.503906],
         [0],
         [0]],

        [[0],
         [0],
         [0],
         [0]],

        [[0],
         [0],
         [0],
         [0]]]], dtype=bfloat16)

(1, 4, 2, 4)
(1, 4, 4)


In [71]:
token_priority

Array([[[-1.,  3., -1.,  0.],
        [-1.,  0., -1.,  1.],
        [-1.,  1., -1.,  2.],
        [-1.,  2., -1.,  3.]]], dtype=float32)

In [70]:
router_probs

Array([[[0.        , 0.49084252, 0.        , 0.50915754],
        [0.        , 0.50579154, 0.        , 0.4942085 ],
        [0.        , 0.5517242 , 0.        , 0.4482759 ],
        [0.        , 0.5944171 , 0.        , 0.40558293]]], dtype=float32)

In [69]:
_dispatch_mask.astype('b')[0]

Array([[[0],
        [0],
        [0],
        [1]],

       [[0],
        [1],
        [0],
        [0]],

       [[0],
        [0],
        [0],
        [0]],

       [[0],
        [0],
        [0],
        [0]]], dtype=int8)

In [50]:
combined_outputs

Array([[[0.17951965, 0.32333374],
        [0.02140617, 0.04084396],
        [0.        , 0.        ],
        [0.        , 0.        ]]], dtype=float32)

In [None]:
_expert_inputs

In [57]:
_expert_inputs[0]

Array([[[0, 0]],

       [[0.0296631, 0.0654297]],

       [[0, 0]],

       [[0.00136566, 0.734375]]], dtype=bfloat16)

In [51]:
r = jnp.einsum('gcd,dm->gcm', token_inputs, _call_experts.squeeze(0))

In [52]:
r

Array([[[0.35335103, 0.6379664 ],
        [0.04244471, 0.08118868],
        [0.4475174 , 0.86351776],
        [0.5697632 , 1.1248779 ]]], dtype=float32)

In [53]:
r.shape

(1, 4, 2)

In [14]:
_call_experts.squeeze(0)

Array([[0.01139784, 0.9637133 , 0.17142892, ..., 0.519408  , 0.40298927,
        0.50899124],
       [0.33128166, 0.11407924, 0.45493877, ..., 0.19184983, 0.6729857 ,
        0.6878849 ],
       [0.5041802 , 0.5340363 , 0.9612719 , ..., 0.7156762 , 0.5960239 ,
        0.79638684],
       ...,
       [0.00309777, 0.5363393 , 0.3114102 , ..., 0.04296672, 0.2589947 ,
        0.66974473],
       [0.9366919 , 0.9258721 , 0.4500034 , ..., 0.8394394 , 0.5943967 ,
        0.94559646],
       [0.93420994, 0.38898253, 0.5649245 , ..., 0.81798625, 0.601208  ,
        0.9051266 ]], dtype=float32)

In [25]:
_expert_inputs

Array([[[[0.671875, 0.212891, 0.365234, ..., 0.953125, 0.863281,
          0.0581055],
         [0.90625, 0.257812, 0.402344, ..., 0.703125, 0.198242,
          0.0559082],
         [0.769531, 0.237305, 0.789062, ..., 0.392578, 0.476562,
          0.217773]],

        [[0.480469, 0.166016, 0.154297, ..., 0.120117, 0.427734,
          0.621094],
         [0.828125, 0.451172, 0.964844, ..., 0.414062, 0.804688,
          0.527344],
         [0.010437, 0.296875, 0.773438, ..., 0.46875, 0.703125,
          0.832031]],

        [[0.480469, 0.166016, 0.154297, ..., 0.120117, 0.427734,
          0.621094],
         [0.929688, 0.503906, 0.300781, ..., 0.644531, 0.617188,
          0.102539],
         [0.828125, 0.451172, 0.964844, ..., 0.414062, 0.804688,
          0.527344]],

        [[0.199219, 0.636719, 0.761719, ..., 0.625, 0.202148, 0.408203],
         [0.259766, 0.390625, 0.0598145, ..., 0.00405884, 0.24707,
          0.625],
         [0.710938, 0.0791016, 0.683594, ..., 0.855469, 0.7929

In [13]:
token_inputs

Array([[[0.67017794, 0.21266353, 0.3643179 , ..., 0.95354056,
         0.86200583, 0.05801117],
        [0.90641916, 0.2581885 , 0.40254116, ..., 0.70177364,
         0.1986829 , 0.05588424],
        [0.76905775, 0.23724616, 0.7909447 , ..., 0.39260423,
         0.47644675, 0.21751463],
        ...,
        [0.01041508, 0.29660642, 0.7722591 , ..., 0.46841896,
         0.7017602 , 0.8304368 ],
        [0.2589029 , 0.3903619 , 0.05971932, ..., 0.00406814,
         0.24726093, 0.6248237 ],
        [0.7118826 , 0.07900417, 0.68311775, ..., 0.8543353 ,
         0.7920536 , 0.66045046]],

       [[0.35675788, 0.5723851 , 0.05660164, ..., 0.8213806 ,
         0.15491045, 0.13453877],
        [0.27500308, 0.0654732 , 0.5509267 , ..., 0.25192523,
         0.06606376, 0.5580566 ],
        [0.2960205 , 0.6988629 , 0.86140907, ..., 0.783445  ,
         0.1348989 , 0.54412246],
        ...,
        [0.07495975, 0.2266761 , 0.50004804, ..., 0.00884986,
         0.35092378, 0.6600951 ],
        [0.5

In [12]:
r

Array([[[30.616161, 30.490242, 32.09529 , ..., 31.686031, 29.929438,
         35.661602],
        [30.968822, 29.62692 , 32.57369 , ..., 29.640743, 28.168673,
         32.26693 ],
        [32.03814 , 31.602169, 31.933764, ..., 31.83572 , 29.167837,
         35.62698 ],
        ...,
        [26.817465, 26.582998, 28.993309, ..., 28.692616, 26.634424,
         31.390823],
        [29.20658 , 28.143347, 29.073826, ..., 28.767738, 26.71118 ,
         32.139866],
        [35.162743, 32.240837, 35.101612, ..., 33.50216 , 31.391727,
         37.2987  ]],

       [[30.344604, 27.864147, 31.761086, ..., 30.24715 , 29.06248 ,
         32.68872 ],
        [31.731127, 28.723663, 33.13195 , ..., 30.487007, 28.915962,
         31.881374],
        [33.392685, 30.383389, 32.017506, ..., 31.762093, 31.153368,
         35.39888 ],
        ...,
        [28.900892, 28.08874 , 29.180866, ..., 29.960375, 27.57775 ,
         34.689297],
        [30.796967, 30.372814, 33.367855, ..., 31.086185, 27.887077,
   

In [9]:
token_inputs.shape

(3, 12, 128)

In [7]:
combined_outputs.shape

(3, 12, 128)