In [2]:
import torch
import torch.nn.functional as F
import torch.nn as nn

In [3]:
x = torch.randn(2,8)

probs, indices = torch.topk(x, 2, dim=-1)
print(probs, indices)

F.softmax(probs, dim=-1)

tensor([[1.3119, 1.2256],
        [1.2041, 0.8192]]) tensor([[2, 5],
        [7, 1]])


tensor([[0.5215, 0.4785],
        [0.5950, 0.4050]])

In [4]:
print(indices)
mask = F.one_hot(indices, num_classes=8)
print(mask)
mask = mask.permute(2, 1, 0)
print(mask)

tensor([[2, 5],
        [7, 1]])
tensor([[[0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1, 0, 0]],

        [[0, 0, 0, 0, 0, 0, 0, 1],
         [0, 1, 0, 0, 0, 0, 0, 0]]])
tensor([[[0, 0],
         [0, 0]],

        [[0, 0],
         [0, 1]],

        [[1, 0],
         [0, 0]],

        [[0, 0],
         [0, 0]],

        [[0, 0],
         [0, 0]],

        [[0, 0],
         [1, 0]],

        [[0, 0],
         [0, 0]],

        [[0, 1],
         [0, 0]]])


In [5]:
print(mask[0])

tensor([[0, 0],
        [0, 0]])


In [6]:
idx,token_idx = torch.where(torch.Tensor([[1,0],[1,0]]))
print(idx, token_idx)

tensor([0, 1]) tensor([0, 0])


In [7]:
probs[token_idx, idx]

tensor([1.3119, 1.2256])

In [8]:
(torch.randn(2, 4, 5) * torch.randn(2, 4, 1)).shape

torch.Size([2, 4, 5])

In [9]:
class Expert(nn.Module):
    def __init__(self, f_in, f_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(f_in, f_out),
            nn.GELU(),
        )
    def forward(self, x):
        return self.net(x)

exp = Expert(10,20)
exp(torch.randn(2,10)).shape

torch.Size([2, 20])

In [10]:
class BasicMOE(nn.Module):
    def __init__(self, f_in, f_out, n_expert):
        super().__init__()
        self.nets = nn.ModuleList(
            [Expert(f_in, f_out) for _ in range(n_expert)]
        )
        # gate
        self.gate = nn.Linear(f_in, n_expert)

    def forward(self, x):
        weight = self.gate(x) # (B, n_expert)
        outputs = torch.cat([exp(x).unsqueeze(1) for exp in self.nets], dim=1) # (B, n_expert, f_out)
        x = weight.unsqueeze(1) @ outputs # (B, 1, f_out)
        return x.squeeze(1)


moe = BasicMOE(10,20,8)
moe(torch.randn(2,10)).shape

torch.Size([2, 20])

In [12]:
from dataclasses import dataclass
device = 'cuda' if torch.cuda.is_available() else 'cpu'

@dataclass
class MOEConfig:
    hidden_dim: int
    n_expert: int
    top_k: int
    n_share_expert: int = 2

class MOERouter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate = nn.Linear(config.hidden_dim, config.n_expert)
        self.n_expert = config.n_expert
        self.top_k = config.top_k

    def forward(self, x):
        # gate logits
        router_logits = self.gate(x) # (B*ns, n_expert)

        # top k
        # weights (B*ns, top_k)
        weights, indices = torch.topk(router_logits, self.top_k, dim=-1)

        # norm
        weights = F.softmax(weights, dim=-1)

        # expert mask (B*ns, top_k, n_expert)
        expert_mask = F.one_hot(indices, num_classes=self.n_expert)
        # permute (n_expert, top_k, B*ns)
        expert_mask = expert_mask.permute(2, 1, 0)

        return router_logits, weights, indices, expert_mask
        
    
class SparseMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.experts = nn.ModuleList(
            [Expert(config.hidden_dim, config.hidden_dim) for _ in range(config.n_expert)]
        )
        self.router = MOERouter(config)
        self.hidden_dim = config.hidden_dim
        self.n_expert = config.n_expert
        self.top_k = config.top_k

    def forward(self, x):
        B, ns, nh = x.size()
        # (B*ns, nh)
        hs = x.view(-1, nh)

        # router select
        router_logits, weights, indices, expert_mask = self.router(hs)

        print(router_logits.shape, weights.shape, expert_mask.shape)

        # 
        final_hs = torch.zeros((B*ns, nh), dtype=x.dtype).to(device)

        for idx in range(self.n_expert):
            expert_layer = self.experts[idx]
            # (n_expert, top_k, B)
            idx, token_idx = torch.where(expert_mask[idx])
            # (len(token_idx), nh)
            current_state = hs.unsqueeze(0)[:, token_idx, :].reshape(-1, nh)
            # current_hs * weights
            # weights (B*ns, top_k) -> (len(token_idx)*len(idx), 1)
            router_weights = weights[token_idx, idx].unsqueeze(-1)
            current_hs = expert_layer(current_state) 
            # (len(token_idx, nh) * (len(token_idx), 1)
            current_hs =  current_hs * router_weights
            # add 
            final_hs[token_idx]+=current_hs
        final_hs = final_hs.view(B, ns, nh)
        return final_hs, router_logits

x = torch.randn(2, 4, 8).to(device)
config = MOEConfig(8, 4, 2)
moe = SparseMOE(config).to(device)
hs, logits = moe(x)
hs
    

torch.Size([8, 4]) torch.Size([8, 2]) torch.Size([4, 2, 8])


tensor([[[ 2.0326e-02,  1.9251e-01, -6.2082e-02,  1.8105e-01,  4.0711e-04,
           1.7474e-01,  1.7426e-01, -1.6486e-01],
         [ 1.1799e-01, -5.3165e-02,  4.2430e-01, -8.8925e-02, -8.0590e-02,
           6.2828e-01,  2.1644e-01, -1.6243e-01],
         [ 1.5192e-01, -7.0823e-02, -8.2092e-02, -1.3237e-01,  2.4077e-01,
          -1.1497e-01,  5.0428e-03, -1.6479e-01],
         [-7.3801e-02,  1.8053e-01,  4.0537e-02, -3.7610e-02,  1.6313e-01,
           3.0721e-01,  7.9455e-01,  1.4426e-01]],

        [[ 4.1279e-01, -1.1055e-01,  2.3830e-01, -1.2703e-01,  8.6327e-03,
          -4.3897e-02,  1.5085e-01, -1.0080e-01],
         [-1.0584e-01,  6.0744e-01, -1.6124e-01, -6.6115e-02,  1.6061e-01,
           1.8657e-01, -1.6364e-01,  3.7994e-01],
         [-7.7603e-03,  6.1990e-01, -1.0862e-01,  1.3450e-01,  1.7599e-01,
           1.7986e-01,  5.4610e-01, -1.3224e-01],
         [ 5.1953e-02,  1.1769e+00, -4.3976e-02,  5.6574e-01,  7.1521e-01,
           6.2427e-02, -4.6838e-02,  4.3063e-01]

In [15]:
torch.stack([torch.randn(2,3), torch.randn(2,3)], dim=0).sum(dim=0).shape

torch.Size([2, 3])