In [1]:
import torch
import torch.nn as nn

In [2]:
from deps.other_components import FeedForward_SwiGLU

In [3]:
class RMSNorm_Qwen(nn.Module):
    def __init__(self, emb_dim, eps=1e-6, bias=False):
        super().__init__()

        self.eps = eps
        # Llama had only 1 parameter
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None
    
    def forward(self, x):
        input_dtype = x.dtype
        x = x.to(torch.float32)  # for Qwen
        
        # RMSNorm
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        norm_x = x * torch.rsqrt(variance + self.eps)

        # new
        norm_x = norm_x * self.scale
        if self.shift is not None:
            norm_x = norm_x + self.shift
        
        return norm_x.to(input_dtype)

In [4]:
from deps.other_components import precompute_rope_params, compute_rope

In [5]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_in, num_heads, num_kv_groups,
                    head_dim=None, dtype=None):
        super().__init__()
        assert num_heads % num_kv_groups == 0, 'num_heads must be divisible by num_kv_groups'

        self.num_heads = num_heads
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads // num_kv_groups

        if head_dim is None:
            assert d_in % num_heads == 0, 'd_in must be divisible by num_heads if head_dim is not provided'
            head_dim = d_in // num_heads
        self.head_dim = head_dim
        # if head_dim not provided then:
        #   d_out = num_heads * head_dim = num_heads * (d_in / num_heads) = d_in
        #   so, d_out = d_in
        self.d_out = num_heads*head_dim

        self.W_q = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)
        # if head_dim not provided, then:
        #   num_kv_groups * head_dim = num_kv_groups * d_in / num_heads
        #       = (num_kv_groups / num_heads) * d_in
        #       = (1 / group_size) * d_in = d_in / group_size
        # in GQA, K and V are repeated "group_size" times, so:
        #       = group_size * ((1 / group_size) * d_in)
        #       = d_in
        self.W_k = nn.Linear(d_in, num_kv_groups*head_dim, bias=False, dtype=dtype)
        self.W_v = nn.Linear(d_in, num_kv_groups*head_dim, bias=False, dtype=dtype)

        self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)

        # new, QK-Norm
        self.q_norm = RMSNorm_Qwen(head_dim, eps=1e-6)
        self.k_norm = RMSNorm_Qwen(head_dim, eps=1e-6)
    
    def forward(self, x, mask, cos, sin):
        b, num_tokens, _ = x.shape

        queries = self.W_q(x)  # (b, num_tokens, num_heads * head_dim)
        keys = self.W_k(x)  # (b, num_tokens, num_kv_groups * head_dim)
        values = self.W_v(x)  # (b, num_tokens, num_kv_groups * head_dim)

        # "un-flatten" last dims
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)
        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)
        
        queries = queries.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)

        # new
        queries = self.q_norm(queries)
        keys = self.k_norm(keys)

        queries = compute_rope(queries, cos, sin)
        keys = compute_rope(keys, cos, sin)

        # (..., head_dim, num_kv_groups) -> (..., head_dim, num_heads)
        #   num_heads / num_kv_groups = group_size
        #   => num_heads = num_kv_groups * group_size
        keys = keys.repeat_interleave(self.group_size, dim=1)
        values = values.repeat_interleave(self.group_size, dim=1)

        attn_scores = queries @ keys.transpose(2,3)
        attn_scores = attn_scores.masked_fill(mask, -torch.inf)
        attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)

        context = (attn_weights @ values).transpose(1,2).reshape(b, num_tokens, self.d_out)
        return self.out_proj(context)

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        
        self.attn = GroupedQueryAttention(
            d_in=cfg['emb_dim'],
            num_heads=cfg['n_heads'],
            head_dim=cfg['head_dim'],
            num_kv_groups=cfg['n_kv_groups'],
            dtype=cfg['dtype'],
        )

        self.ff = FeedForward_SwiGLU(cfg)
        self.norm1 = RMSNorm_Qwen(cfg['emb_dim'], eps=1e-6)
        self.norm2 = RMSNorm_Qwen(cfg['emb_dim'], eps=1e-6)
    
    def forward(self, x, mask, cos, sin):
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x, mask, cos, sin)
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = x + shortcut

        return x


In [7]:
class Qwen3Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.tok_emb = nn.Embedding(cfg['vocab_size'], cfg['emb_dim'], dtype=cfg['dtype'])
        self.trf_blocks = nn.ModuleList([
            TransformerBlock(cfg)
            for _ in range(cfg['n_layers'])
        ])

        self.final_norm = RMSNorm_Qwen(cfg['emb_dim'])
        self.out_head = nn.Linear(cfg['emb_dim'], cfg['vocab_size'], bias=False, dtype=cfg['dtype'])

        if cfg['head_dim'] is None:
            head_dim = cfg['emb_dim'] // cfg['n_heads']
        else:
            head_dim = cfg['head_dim']
        
        cos, sin = precompute_rope_params(head_dim, theta_base=cfg['rope_base'], context_len=cfg['context_len'])
        self.register_buffer('cos', cos, persistent=False)
        self.register_buffer('sin', sin, persistent=False)
        self.cfg = cfg
    
    def forward(self, in_idx):
        tok_embs = self.tok_emb(in_idx)
        x = tok_embs

        num_tokens = x.shape[1]
        mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)

        for block in self.trf_blocks:
            x = block(x, mask, self.cos, self.sin)
        
        x = self.final_norm(x)
        logits = self.out_head(x.to(self.cfg['dtype']))
        return logits

In [8]:
# 0.6B config
QWEN_0_6B_CONFIG = {
    'vocab_size': 151_936,
    'context_len': 40_960,
    'emb_dim': 1024,
    'n_heads': 16,
    'n_layers': 28,
    'hidden_dim': 3072,
    'head_dim': 128,
    'n_kv_groups': 8,
    'rope_base': 1_000_000.0,
    'dtype': torch.bfloat16,
}

torch.manual_seed(123)
model = Qwen3Model(QWEN_0_6B_CONFIG)

In [9]:
# verify
model(torch.tensor([1,2,3]).unsqueeze(0))

tensor([[[-0.2197, -0.0109, -0.7227,  ...,  0.4316,  0.1216,  1.0781],
         [-0.6445,  0.5391, -0.0767,  ..., -0.0771,  0.5312,  0.3008],
         [-0.4727, -0.1582,  0.1172,  ..., -0.2305,  0.2334,  0.6367]]],
       dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)

In [11]:
total_params = sum(p.numel() for p in model.parameters())
print(f'Total Parameters: {total_params:,}')

# account for weight tying
total_params = total_params - model.tok_emb.weight.numel()
print(f'Total Unique Parameters: {total_params:,}')

Total Parameters: 751,632,384
Total Unique Parameters: 596,049,920
