In [1]:
import numpy as np
from numpy.random import randn
from itertools import count
import sys

np.random.seed(1)

In [2]:
#expert_num = 64 # number of experts
#expert_dim = 512
#emb_dim = 256 # embedding size
#token_num = 2048 # think token count, from unrolled batches possibly

expert_num = 5 # number of experts
expert_dim = 7
emb_dim = 11 # embedding size
token_num = 13 # think token count, from unrolled batches possibly

# Bypass weight
b = 0.1

# Export sizes and bypass weight to C++ code as well
exported = dict(
    expert_count=expert_num,
    expert_size=expert_dim,
    embedding_size=emb_dim,
    token_count=token_num,
    b=np.array([b], dtype=np.float32) # I don't care that those numbers are longs, but I want this multiplier to be float32.
)

# token count * embedding size. Explicitly specifying order=C so it matches the C++ implementation.
words = randn(token_num, emb_dim).astype(np.float32, order='C')
exported["src"] = words

In [3]:
words

array([[ 1.6243454 , -0.6117564 , -0.5281718 , -1.0729686 ,  0.86540765,
        -2.3015387 ,  1.7448118 , -0.7612069 ,  0.3190391 , -0.24937038,
         1.4621079 ],
       [-2.0601406 , -0.3224172 , -0.38405436,  1.1337694 , -1.0998913 ,
        -0.1724282 , -0.8778584 ,  0.04221375,  0.58281523, -1.1006192 ,
         1.1447237 ],
       [ 0.9015907 ,  0.50249434,  0.90085596, -0.68372786, -0.12289023,
        -0.93576944, -0.26788807,  0.53035545, -0.69166076, -0.39675352,
        -0.6871727 ],
       [-0.84520566, -0.6712461 , -0.0126646 , -1.1173104 ,  0.2344157 ,
         1.6598022 ,  0.74204415, -0.19183555, -0.887629  , -0.7471583 ,
         1.6924546 ],
       [ 0.05080776, -0.6369957 ,  0.19091548,  2.1002553 ,  0.12015896,
         0.6172031 ,  0.30017033, -0.35224986, -1.1425182 , -0.34934273,
        -0.20889424],
       [ 0.5866232 ,  0.8389834 ,  0.9311021 ,  0.2855873 ,  0.8851412 ,
        -0.7543979 ,  1.2528682 ,  0.5129298 , -0.29809284,  0.48851815,
        -0.075

In [4]:
experts_w1 = randn(expert_num, emb_dim, expert_dim).astype(np.float32, order='C')
exported["experts_w1"] = experts_w1

experts_b1 = randn(expert_num, 1, expert_dim).astype(np.float32, order='C')
exported["experts_b1"] = experts_b1

experts_w2 = randn(expert_num, expert_dim, emb_dim).astype(np.float32, order='C')
exported["experts_w2"] = experts_w2

experts_b2 = randn(expert_num, 1, emb_dim).astype(np.float32, order='C')
exported["experts_b2"] = experts_b2

In [5]:
def scramble(a, axis=-1):
    """
    Shuffle `a` in-place along the given axis.

    Apply numpy.random.shuffle to the given axis of `a`.
    Each one-dimensional slice is shuffled independently.
    """
    b = a.swapaxes(axis, -1)
    # Shuffle `b` in-place along the last axis.  `b` is a view of `a`,
    # so `a` is shuffled in place, too.
    shp = b.shape[:-1]
    for ndx in np.ndindex(shp):
        np.random.shuffle(b[ndx])
    return

In [6]:
# Generate random binary router

router = np.zeros((token_num, expert_num)).astype(np.float32)
print(router.shape)
# Select K for TOP-K
k = 4

counter = 0
expert = 0
for t in range(token_num):
    for counter in range(k):
        router[t][expert + counter] = 1


np.set_printoptions(threshold=sys.maxsize)
#print(router)

# router_shuf = np.transpose(router)
scramble(router, axis=-1)
# router = np.transpose(router_shuf)

# Sanity check: sum per expert
print(f'Total is {np.sum(router)}, expected {token_num * k}')
print(f'Expect around {token_num * k / expert_num} per expert')
print(np.sum(router, axis=0))

# Order seems to be [token,expert]. Flip it back to [expert,token] to make the for-loop easier
router = np.transpose(router)
print(router.shape)

# Router can be tiny, we just need booleans -> uint8. Also order='C' necessary otherwise it
# will be saved flipped, since we transposed router and numpy implements this by just
# changing the memory order marker.
exported["router"] = router.astype(np.uint8, order='C')

(13, 5)
Total is 52.0, expected 52
Expect around 10.4 per expert
[10. 10.  9. 12. 11.]
(5, 13)


In [7]:
router

array([[1., 1., 1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1.],
       [1., 0., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.],
       [0., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0., 1., 1.],
       [1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.]],
      dtype=float32)

In [16]:
# words(7) * emb_size(4)
total_output = b * words # … so that if a word doesn't go through any expert it will still have some value

# Repeat for every expert, looking at routed do determine which words go to said expert
for n, expert_w1, expert_b1, expert_w2, expert_b2, mask in zip(count(), experts_w1, experts_b1, experts_w2, experts_b2, router):
    print(f"Expert {n}:")
    print(f"Expert {n} gets {mask.sum()} tokens")
    assert mask.shape[0] == words.shape[0]
    
    # select all words where the mask for this expert is > 0
    expert_input = words[mask.nonzero()]
    
    # classic matmul op for feed-forward I guess? + relu
    expert_output1 = np.maximum(0.0, np.add(np.matmul(expert_input, expert_w1), expert_b1))
    expert_output2 = np.maximum(0.0, np.add(np.matmul(expert_output1, expert_w2), expert_b2))
    assert expert_output2.shape[0] == mask.sum()
    
    # Addition for all the selected experts.
    # TODO: after loop: normalisation? We're now summing K experts + `b` of the original input.
    total_output[mask.nonzero()] += expert_output2

    print(f"  {expert_input.shape} * {expert_w1.shape} + {expert_b1.shape} = {expert_output1.shape}\n")
    print(f"  {expert_output1.shape} * {expert_w2.shape} + {expert_b2.shape} = {expert_output2.shape}\n")
    
    exported[f"expert_{n}_src"] = expert_input.astype(np.float32, order='C')
    exported[f"expert_{n}_dst"] = expert_output2.astype(np.float32, order='C')
    
print(f"Total output: {total_output.shape}")
exported["dst"] = total_output.astype(np.float32, order='C')

Expert 0:
Expert 0 gets 10.0 tokens
  (10, 11) * (11, 7) + (1, 7) = (10, 7)

  (10, 7) * (7, 11) + (1, 11) = (10, 11)

Expert 1:
Expert 1 gets 10.0 tokens
  (10, 11) * (11, 7) + (1, 7) = (10, 7)

  (10, 7) * (7, 11) + (1, 11) = (10, 11)

Expert 2:
Expert 2 gets 9.0 tokens
  (9, 11) * (11, 7) + (1, 7) = (9, 7)

  (9, 7) * (7, 11) + (1, 11) = (9, 11)

Expert 3:
Expert 3 gets 12.0 tokens
  (12, 11) * (11, 7) + (1, 7) = (12, 7)

  (12, 7) * (7, 11) + (1, 11) = (12, 11)

Expert 4:
Expert 4 gets 11.0 tokens
  (11, 11) * (11, 7) + (1, 7) = (11, 7)

  (11, 7) * (7, 11) + (1, 11) = (11, 11)

Total output: (13, 11)


In [9]:
np.savez("data.npz", **exported)