<a href="https://colab.research.google.com/github/lizhieffe/llm_knowledge/blob/main/examples/pytorch_dist/%5BDist%5D_PyTorch_EP_Practice_(MoE).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import os

import torch
import torch.nn as nn

import torch.distributed as dist
import torch.multiprocessing as mp

# MLP

Each MLP instance is an expert in MoE.

In [11]:
class MLP(nn.Module):
  def __init__(self, emb_dim: int):
    super().__init__()
    self.l1 = torch.randn((emb_dim, emb_dim * 4), dtype=torch.float32)  # [ED, 4*ED]
    self.relu = nn.ReLU()
    self.l2 = torch.randn((emb_dim * 4, emb_dim), dtype=torch.float32)  # [4*ED, ED]

  def forward(self, x: torch.Tensor):
    # BS, N, ED = x.shape
    y = x @ self.l1     # [BS, N, 4*ED]
    y = self.relu(y)    # [BS, N, 4*ED]
    y = y @ self.l2     # [BS, N, ED]
    return y            # [BS, N, ED]

# Test
BS = 4
N = 16
emb_dim = 32

input = torch.randn((BS, N, emb_dim), dtype=torch.float32)
mlp = MLP(emb_dim=emb_dim)
output = mlp(input)
assert output.shape == input.shape

# Gate

In [12]:
class Gate(nn.Module):
  def __init__(self, emb_dim: int, n_experts: int):
    super().__init__()
    # print(f"{n_experts=}")
    self.l_gate = torch.randn((emb_dim, n_experts), dtype=torch.float32)  # [ED, NE]

  def forward(self, x: torch.Tensor):
    # BS, N, ED = x.shape
    logit_gate = x @ self.l_gate            # [BS, N, NE]
    # print(f"{logit_gate.shape=}")
    prob_gate = logit_gate.softmax(dim=-1)  # [BS, N, NE]
    return prob_gate                        # [BS, N, NE]

# Test

BS = 4
N = 16
emb_dim = 32
n_experts = 4

input = torch.randn((BS, N, emb_dim), dtype=torch.float32)
gate = Gate(emb_dim=emb_dim, n_experts=n_experts)
output = gate(input)
assert output.shape == (BS, N, n_experts)
assert torch.allclose(output.sum(dim=-1), torch.ones(BS, N))

# SwitchMoE

References
- https://1a3orn.com/sub/essays-intro-to-moe.html
- https://colab.research.google.com/github/1a3orn/very-simple-moe/blob/main/Switch_MoE.ipynb?authuser=1
- https://colab.research.google.com/drive/1ZgNfgg91JYcWEGDavYpdV59tg-krd060?authuser=1#scrollTo=v_aNH9z78Gox



In [13]:
class SwitchMoE(nn.Module):
  def __init__(self, emb_dim: int, n_experts: int):
    super().__init__()
    self.emb_dim = emb_dim
    self.n_experts = n_experts

    self.experts = nn.ModuleList(modules=[MLP(emb_dim=emb_dim) for _ in range(n_experts)])
    assert len(self.experts) == n_experts

    self.gate = Gate(emb_dim=emb_dim, n_experts=n_experts)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    input = x                                                   # [BS, N, ED]
    input = input.reshape(-1, self.emb_dim)                          # [BS*N, ED]
    output = torch.zeros_like(input)                            # [BS*N, ED]

    gate_prob = self.gate(input)                                     # [BS*N, NE]
    assert gate_prob.shape == (BS * N, self.n_experts), f"{gate_prob.shape=}"
    # print(f"{gate_prob.shape=}")

    gate_top1_idx = gate_prob.argmax(dim=-1)                    # [BS*N]
    gate_top1_onehot = nn.functional.one_hot(gate_top1_idx, num_classes=self.n_experts)  # [BS*N, NE]
    assert gate_top1_onehot.shape == (BS * N, self.n_experts)

    per_expert_tokens = []

    for i, expert in enumerate(self.experts):
      # The mask about whether the tokens are routed to the current expert
      mask = gate_top1_onehot[:, i]     # [BS*N]
      assert mask.shape == (BS*N,)
      mask = mask==True                 # [BS*N]
      assert mask.shape == (BS*N,)
      # print(f"{mask=}")

      # Keep only the tokens for the current expert
      expert_input = input[mask]        # [N_EXPERT_INPUT, ED]
      per_expert_tokens.append(len(expert_input))

      # The expert output
      expert_output = expert(expert_input)
      assert expert_input.shape == expert_output.shape

      # Merge the expert output to the final output.
      output[mask] = expert_output

      # Verify the expert output are merged to the final output correctly.
      if i == 0:
        assert not torch.allclose(output[mask], torch.zeros_like(expert_output))
        mask_reverse = mask==False
        assert output[mask_reverse].shape[0] + output[mask].shape[0] == BS * N
        assert torch.allclose(output[mask_reverse], torch.zeros_like(output[mask_reverse]))

    # print(f"{per_expert_tokens=}")
    assert sum(per_expert_tokens) == BS * N

    output = output.reshape(BS, N, -1)
    assert output.shape == (BS, N, self.emb_dim)
    # print(output[:11])

    return output

# Test
BS = 4
N = 16
emb_dim = 32
n_experts = 4

torch.manual_seed(123)
input = torch.randn((BS, N, emb_dim), dtype=torch.float32)
moe = SwitchMoE(emb_dim=emb_dim, n_experts=n_experts)
output = moe(input)
assert output.shape == input.shape

# Verify that when n_experts == 1, the result is the same as the MLP
input = torch.randn((BS, N, emb_dim), dtype=torch.float32)

torch.manual_seed(123)
moe = SwitchMoE(emb_dim=emb_dim, n_experts=1)
output = moe(input)

torch.manual_seed(123)
print(f"{input.shape=}")
mlp = MLP(emb_dim=emb_dim)
expected = mlp(input)

assert torch.allclose(output, expected)

input.shape=torch.Size([4, 16, 32])


# EP

TODO: the current impl assume the # of replicas equals the # of experts, and each replica gets exact one expert

In [23]:
class DistSwitchMoE(nn.Module):
  def __init__(self, emb_dim: int, n_experts: int, rank: int, world_size: int):
    super().__init__()

    # TODO: remove this limitation.
    assert n_experts == world_size

    self.emb_dim = emb_dim
    self.n_experts = n_experts
    self.rank = rank
    self.world_size = world_size

    self.expert = MLP(emb_dim=emb_dim)
    if rank == 0:
      self.gate = Gate(emb_dim=emb_dim, n_experts=n_experts)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    # 1. Prepare the per expert input.
    if self.rank == 0:
      input = x                                                   # [BS, N, ED]
      input = input.reshape(-1, self.emb_dim)                          # [BS*N, ED]
      output = torch.zeros_like(input)                            # [BS*N, ED]

      gate_prob = self.gate(input)                                     # [BS*N, NE]
      assert gate_prob.shape == (BS * N, self.n_experts), f"{gate_prob.shape=}"
      # print(f"{gate_prob.shape=}")

      gate_top1_idx = gate_prob.argmax(dim=-1)                    # [BS*N]
      gate_top1_onehot = nn.functional.one_hot(gate_top1_idx, num_classes=self.n_experts)  # [BS*N, NE]
      assert gate_top1_onehot.shape == (BS * N, self.n_experts)

      per_expert_tokens = []
      per_expert_input = []
      ep_input_isend_reqs = []
      masks = []

      for i in range(self.world_size):
        # The mask about whether the tokens are routed to the current expert
        mask = gate_top1_onehot[:, i]     # [BS*N]
        assert mask.shape == (BS*N,)
        mask = mask==True                 # [BS*N]
        assert mask.shape == (BS*N,)
        # print(f"{mask=}")
        masks.append(mask)

        # Keep only the tokens for the current expert
        expert_input = input[mask]        # [N_EXPERT_INPUT, ED]
        per_expert_tokens.append(len(expert_input))
        per_expert_input.append(expert_input)

      assert sum(per_expert_tokens) == BS * N
    # print(f"{rank=}, finished step 1")

    # 2. Communicate the expert input size
    if self.rank == 0:
      ep_input_len_isend_reqs = []
      for i in range(1, self.world_size):
        t_send = torch.tensor(len(per_expert_input[i]), dtype=torch.int32)
        req = dist.isend(t_send, dst=i)
        ep_input_len_isend_reqs.append(req)

      expert_input_len = len(per_expert_input[0])

      for req in ep_input_len_isend_reqs:
        req.wait()
    else:
      expert_input_len = torch.zeros(1, dtype=torch.int32)
      dist.recv(expert_input_len, src=0)
    # print(f"{rank=}, finished step 2")

    # 3. Distribute the expert inputs
    if self.rank == 0:
      for i in range(1, self.world_size):
        req = dist.isend(per_expert_input[i], dst=i)
        ep_input_isend_reqs.append(req)
      assert len(ep_input_isend_reqs) + 1 == self.world_size

      expert_input = per_expert_input[0]

      for req in ep_input_isend_reqs:
        req.wait()
    else:
      expert_input = torch.zeros((expert_input_len.item(), self.emb_dim), dtype=torch.float32)
      dist.recv(expert_input, src=0)
    # print(f"{rank=}, finished step 3")

    # 4. Computation
    expert_output = self.expert(expert_input)
    assert expert_input.shape == expert_output.shape
    # print(f"{rank=}, finished step 4")

    # 5. Collect the computation results
    if self.rank == 0:
      # Merge the local expert output to the final output.
      output[masks[0]] = expert_output

      # Verify the expert output are merged to the final output correctly.
      mask = masks[0]
      assert not torch.allclose(output[mask], torch.zeros_like(expert_output))
      mask_reverse = mask==False
      assert output[mask_reverse].shape[0] + output[mask].shape[0] == BS * N
      assert torch.allclose(output[mask_reverse], torch.zeros_like(output[mask_reverse]))

      # Merge the remote expert outputs to the final output.
      for i in range(1, self.world_size):
        remote_expert_output = torch.zeros_like(per_expert_input[i], dtype=torch.float32)
        dist.recv(remote_expert_output, src=i)
        output[masks[i]] = remote_expert_output

      # Extra verification
      output = output.reshape(BS, N, -1)
      assert output.shape == (BS, N, self.emb_dim)
      # print(f"{rank=}, finished step 5")

      return output
    else:
      dist.send(expert_output, dst=0)
      # print(f"{rank=}, finished step 5")
      return None

# Test
BS = 4
N = 16
emb_dim = 32
n_experts = 4
world_size = n_experts


os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12359' # You can choose a different port if 12355 is in use

torch.manual_seed(123)
input = torch.randn((BS, N, emb_dim), dtype=torch.float32)

def init_process_ep(rank, world_size):
  print(f"Starting process with {rank=}, {world_size=}")

  # Use the gloo backend for CPU-based distributed processing
  dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)

  assert rank == dist.get_rank()
  assert world_size == dist.get_world_size()
  dist.barrier()

  torch.manual_seed(123)
  dist_moe = DistSwitchMoE(emb_dim=emb_dim, n_experts=n_experts, rank=rank, world_size=world_size)
  output = dist_moe(input)

  if rank == 0:
    assert output.shape == input.shape
    print(f"{output.shape=}")

  # Verification
  if rank == 0:
    torch.manual_seed(123)
    moe = SwitchMoE(emb_dim=emb_dim, n_experts=n_experts)
    expected = moe(input)
    torch.allclose(expected, output)

processes = []
for rank in range(world_size):
  p = mp.Process(target=init_process_ep, args=(rank, world_size))
  p.start()
  processes.append(p)

for p in processes:
  p.join()

Starting process with rank=0, world_size=4
Starting process with rank=1, world_size=4
Starting process with rank=2, world_size=4Starting process with rank=3, world_size=4

output.shape=torch.Size([4, 16, 32])


In [15]:
# @title Test tensor idx

x = torch.tensor([1,2,3])
idx = [True, False, True]

print(f"{idx=}")
x[idx]

idx=[True, False, True]


tensor([1, 3])