In [2]:
import numpy as np
from numpy.random import randn
from itertools import count

np.random.seed(1)

# Export stuffs!
exported = dict()

In [17]:
expert_size = 3 # number of experts
emb_size = 5 # embedding size
word_size = 7 # think token count, from unrolled batches possibly

# token count * embedding size
words = randn(word_size, emb_size).astype(np.float32)
exported["src"] = words

In [18]:
# expert (3) * emb_size (5) * emb_size (5)
experts_w = randn(expert_size, emb_size, emb_size).astype(np.float32)
exported["experts_w"] = experts_w

# expert (3) * emb_prime_size (4) * emb_size (5)
experts_b = randn(expert_size, 1, emb_size).astype(np.float32)
exported["experts_b"] = experts_b

In [19]:
# experts (3) * tokens
router = np.array([
    [1, 1, 0, 0, 0, 0, 0], # first two tokens go to expert 1
    [0, 0, 1, 1, 0, 0, 0], # second two tokens go to expert 2
    [0, 0, 0, 0, 1, 1, 0]  # the very last token doesn't go to any expert(!)
], dtype=np.uint8)
exported["router"] = router

In [20]:
words.dtype

dtype('float32')

In [21]:
b = 0.1

# words(7) * emb_size(4)
total_output = b * words # … so that if a word doesn't go through any expert it will still have some value

# Repeat for every expert, looking at routed do determine which words go to said expert
for n, expert_w, expert_b, mask in zip(count(), experts_w, experts_b, router):
    print(f"Expert {n}:")
    
    # select all words where the mask for this expert is > 0
    expert_input = words[mask.nonzero()]
    
    # classic matmul op for feed-forward I guess?
    expert_output = np.add(np.matmul(expert_input, expert_w), expert_b)

    # I did assignment here, but could also be addition
    total_output[mask.nonzero()] = expert_output

    print(f"  {expert_input.shape} * {expert_w.shape} + {expert_b.shape} = {expert_output.shape}\n")
    
    exported[f"expert_{n}_src"] = expert_input
    exported[f"expert_{n}_dst"] = expert_output
    
print(f"Total output: {total_output.shape}")
exported["dst"] = total_output

Expert 0:
  (2, 5) * (5, 5) + (1, 5) = (2, 5)

Expert 1:
  (2, 5) * (5, 5) + (1, 5) = (2, 5)

Expert 2:
  (2, 5) * (5, 5) + (1, 5) = (2, 5)

Total output: (7, 5)


In [26]:
np.savez("data.npz", **exported)

In [28]:
experts_w[0]

array([[-0.51709443, -0.9970268 ,  0.24879916, -0.29664114,  0.49521133],
       [-0.17470317,  0.98633516,  0.21353391,  2.1906998 , -1.8963609 ],
       [-0.6469167 ,  0.9014869 ,  2.5283258 , -0.24863477,  0.04366899],
       [-0.22631425,  1.3314571 , -0.28730786,  0.68006986, -0.3198016 ],
       [-1.2725588 ,  0.31354773,  0.5031848 ,  1.2932259 , -0.11044703]],
      dtype=float32)