In [5]:
import mlx.core as mx
from mlx import nn

In [2]:
mx.random.seed(3985)

In [11]:
B, T, C = 2, 4, 8  # batch_size, seq_len, d_embd
x = mx.random.uniform(shape=[B, T, C])
x

array([[[0.324384, 0.0728916, 0.692562, ..., 0.473036, 0.732917, 0.613625],
        [0.937988, 0.821932, 0.353415, ..., 0.699138, 0.0826687, 0.684287],
        [0.963293, 0.217973, 0.239089, ..., 0.435873, 0.525348, 0.027538],
        [0.332283, 0.22946, 0.521798, ..., 0.418693, 0.172103, 0.0305814]],
       [[0.380298, 0.500459, 0.931418, ..., 0.795432, 0.0933626, 0.835255],
        [0.328979, 0.496543, 0.868076, ..., 0.533726, 0.293846, 0.770115],
        [0.442788, 0.287145, 0.0224269, ..., 0.171359, 0.840324, 0.175392],
        [0.408249, 0.0423786, 0.482122, ..., 0.301558, 0.276174, 0.602193]]], dtype=float32)

In [63]:
class MixtureOfDepths(nn.Module):
    def __init__(self, block, capacity_factor, seq_len, n_embd):
        super().__init__()
        self.block = block
        self.capacity_factor = capacity_factor
        self.capacity = int(capacity_factor * seq_len)
        self.router = nn.Linear(n_embd, 1)

    def __call__(self, x):
        B, T = x.shape[:2]  # batch_size, seq_len

        # Top k expert choice
        r = self.router(x).squeeze(-1)
        capacity = min(self.capacity, self.capacity_factor*T)
        chosen_idx = mx.argpartition(-r, capacity, axis=1)[:, :capacity]

        # Sorted top k to preserve token causality
        # mx.sort does not support uint32?
        chosen_idx = mx.sort(chosen_idx.astype(mx.float32), axis=1).astype(mx.uint32)

        # Process chosen tokens
        batch_idx = mx.arange(B)[:, None]
        chosen_r = r[batch_idx, chosen_idx, None]
        chosen_x = x[batch_idx, chosen_idx, :]
        process_x = self.block(chosen_x)
        x[batch_idx, chosen_idx, :] += chosen_r * process_x

        # Auxiliary loss for training the router
        r_nll = -nn.log_softmax(chosen_r[..., 0], axis=-1).mean()

        return x, r_nll

mod = MixtureOfDepths(nn.Linear(C, C), 0.5, T, C)
mod(x)

(array([[[0.557108, 0.0590587, 0.603644, ..., 0.288003, 0.565407, 0.752665],
         [-0.286823, 0.1454, -0.0832467, ..., 0.714576, 0.233552, 0.925167],
         [1.15273, 1.4296, -0.337541, ..., 0.774492, 1.09024, -0.284072],
         [0.694623, 1.15832, 0.722264, ..., 0.151736, 0.672384, -0.699643]],
        [[0.947543, 2.18298, -0.407171, ..., 0.649991, 1.21402, -0.0334666],
         [0.0934433, -0.0856612, 0.686615, ..., 0.866722, 0.30294, 0.53751],
         [0.693493, 0.788395, -0.268603, ..., 0.715195, 1.10677, 0.194276],
         [0.0828968, 0.234539, 0.210647, ..., 0.543489, 0.562059, 0.0940933]]], dtype=float32),
 array(0.705366, dtype=float32))

In [47]:
from mlx.core.fast import scaled_dot_product_attention

class CausalSelfAttention(nn.Module):
    def __init__(self, n_head, d_embd, p_drop, **kwargs):
        assert d_embd % n_head == 0
        super().__init__()
        self.n_head = n_head
        self.scale = (d_embd / n_head) ** 0.5

        self.attn_proj = nn.Linear(d_embd, 3*d_embd, bias=False)
        self.out_proj = nn.Linear(d_embd, d_embd, bias=False)
        self.resid_drop = nn.Dropout(p_drop)

    def __call__(self, x):
        B, T = x.shape[:2]

        qkv = self.attn_proj(x).split(3, axis=-1)  # B, T, d_embd
        to_attn_weights = lambda z: z.reshape(B, T, self.n_head, -1).transpose(0, 2, 1, 3)
        Q, K, V = map(to_attn_weights, qkv)  # B, n_head, T, d_head

        # MLX SDPA does not support dropout?
        causal_mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
        O = scaled_dot_product_attention(Q, K, V, scale=self.scale, mask=causal_mask)  # B, n_head, T, d_head
        O = O.transpose(0, 2, 1, 3).reshape(B, T, -1)  # B, T, d_embd

        output = self.resid_drop(self.out_proj(O))

        return output

attn = CausalSelfAttention(8, 128, 0.1)
attn(mx.random.uniform(shape=[2, 4, 128])).shape

(2, 4, 128)

In [51]:
class FeedForwardNet(nn.Module):
    def __init__(self, d_embd, p_drop, **kwargs):
        super().__init__()
        self.up_proj = nn.Linear(d_embd, 4*d_embd, bias=False)
        self.down_proj = nn.Linear(4*d_embd, d_embd, bias=False)
        self.dropout = nn.Dropout(p_drop)

    def __call__(self, x):
        x = nn.gelu(self.up_proj(x))
        x = self.dropout(self.down_proj(x))
        return x

In [52]:
class TransformerBlock(nn.Module):
    def __init__(self, d_embd, **kwargs):
        super().__init__()
        self.pre_norm = nn.LayerNorm(d_embd, bias=False)
        self.self_attn = CausalSelfAttention(d_embd=d_embd, **kwargs)
        self.post_norm = nn.LayerNorm(d_embd, bias=False)
        self.ffn = FeedForwardNet(d_embd=d_embd, **kwargs)

    def __call__(self, x):
        x = self.self_attn(self.pre_norm(x)) + x
        x = self.ffn(self.post_norm(x))
        return x

In [59]:
class GPT(nn.Module):
    def __init__(self, n_vocab, n_ctx, d_embd, p_drop, n_layers, **kwargs):
        super().__init__()

        self.tok_embd = nn.Embedding(n_vocab, d_embd)
        self.pos_embd = nn.Embedding(n_ctx, d_embd)
        self.dropout = nn.Dropout(p_drop)

        self.blocks = [
            TransformerBlock(d_embd=d_embd, p_drop=p_drop, **kwargs)
            for _ in range(n_layers)
        ]

        self.norm = nn.LayerNorm(d_embd, bias=False)
        self.lm_proj = nn.Linear(d_embd, n_vocab, bias=False)

    def __call__(self, tok_idx):
        T = tok_idx.shape[1]

        tok_embd = self.tok_embd(tok_idx)
        pos_embd = self.pos_embd(mx.arange(T))
        x = self.dropout(tok_embd + pos_embd)

        for block in self.blocks:
            x = block(x)

        logits = self.lm_proj(self.norm(x))

        return logits

In [57]:
from dataclasses import dataclass, asdict

@dataclass
class ModelConfig:
    n_vocab: int
    n_ctx: int
    n_layers: int
    d_embd: int
    n_head: int
    p_drop: float

cfg = ModelConfig(128, 32, 4, 256, 8, 0.1)
cfg

ModelConfig(n_vocab=128, n_ctx=32, n_layers=4, d_embd=256, n_head=8, p_drop=0.1)

In [61]:
model = GPT(**asdict(cfg))
model(mx.random.randint(0, cfg.n_vocab, shape=[2, cfg.n_ctx])).shape

(2, 32, 128)