In [1]:
import torch
import torch.nn as nn

In [2]:
from deps.other_components import SiLU

In [3]:
class FeedForward_MoE(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.num_experts_per_tok = cfg['n_experts_per_tok']
        self.num_experts = cfg['n_experts']

        self.gate = nn.Linear(cfg['emb_dim'], cfg['n_experts'], bias=False, dtype=cfg['dtype'])

        meta_device = torch.device('meta')  # to reduce memory when loading weights
        self.fc1 = nn.ModuleList([
            nn.Linear(cfg['emb_dim'], cfg['moe_intermediate_size'], bias=False, dtype=cfg['dtype'],
                    device=meta_device)
            for _ in range(cfg['n_experts'])
        ])
        self.fc2 = nn.ModuleList([
            nn.Linear(cfg['emb_dim'], cfg['moe_intermediate_size'], bias=False, dtype=cfg['dtype'],
                    device=meta_device)
            for _ in range(cfg['n_experts'])
        ])
        self.fc3 = nn.ModuleList([
            nn.Linear(cfg['moe_intermediate_size'], cfg['emb_dim'], bias=False, dtype=cfg['dtype'],
                    device=meta_device)
            for _ in range(cfg['n_experts'])
        ])
        self.silu = SiLU()
    
    def forward(self, x):
        # (batch, seq_len, emb_dim) -> (batch, seq_len, n_experts)
        scores = self.gate(x)
        topk_scores, topk_idxs = torch.topk(scores, k=self.num_experts_per_tok, dim=-1)
        topk_probas = torch.softmax(topk_scores, dim=-1)
        
        # (batch, seq_len, emb_dim)
        y = torch.zeros_like(x)

        for i in range(self.num_experts_per_tok):
            # work on ith entry in top-k

            # (batch, seq_len, n_experts) -> (batch, seq_len)
            expert_idxs = topk_idxs[..., i]
            # (batch, seq_len, n_experts) -> (batch, seq_len, 1)
            expert_proba = topk_probas[..., i].unsqueeze(-1)

            # each expert processes only assigned tokens
            for e in range(self.num_experts):
                # (batch, seq_len)
                mask = (expert_idxs == e)
                # check if any token in any batch is assigned to this expert
                if mask.any():
                    # (batch, seq_len, emb_dim) -> (n_tokens_e, emb_dim)
                    # first 2 dims are indexed by mask
                    selected = x[mask]
                    
                    # SwiGLU
                    # (n_tokens_e, emb_dim) -> (n_tokens_e, moe_hidden)
                    hidden = self.silu(self.fc1[e](selected)) * self.fc2[e](selected)
                    # (n_tokens_e, moe_hidden) -> (n_tokens_e, emb_dim)
                    out = self.fc3[e](hidden)
                    y[mask] += expert_proba[mask] * out
            
        return y

In [4]:
from deps.other_components import RMSNorm_Qwen
from deps.other_components import precompute_rope_params, compute_rope
from deps.other_components import GroupedQueryAttention_Qwen