In [1]:
#importing libraries
import numpy as np 
import torch 
import torch.nn as nn
import torch.nn.functional as F 

### Gating Mechanism

In [2]:
class Gate(nn.Module):
    def __init__(self, input_dim, num_experts, k=1):
        super().__init__()
        self.k = k
        self.gate = nn.Linear(input_dim, num_experts)

    def forward(self, X):
        logits = self.gate(X)
        print(f"Gate Logits: {logits}")
        k_vals, k_indices = torch.topk(logits, self.k, dim = -1)
        print(f"Top k values are: {k_vals} and their indices are {k_indices}")
        scores = F.softmax(k_vals, dim = -1)
        return k_indices, scores

### Experts

In [3]:
class Experts(nn.Module):
    def __init__(self, num_experts, input_dim, hidden_dim):
        super().__init__()
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, input_dim)
            ) for _ in range(num_experts)
        ])

    def forward(self, X, indices):
        batch_size, k = indices.shape
        out = torch.zeros(batch_size, X.shape[1], device = X.device)
        for i in range(k):
            expert_idx = indices[:,i]
            print(f"Expert idx: {expert_idx}")
            for e in range(len(self.experts)):
                mask = (expert_idx == e)
                print("Mask: ",mask)
                if mask.sum() > 0:
                    x_selected = X[mask]
                    print(f" for expert {e} x selected is {x_selected}")
                    
                    out[mask] += self.experts[e](x_selected)
                    
                    

        return out


In [4]:
class MOE_layer(nn.Module):
    def __init__(self, input_dim,hidden_dim, num_experts=4, k=2):
        super().__init__()
        self.gate = Gate(input_dim, num_experts,k )
        self.experts = Experts(num_experts, input_dim, hidden_dim)
    
    def forward(self, x):
        topk_indices, topkscores = self.gate(x)
        expert_output = self.experts(x, topk_indices)
        return expert_output

In [5]:
input_dim = 12
hidden_dim = 24
batch_size = 4

x = torch.randn(batch_size, input_dim)
print(f"Input shape: {x.shape}")
print(f"X: {x}")
moe = MOE_layer(input_dim, hidden_dim, num_experts=4, k=2)

out = moe(x)
print("MoE Output:", out.shape)  

Input shape: torch.Size([4, 12])
X: tensor([[-0.3021,  0.7066, -1.4496, -1.4628, -0.4010,  0.1746,  1.2875, -0.3984,
          1.9735, -1.3014,  0.6499,  1.1704],
        [ 0.2978, -0.9248,  0.4805, -0.1547,  0.7873, -0.1263, -1.5822, -0.3890,
         -1.5515,  0.4783,  0.7481,  1.4117],
        [ 1.2824,  1.4317, -0.1390, -0.3534,  0.7705,  2.5527,  0.8637, -0.8920,
          0.5932,  0.8116,  0.0730,  0.5651],
        [ 0.9061,  1.9133,  0.5607, -2.0725,  0.1010, -0.3629, -0.0736,  1.9410,
          0.8914,  0.9148,  0.8736,  0.5167]])
Gate Logits: tensor([[ 0.3844,  0.6600,  0.9456, -0.5175],
        [ 0.5285,  0.9110,  0.1191, -0.8722],
        [ 0.4298, -0.2929,  0.8411,  0.4024],
        [ 0.2496, -0.3347,  1.3059, -0.0569]], grad_fn=<AddmmBackward0>)
Top k values are: tensor([[0.9456, 0.6600],
        [0.9110, 0.5285],
        [0.8411, 0.4298],
        [1.3059, 0.2496]], grad_fn=<TopkBackward0>) and their indices are tensor([[2, 1],
        [1, 0],
        [2, 0],
        [2, 0

In [6]:
out

tensor([[ 0.0273,  0.2877, -0.5854,  0.0574,  0.1356,  0.5124, -0.3650, -0.1347,
         -0.1133,  0.2097, -0.3669,  0.3900],
        [ 0.1926, -0.0400,  0.1400,  0.5526, -0.1984, -0.1806, -0.1975,  0.2709,
         -0.1425, -0.1440,  0.6351,  0.0976],
        [ 0.6171,  0.6376, -0.2093, -0.1238, -0.1605,  0.8789, -0.7390, -0.3176,
         -0.1437,  0.5077,  0.3651,  0.4078],
        [ 0.8661,  0.1895, -0.1989,  0.0059,  0.2289,  0.7442, -0.0927, -0.0368,
          0.2981,  0.4837,  0.4913,  0.7385]], grad_fn=<IndexPutBackward0>)