In [1]:
import numpy as np
from numpy.random import randn
from itertools import count

np.random.seed(1)

In [2]:
expert_size = 3 # number of experts
emb_size = 5 # embedding size
word_size = 7 # think token count, from unrolled batches possibly

# token count * embedding size
words = randn(word_size, emb_size)

In [10]:
# expert (3) * emb_size (5) * emb_size (5)
experts_w = randn(expert_size, emb_size, emb_size)

# expert (3) * emb_prime_size (4) * emb_size (5)
experts_b = randn(expert_size, 1, emb_size)

In [15]:
# experts (3) * tokens
router = np.array([
    [1, 1, 0, 0, 0, 0, 0], # first two tokens go to expert 1
    [0, 0, 1, 1, 0, 0, 0], # second two tokens go to expert 2
    [0, 0, 0, 0, 1, 1, 0]  # the very last token doesn't go to any expert(!)
], dtype=np.uint8)

In [16]:
b = 0.1

# words(7) * emb_size(4)
total_output = b * words # … so that if a word doesn't go through any expert it will still have some value

# Repeat for every expert, looking at routed do determine which words go to said expert
for n, expert_w, expert_b, mask in zip(count(), experts_w, experts_b, router):
    print(f"Expert {n}:")
    
    # select all words where the mask for this expert is > 0
    expert_input = words[mask.nonzero()]
    
    # classic matmul op for feed-forward I guess?
    expert_output = np.add(np.matmul(expert_input, expert_w), expert_b)

    # I did assignment here, but could also be addition
    total_output[mask.nonzero()] = expert_output

    print(f"  {expert_input.shape} * {expert_w.shape} + {expert_b.shape} = {expert_output.shape}\n")

print(f"Total output: {total_output.shape}")

Expert 0:
  (2, 5) * (5, 5) + (1, 5) = (2, 5)

Expert 1:
  (2, 5) * (5, 5) + (1, 5) = (2, 5)

Expert 2:
  (2, 5) * (5, 5) + (1, 5) = (2, 5)

Total output: (7, 5)


In [17]:
np.savez("data.npz", src=words, dst=total_output, experts_w=experts_w, experts_b=experts_b, router=router)