In [1]:
import torch
import torch.nn as nn

In [2]:
from deps.other_components import SiLU

In [3]:
class FeedForward_MoE(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.num_experts_per_tok = cfg['n_experts_per_tok']
        self.num_experts = cfg['n_experts']

        self.gate = nn.Linear(cfg['emb_dim'], cfg['n_experts'], bias=False, dtype=cfg['dtype'])

        meta_device = torch.device('meta')  # to reduce memory when loading weights
        self.fc1 = nn.ModuleList([
            nn.Linear(cfg['emb_dim'], cfg['moe_intermediate_size'], bias=False, dtype=cfg['dtype'],
                    device=meta_device)
            for _ in range(cfg['n_experts'])
        ])
        self.fc2 = nn.ModuleList([
            nn.Linear(cfg['emb_dim'], cfg['moe_intermediate_size'], bias=False, dtype=cfg['dtype'],
                    device=meta_device)
            for _ in range(cfg['n_experts'])
        ])
        self.fc3 = nn.ModuleList([
            nn.Linear(cfg['moe_intermediate_size'], cfg['emb_dim'], bias=False, dtype=cfg['dtype'],
                    device=meta_device)
            for _ in range(cfg['n_experts'])
        ])
        self.silu = SiLU()
    
    def forward(self, x):
        # (batch, seq_len, emb_dim) -> (batch, seq_len, n_experts)
        scores = self.gate(x)
        topk_scores, topk_idxs = torch.topk(scores, k=self.num_experts_per_tok, dim=-1)
        topk_probas = torch.softmax(topk_scores, dim=-1)
        
        # (batch, seq_len, emb_dim)
        y = torch.zeros_like(x)

        for i in range(self.num_experts_per_tok):
            # work on ith entry in top-k

            # (batch, seq_len, n_experts) -> (batch, seq_len)
            expert_idxs = topk_idxs[..., i]
            # (batch, seq_len, n_experts) -> (batch, seq_len, 1)
            expert_proba = topk_probas[..., i].unsqueeze(-1)

            # each expert processes only assigned tokens
            for e in range(self.num_experts):
                # (batch, seq_len)
                mask = (expert_idxs == e)
                # check if any token in any batch is assigned to this expert
                if mask.any():
                    # (batch, seq_len, emb_dim) -> (n_tokens_e, emb_dim)
                    # first 2 dims are indexed by mask
                    selected = x[mask]
                    
                    # SwiGLU
                    # (n_tokens_e, emb_dim) -> (n_tokens_e, moe_hidden)
                    hidden = self.silu(self.fc1[e](selected)) * self.fc2[e](selected)
                    # (n_tokens_e, moe_hidden) -> (n_tokens_e, emb_dim)
                    out = self.fc3[e](hidden)
                    y[mask] += expert_proba[mask] * out
            
        return y

In [4]:
from deps.other_components import RMSNorm_Qwen
from deps.other_components import precompute_rope_params, compute_rope
from deps.other_components import GroupedQueryAttention_Qwen

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.attn = GroupedQueryAttention_Qwen(
            d_in=cfg['emb_dim'],
            num_heads=cfg['n_heads'],
            head_dim=cfg['head_dim'],
            num_kv_groups=cfg['n_kv_groups'],
            dtype=cfg['dtype']
        )

        self.ff = FeedForward_MoE(cfg)
        self.norm1 = RMSNorm_Qwen(cfg['emb_dim'], eps=1e-6)
        self.norm2 = RMSNorm_Qwen(cfg['emb_dim'], eps=1e-6)

    def forward(self, x, mask, cos, sin):
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x, mask, cos, sin)
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = x + shortcut

        return x

In [7]:
class Qwen3MoEModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.tok_emb = nn.Embedding(cfg['vocab_size'], cfg['emb_dim'], dtype=cfg['dtype'])
        self.trf_blocks = nn.ModuleList([
            TransformerBlock(cfg)
            for _ in range(cfg['n_layers'])
        ])

        self.final_norm = RMSNorm_Qwen(cfg['emb_dim'])
        self.out_head = nn.Linear(cfg['emb_dim'], cfg['emb_dim'], bias=False, dtype=cfg['dtype'])

        cos, sin = precompute_rope_params(
            cfg['head_dim'] if cfg['head_dim'] is not None else (cfg['emb_dim'] // cfg['n_heads']),
            cfg['rope_base'],
            cfg['context_len'],
        )

        self.register_buffer('cos', cos, persistent=True)
        self.register_buffer('sin', sin, persistent=True)
        self.cfg = cfg
    
    def forward(self, in_idx):
        tok_embs = self.tok_emb(in_idx)
        x = tok_embs

        num_tokens = x.shape[1]
        mask = torch.triu(
            torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool),
            diagonal=1,
        )

        for block in self.trf_blocks:
            x = block(x, mask, self.cos, self.sin)
        
        x = self.final_norm(x)
        logits = self.out_head(x.to(self.cfg['dtype']))
        
        return logits

In [8]:
QWEN3_CONFIG = {
    'vocab_size': 151_936,
    'context_len': 262_144,
    'emb_dim': 2048,
    'n_heads': 32,
    'n_layers': 48,
    'head_dim': 128,
    'n_kv_groups': 4,
    'rope_base': 10_000_000.0,
    'dtype': torch.bfloat16,
    'n_experts': 128,
    'n_experts_per_tok': 8,
    'moe_intermediate_size': 768
}

In [9]:
# verify if it works

TEST_QWEN3_CONFIG = {
    'vocab_size': 64,
    'context_len': 256,
    'emb_dim': 1024,
    'n_heads': 8,
    'n_layers': 2,
    'head_dim': 128,
    'n_kv_groups': 2,
    'rope_base': 10_000_000.0,
    'dtype': torch.bfloat16,
    'n_experts': 32,
    'n_experts_per_tok': 8,
    'moe_intermediate_size': 324
}

model = Qwen3MoEModel(TEST_QWEN3_CONFIG)

In [10]:
model(torch.tensor([1,2,3]).unsqueeze(0))

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.2949, -0.2734, -0.4121,  ..., -0.3066, -1.1406, -0.4082]]],
       dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)

In [12]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total params: {total_params:,}")

total_params = total_params - model.tok_emb.weight.numel()
print(f"Total unique params: {total_params:,}")

Total params: 70,129,152
Total unique params: 70,063,616
