<a href="https://colab.research.google.com/github/lizhieffe/llm_knowledge/blob/main/examples/pytorch_dist/%5BDist%5D_PyTorch_EP_Practice_(MoE).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import torch.nn as nn

# MLP

Each MLP instance is an expert in MoE.

In [35]:
class MLP(nn.Module):
  def __init__(self, emb_dim: int):
    super().__init__()
    self.l1 = torch.randn((emb_dim, emb_dim * 4), dtype=torch.float32)  # [ED, 4*ED]
    self.relu = nn.ReLU()
    self.l2 = torch.randn((emb_dim * 4, emb_dim), dtype=torch.float32)  # [4*ED, ED]

  def forward(self, x: torch.Tensor):
    BS, N, ED = x.shape
    y = x @ self.l1     # [BS, N, 4*ED]
    y = self.relu(y)    # [BS, N, 4*ED]
    y = y @ self.l2     # [BS, N, ED]
    return y            # [BS, N, ED]

# Test
BS = 4
N = 16
emb_dim = 32

input = torch.randn((BS, N, emb_dim), dtype=torch.float32)
mlp = MLP(emb_dim=emb_dim)
output = mlp(input)
assert output.shape == input.shape

# Gate

In [82]:
class Gate(nn.Module):
  def __init__(self, emb_dim: int, n_experts: int):
    super().__init__()
    self.l_gate = torch.randn((emb_dim, n_experts), dtype=torch.float32)  # [ED, NE]

  def forward(self, x: torch.Tensor):
    # BS, N, ED = x.shape
    logit_gate = x @ self.l_gate            # [BS, N, NE]
    prob_gate = logit_gate.softmax(dim=-1)  # [BS, N, NE]
    return prob_gate                        # [BS, N, NE]

# Test

BS = 4
N = 16
emb_dim = 32
n_experts = 4

input = torch.randn((BS, N, emb_dim), dtype=torch.float32)
gate = Gate(emb_dim=emb_dim, n_experts=n_experts)
output = gate(input)
assert output.shape == (BS, N, n_experts)
assert torch.allclose(output.sum(dim=-1), torch.ones(BS, N))

# SwitchMoE

References
- https://1a3orn.com/sub/essays-intro-to-moe.html
- https://colab.research.google.com/github/1a3orn/very-simple-moe/blob/main/Switch_MoE.ipynb?authuser=1
- https://colab.research.google.com/drive/1ZgNfgg91JYcWEGDavYpdV59tg-krd060?authuser=1#scrollTo=v_aNH9z78Gox



In [144]:
BS = 4
N = 16
emb_dim = 32
n_experts = 4

torch.manual_seed(123)

experts = nn.ModuleList(modules=[MLP(emb_dim=emb_dim) for _ in range(n_experts)])
assert len(experts) == n_experts

input = torch.randn((BS, N, emb_dim), dtype=torch.float32)  # [BS, N, ED]
input = input.reshape(-1, emb_dim)                          # [BS*N, ED]
output = torch.zeros_like(input)                             # [BS*N, ED]

gate = Gate(emb_dim=emb_dim, n_experts=n_experts)
gate_prob = gate(input)                                     # [BS*N, NE]
assert gate_prob.shape == (BS * N, n_experts)
# print(f"{gate_prob[0]=}")

gate_top1_idx = gate_prob.argmax(dim=-1)                    # [BS*N]
gate_top1_onehot = nn.functional.one_hot(gate_top1_idx, num_classes=n_experts)  # [BS*N, NE]
assert gate_top1_onehot.shape == (BS * N, n_experts)

per_expert_tokens = []

for i, expert in enumerate(experts):
  # The mask about whether the tokens are routed to the current expert
  mask = gate_top1_onehot[:, i]     # [BS*N]
  assert mask.shape == (BS*N,)
  mask = mask==True                 # [BS*N]
  assert mask.shape == (BS*N,)
  # print(f"{mask=}")

  # Keep only the tokens for the current expert
  expert_input = input[mask]        # [N_EXPERT_INPUT, ED]
  per_expert_tokens.append(len(expert_input))

  # The expert output
  expert_output = expert(expert_input)
  assert expert_input.shape == expert_output.shape

  # Merge the expert output to the final output.
  output[mask] = expert_output

  # Verify the expert output are merged to the final output correctly.
  if i == 0:
    assert not torch.allclose(output[mask], torch.zeros_like(expert_output))
    mask_reverse = mask==False
    assert output[mask_reverse].shape[0] + output[mask].shape[0] == BS * N
    assert torch.allclose(output[mask_reverse], torch.zeros_like(output[mask_reverse]))

# print(f"{per_expert_tokens=}")
assert sum(per_expert_tokens) == BS * N

output = output.reshape(BS, N, -1)
assert output.shape == (BS, N, emb_dim)
# print(output[:11])

In [145]:
# @title Test tensor idx

x = torch.tensor([1,2,3])
idx = [True, False, True]

print(f"{idx=}")
x[idx]

idx=[True, False, True]


tensor([1, 3])