In [25]:

#for displaying figures in code editor
#%matplotlib inline
import matplotlib.pyplot as plt
from time import perf_counter
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn
import random




In [26]:

class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4* n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
         nn.Dropout(0.2))
        
    def forward(self, x):
        return self.net(x)
        
class MoeLayer(nn.Module):
    def __init__(self, experts, gate, k):
        super().__init__()
        assert len(experts) > 0
        self.experts = nn.ModuleList(experts)
        self.gate = gate
        self.k = k

    def forward(self, inputs: torch.Tensor):
        inputs_squashed = inputs.view(-1, inputs.shape[-1])
        gate_logits = self.gate(inputs_squashed)
        weights, selected_experts = torch.topk(
            gate_logits, self.k
        )
        weights = nn.functional.softmax(
            weights,
            dim=1,
            dtype=torch.float,
        ).type_as(inputs)
        results = torch.zeros_like(inputs_squashed)
        for i, expert in enumerate(self.experts):
            batch_idx, nth_expert = torch.where(selected_experts == i)
            results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(
                inputs_squashed[batch_idx]
            )
        return results.view_as(inputs)
     

In [27]:
num_experts = 4
top_k = 2
n_embd = 16
dropout=0.1

mh_output = torch.randn(4, 8, n_embd)  # Example multi-head attention output

In [28]:
moe = MoeLayer(
            experts=[FeedForward(n_embd) for _ in range(num_experts)],
            gate=nn.Linear(n_embd, num_experts, bias=False),k=1
        )

In [30]:
moe(mh_output).shape

torch.Size([4, 8, 16])