In [1]:
from dataclasses import dataclass, asdict

@dataclass
class ModelConfig:
    '''
    Model configuration reference:
    https://github.com/epfml/llm-baselines?tab=readme-ov-file#results-on-wikitext
    '''
    n_layers: int = 24
    vocab_size: int = 32000  # LLaMA 2 tokenizer
    d_embd: int = 768
    n_heads: int = 12
    seq_len: int = 512
    rope_theta: float = 1e4
    rope_scale: float = 1.0
    ffn_mult: int = 256
    norm_eps: float = 1e-5

cfg_m = ModelConfig()

In [2]:
def x_test(B=2):
    mx.random.seed(3985)
    x_ = mx.random.uniform(shape=[B, cfg_m.seq_len, cfg_m.d_embd])
    return x_

In [3]:
import mlx.core as mx
from mlx import nn

In [32]:
from mlx.core.fast import scaled_dot_product_attention

class SelfAttention(nn.Module):
    def __init__(self, d_embd, n_heads, rope_theta, rope_scale, **kwargs):
        super().__init__()
        assert d_embd % n_heads == 0
        self.d_head = d_embd // n_heads

        self.attn_proj = nn.Linear(d_embd, 3*d_embd, bias=False)
        self.rope = nn.RoPE(self.d_head, base=rope_theta, scale=rope_scale)
        self.scale = self.d_head ** -0.5
        self.out_proj = nn.Linear(d_embd, d_embd, bias=False)

    def __call__(self, x, mask):
        bsz, seq_len, d_embd = x.shape

        # [bsz, seq_len, d_embd] * 3
        qkv = self.attn_proj(x).split(3, axis=-1)

        # bsz, n_heads, seq_len, d_head
        to_attn_heads = lambda z: z.reshape(bsz, seq_len, -1, self.d_head).transpose(0, 2, 1, 3)
        Q, K, V = map(to_attn_heads, qkv)

        # Apply rotary embeddings
        Q, K = self.rope(Q), self.rope(K)

        # bsz, n_head, seq_len, d_head
        O = scaled_dot_product_attention(Q, K, V, scale=self.scale, mask=mask)

        # bsz, seq_len, d_embd
        output = self.out_proj(O.transpose(0, 2, 1, 3).reshape(bsz, seq_len, d_embd))

        return output

attn = SelfAttention(**asdict(cfg_m))
attn(x_test(), mask=nn.MultiHeadAttention.create_additive_causal_mask(cfg_m.seq_len)).shape, (cfg_m.seq_len, cfg_m.d_embd)

((2, 512, 768), (512, 768))

In [55]:
class FeedForwardNet(nn.Module):
    def __init__(self, d_embd, ffn_mult, **kwargs):
        super().__init__()
        hidden_dim = int((4 * d_embd) * 2 / 3)
        hidden_dim = ffn_mult * ((hidden_dim + ffn_mult - 1) // ffn_mult)  # The next multiple of ffn_mult

        self.gate_proj = nn.Linear(d_embd, hidden_dim, bias=False)
        self.up_proj = nn.Linear(d_embd, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, d_embd, bias=False)

    def __call__(self, x):
        h = nn.silu(self.gate_proj(x)) * self.up_proj(x)  # SwiGLU
        out = self.down_proj(h)
        return out

ffn = FeedForwardNet(**asdict(cfg_m))
ffn(x_test()).shape

(2, 512, 768)

In [56]:
class TransformerBlock(nn.Module):
    def __init__(self, d_embd, norm_eps, **kwargs):
        super().__init__()
        self.pre_norm = nn.RMSNorm(d_embd, norm_eps)
        self.attn = SelfAttention(d_embd=d_embd, **kwargs)
        self.ffn_norm = nn.RMSNorm(d_embd, norm_eps)
        self.ffn = FeedForwardNet(d_embd=d_embd, **kwargs)

    def __call__(self, x, mask):
        h = x + self.attn(self.pre_norm(x), mask)
        out = h + self.ffn(self.ffn_norm(h))
        return out

layer = TransformerBlock(**asdict(cfg_m))
layer(x_test(), mask=nn.MultiHeadAttention.create_additive_causal_mask(cfg_m.seq_len)).shape

(2, 512, 768)

In [57]:
class LLaMA(nn.Module):
    def __init__(self, vocab_size, n_layers, d_embd, norm_eps, **kwargs):
        super().__init__()
        self.embd_toks = nn.Embedding(vocab_size, d_embd)
        self.layers = [
            TransformerBlock(d_embd=d_embd, norm_eps=norm_eps,**kwargs)
            for _ in range(n_layers)
        ]
        self.out_norm = nn.RMSNorm(d_embd, norm_eps)
        self.lm_head = nn.Linear(d_embd, vocab_size, bias=False)

    def __call__(self, tok_idxs):
        # bsz, seq_len, d_embd
        h = self.embd_toks(tok_idxs)

        # bsz, seq_len, d_embd
        causal_mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
        for layer in self.layers:
            h = layer(h, causal_mask)
        h = self.out_norm(h)

        # bsz, seq_len, vocab_size
        logits = self.lm_head(h)

        return logits

model = LLaMA(**asdict(cfg_m))
model(mx.random.randint(0, cfg.vocab_size, shape=[2, cfg.seq_len])).shape

(2, 512, 32000)

In [50]:
model

LLaMA(
  (embd_toks): Embedding(32000, 768)
  (layers.0): TransformerBlock(
    (pre_norm): RMSNorm(768, eps=1e-05)
    (attn): SelfAttention(
      (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)
      (rope): RoPE(64, traditional=False)
      (out_proj): Linear(input_dims=768, output_dims=768, bias=False)
    )
    (ffn_norm): RMSNorm(768, eps=1e-05)
    (ffn): FeedForwardNet(
      (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)
      (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)
      (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)
    )
  )
  (layers.1): TransformerBlock(
    (pre_norm): RMSNorm(768, eps=1e-05)
    (attn): SelfAttention(
      (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)
      (rope): RoPE(64, traditional=False)
      (out_proj): Linear(input_dims=768, output_dims=768, bias=False)
    )
    (ffn_norm): RMSNorm(768, eps=1e-05)
    (ffn): FeedForwardNet(
      (gate_proj): L