In [1]:
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn
import numpy as np

from mlx_lm.models.base import BaseModelArgs, create_additive_causal_mask


  from .autonotebook import tqdm as notebook_tqdm


In [21]:
@dataclass
class ModelArgs(BaseModelArgs):
    model_type: str
    n_ctx: int
    n_embd: int
    n_head: int
    n_layer: int
    n_positions: int
    layer_norm_epsilon: float
    vocab_size: int
    attn_pdrop: float
    embd_pdrop: float
    resid_pdrop: float
    vocab_size: int


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        # attribute_map = {
        #     "hidden_size": "n_embd",
        #     "max_position_embeddings": "n_positions",
        #     "num_attention_heads": "n_head",
        #     "num_hidden_layers": "n_layer",
        # }

        assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"

        self.n_embd = args.n_embd
        self.n_head = args.n_head
        self.head_dim = self.n_embd // self.n_head

        self.scale = self.head_dim**-0.5
        self.attn_pdrop = args.attn_pdrop

        self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)
        self.drop_attn = nn.Dropout(self.attn_pdrop)

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
    ) -> mx.array:
        B, L, D = x.shape

        qkv = self.c_attn(x)
        queries, keys, values = mx.split(qkv, 3, axis=-1)

        # Prepare the queries, keys and values for the attention computation
        queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
        keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
        values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)

        output = mx.fast.scaled_dot_product_attention(
            queries, keys, values, scale=self.scale, mask=mask
        )
        output = self.drop_attn(output)
        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
        return self.c_proj(output)


class MLP(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_embd = args.n_embd
        self.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd)
        self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd)

    def __call__(self, x) -> mx.array:
        return self.c_proj(nn.gelu(self.c_fc(x)))


class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_head = args.n_head
        self.n_embd = args.n_embd
        self.layer_norm_epsilon = args.layer_norm_epsilon
        self.resid_pdrop = args.resid_pdrop
        self.attn = Attention(args)
        self.mlp = MLP(args)
        self.ln_1 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
        self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
        self.drop_resid = nn.Dropout(self.resid_pdrop)

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
    ) -> mx.array:
        r = self.drop_resid(self.attn(self.ln_1(x), mask))
        h = x + r
        r = self.drop_resid(self.mlp(self.ln_2(h)))
        out = h + r
        return out


class GPT2Model(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_embd = args.n_embd
        self.n_positions = args.n_positions
        self.vocab_size = args.vocab_size
        self.n_layer = args.n_layer
        self.layer_norm_epsilon = args.layer_norm_epsilon
        self.embd_pdrop = args.embd_pdrop
        assert self.vocab_size > 0
        self.wte = nn.Embedding(self.vocab_size, self.n_embd)
        self.wpe = nn.Embedding(self.n_positions, self.n_embd)
        self.drop_embd = nn.Dropout(self.embd_pdrop)
        self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)]
        self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)

    def __call__(
        self,
        inputs: mx.array,
    ):
        B, L = inputs.shape

        hidden_states = self.wte(inputs)
        position_ids = mx.array(np.arrange(L))
        hidden_states += self.wpe(position_ids)
        hidden_states = self.drop_embd(hidden_states)

        mask = create_additive_causal_mask(hidden_states.shape[0], 0)

        for layer in self.h:
            hidden_states = layer(hidden_states, mask)

        return self.ln_f(hidden_states)


class Model(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.model_type = args.model_type
        self.transformer = GPT2Model(args)

    def __call__(
        self,
        inputs: mx.array,
    ):
        out = self.transformer(inputs)
        out = self.transformer.wte.as_linear(out)
        return out

    @property
    def layers(self):
        return self.transformer.h

    @property
    def head_dim(self):
        return self.args.n_embd // self.args.n_head


In [22]:
model_args = ModelArgs(
    model_type="gpt2",
    n_ctx=1024,
    n_embd=768,
    n_head=12,
    n_layer=12,
    n_positions=1024,
    layer_norm_epsilon=1e-5,
    vocab_size=50257,
    attn_pdrop=0.1,
    embd_pdrop=0.1,
    resid_pdrop=0.1,
)

In [23]:
batch = mx.array([[123, 456, 789, 1011]])
print(batch.shape)
print(batch)

(1, 4)
array([[123, 456, 789, 1011]], dtype=int32)


In [24]:
wte = nn.Embedding(model_args.vocab_size, model_args.n_embd)

In [25]:
hidden_states = wte(batch)

In [26]:
position_ids = mx.array(np.arange(batch.shape[1]))

In [27]:
hidden_states += nn.Embedding(model_args.n_positions, model_args.n_embd)(position_ids)

In [28]:
mask = create_additive_causal_mask(hidden_states.shape[1], 0)
print(mask)

array([[-0, -1e+09, -1e+09, -1e+09],
       [-0, -0, -1e+09, -1e+09],
       [-0, -0, -0, -1e+09],
       [-0, -0, -0, -0]], dtype=float32)


In [29]:
ln_1 = nn.LayerNorm(model_args.n_embd, eps=model_args.layer_norm_epsilon)
hidden_states = Attention(model_args)(ln_1(hidden_states), mask)
print(hidden_states.shape)

(1, 4, 768)


In [30]:
MLP(model_args)(hidden_states).shape

(1, 4, 768)

In [31]:
model = GPT2Model(model_args)

In [13]:
print(model)

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop_embd): Dropout(p=0.09999999999999998)
  (h): Sequential(
    (layers.0): TransformerBlock(
      (attn): Attention(
        (c_attn): Linear(input_dims=768, output_dims=2304, bias=True)
        (c_proj): Linear(input_dims=768, output_dims=768, bias=True)
        (drop_attn): Dropout(p=0.09999999999999998)
      )
      (mlp): MLP(
        (c_fc): Linear(input_dims=768, output_dims=3072, bias=True)
        (c_proj): Linear(input_dims=3072, output_dims=768, bias=True)
      )
      (ln_1): LayerNorm(768, eps=1e-05, affine=True)
      (ln_2): LayerNorm(768, eps=1e-05, affine=True)
      (drop_resid): Dropout(p=0.09999999999999998)
    )
    (layers.1): TransformerBlock(
      (attn): Attention(
        (c_attn): Linear(input_dims=768, output_dims=2304, bias=True)
        (c_proj): Linear(input_dims=768, output_dims=768, bias=True)
        (drop_attn): Dropout(p=0.09999999999999998)
      )
      (mlp): MLP(
  

In [33]:
print(model)

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop_embd): Dropout(p=0.09999999999999998)
  (h.0): TransformerBlock(
    (attn): Attention(
      (c_attn): Linear(input_dims=768, output_dims=2304, bias=True)
      (c_proj): Linear(input_dims=768, output_dims=768, bias=True)
      (drop_attn): Dropout(p=0.09999999999999998)
    )
    (mlp): MLP(
      (c_fc): Linear(input_dims=768, output_dims=3072, bias=True)
      (c_proj): Linear(input_dims=3072, output_dims=768, bias=True)
    )
    (ln_1): LayerNorm(768, eps=1e-05, affine=True)
    (ln_2): LayerNorm(768, eps=1e-05, affine=True)
    (drop_resid): Dropout(p=0.09999999999999998)
  )
  (h.1): TransformerBlock(
    (attn): Attention(
      (c_attn): Linear(input_dims=768, output_dims=2304, bias=True)
      (c_proj): Linear(input_dims=768, output_dims=768, bias=True)
      (drop_attn): Dropout(p=0.09999999999999998)
    )
    (mlp): MLP(
      (c_fc): Linear(input_dims=768, output_dims=3072, bias=True)
      (

In [40]:
print(model.h[0].ln_1.bias)

array([0, 0, 0, ..., 0, 0, 0], dtype=float32)
