In [None]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.utils as utils
import numpy as np
import math

In [None]:
with open("input.txt", "r", encoding='utf-8') as f:
    text = f.read()

vocab = sorted(list(set(text)))
vocab_size = len(vocab)

print(''.join(vocab))
print(vocab_size)

In [None]:
iots = {i: c for i, c in enumerate(vocab)}
stoi = {c: i for i, c in enumerate(vocab)}
def encode(x): return [stoi[c] for c in x]
def decode(x): return ''.join([iots[i] for i in x])


print(encode("hello world"))
print(decode(encode("hello world")))

In [None]:
data = encode(text)
split = int(0.9 * len(data))
train_data = data[:split]
val_data = data[split:]

In [None]:
ctx_len = 128
n_emb = 128
dropout = 0.1
head_size = 128
n_heads = 4
n_layers = 3
num_epochs = 30
batch_size = 64
lr = 1e-3

print(train_data[:ctx_len + 1])

In [None]:
print("inputs: ", train_data[:ctx_len])
print("labels ", train_data[1:ctx_len+1])

In [None]:
X_train = mx.array([train_data[i: i+ctx_len]
                   for i in range(0, len(train_data) - ctx_len, ctx_len)])
y_train = mx.array([train_data[i+1: i+ctx_len+1]
                   for i in range(0, len(train_data) - ctx_len, ctx_len)])
X_val = mx.array([train_data[i: i+ctx_len]
                 for i in range(0, len(val_data) - ctx_len, ctx_len)])
y_val = mx.array([train_data[i+1: i + ctx_len+1]
                 for i in range(0, len(val_data) - ctx_len, ctx_len)])

In [None]:
def get_batches(X, y, b_size, shuffle=True):
    if shuffle:
        ix = np.arange(X.shape[0])
        np.random.shuffle(ix)
        ix = mx.array(ix)
        X = X[ix]
        y = y[ix]

    for i in range(0, X.shape[0], b_size):
        input = X[i: i+b_size]
        label = y[i: i+b_size]
        yield input, label

In [None]:
import math


class Attention(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size
        self.k_proj = nn.Linear(n_emb, head_size, bias=False)
        self.q_proj = nn.Linear(n_emb, head_size, bias=False)
        self.v_proj = nn.Linear(n_emb, head_size, bias=False)

        indices = mx.arange(ctx_len)
        mask = indices[:, None] < indices[None]
        self.causal_mask = mask * -1e9
        # self.c_proj = nn.Linear(head_size, n_emb)
        self.resid_dropout = nn.Dropout(dropout)

    def __call__(self, x):
        B, T, C = x.shape
        K = self.k_proj(x)
        Q = self.q_proj(x)
        V = self.v_proj(x)
        attn_weights = (Q @ K.transpose([0, 2, 1])) / math.sqrt(self.head_size)
        casual_mask = self.causal_mask[:T, :T]
        attn_weights = attn_weights + casual_mask
        attn_weights = mx.softmax(attn_weights, axis=-1)
        attn_weights = self.resid_dropout(attn_weights)
        o = (attn_weights @ V)
        # o = self.c_proj(self.resid_dropout(o))
        # o = self.resid_dropout(o)
        return o

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.heads = [Attention(head_size//n_heads) for _ in range(n_heads)]

    def __call__(self, x):
        return mx.concatenate([head(x) for head in self.heads], axis=-1)

In [None]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.c_fc = nn.Linear(n_emb, 4*n_emb)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4*n_emb, n_emb)
        self.dropout = nn.Dropout(dropout)

    def __call__(self, x):
        x = self.gelu(self.c_fc(x))
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

In [None]:
class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = MLP()
        self.mha = MultiHeadedAttention()
        self.ln_1 = nn.LayerNorm(dims=n_emb)
        self.ln_2 = nn.LayerNorm(dims=n_emb)

    def __call__(self, x):
        x = x + self.mha(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

In [None]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.wte = nn.Embedding(vocab_size, n_emb)
        self.wpe = nn.Embedding(ctx_len, n_emb)
        self.blocks = nn.Sequential(*[Block() for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(dims=n_emb)
        self.lm_head = nn.Linear(n_emb, vocab_size)
        self._init_parameters()
        total_params = sum(
            [p.size for n, p in utils.tree_flatten(self.parameters())])
        print(f"Total params: {(total_params / 1e6):.3f}M")

    def __call__(self, x):
        B, T = x.shape  # (B = batch_size, T = ctx_len)
        tok_emb = self.wte(x)  # (B, T, n_emb)
        pos_emb = self.wpe(mx.arange(T))  # (T, n_emb)
        x = tok_emb + pos_emb  # (B, T, n_emb)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits

    def generate(self, max_new_tokens):
        ctx = mx.zeros((1, 1), dtype=mx.int32)
        for _ in range(max_new_tokens):
            logits = self(ctx[:, -ctx_len:])
            logits = logits[:, -1, :]
            next_tok = mx.random.categorical(logits, num_samples=1)
            ctx = mx.concatenate((ctx, next_tok), axis=1)
        return ctx

    def _init_parameters(self):
        normal_init = nn.init.normal(mean=0.0, std=0.02)
        residual_init = nn.init.normal(
            mean=0.0, std=(0.02 / math.sqrt(2 * n_layers)))
        new_params = []
        for name, module in self.named_modules():
            if isinstance(module, nn.layers.linear.Linear):
                if 'c_proj' in name:
                    new_params.append(
                        (name + '.weight', residual_init(module.weight)))
                else:
                    new_params.append(
                        (name + '.weight', normal_init(module.weight)))
                if 'bias' in module:
                    new_params.append(
                        (name + '.bias', mx.zeros(module.bias.shape)))
            elif isinstance(module, nn.layers.embedding.Embedding):
                new_params.append(
                    (name + '.weight', normal_init(module.weight)))
        self = self.update(utils.tree_unflatten(new_params))

In [14]:
def loss_fn(model, x, y):
    logits = model(x)
    B, T, C = logits.shape  # (batch_size, seq_len, vocab_size)
    logits = logits.reshape(B*T, C)
    y = y.reshape(B*T)
    loss = nn.losses.cross_entropy(logits, y, reduction='mean')
    return loss


model = GPT()
mx.eval(model.parameters())  # Create the model params (mlx is lazy evaluation)
loss_and_grad = nn.value_and_grad(model, loss_fn)
optimizer = optim.AdamW(learning_rate=lr)


for epoch in range(num_epochs):
    model.train(True)
    running_loss = 0
    batch_cnt = 0
    for input, label in get_batches(X_train, y_train, batch_size):
        batch_cnt += 1
        loss, grads = loss_and_grad(model, input, label)
        optimizer.update(model, grads)
        running_loss += loss.item()
        # compute new parameters and optimizer state
        mx.eval(model.parameters(), optimizer.state)
    avg_train_loss = running_loss / batch_cnt
    model.train(False)  # set eval mode
    running_loss = 0
    batch_cnt = 0
    for input, label in get_batches(X_val, y_val, batch_size):
        batch_cnt += 1
        loss = loss_fn(model, input, label)
        running_loss += loss.item()
    avg_val_loss = running_loss / batch_cnt
    print(
        f"Epoch {epoch:2} | train = {avg_train_loss:.4f} | val = {avg_val_loss:.4f}")

Epoch  0 | train = 3.0429 | val = 2.6857
Epoch  1 | train = 2.5997 | val = 2.4566
Epoch  2 | train = 2.3942 | val = 2.2620
Epoch  3 | train = 2.2442 | val = 2.1274
Epoch  4 | train = 2.1144 | val = 1.9901
Epoch  5 | train = 1.9933 | val = 1.8585
Epoch  6 | train = 1.8854 | val = 1.7536
Epoch  7 | train = 1.7985 | val = 1.6778
Epoch  8 | train = 1.7311 | val = 1.6095
Epoch  9 | train = 1.6791 | val = 1.5661
Epoch 10 | train = 1.6366 | val = 1.5311
Epoch 11 | train = 1.6025 | val = 1.5045
Epoch 12 | train = 1.5760 | val = 1.4750
Epoch 13 | train = 1.5528 | val = 1.4576
Epoch 14 | train = 1.5318 | val = 1.4413
Epoch 15 | train = 1.5145 | val = 1.4254
Epoch 16 | train = 1.4991 | val = 1.4080
Epoch 17 | train = 1.4839 | val = 1.3884
Epoch 18 | train = 1.4723 | val = 1.3807
Epoch 19 | train = 1.4615 | val = 1.3732


In [15]:
completion = decode(model.generate(1000)[0].tolist())
print(completion)
with open('completions.txt', 'w') as f:
    f.write(completion)


That be doubtled, but lady the false can to away. Guliet:
So wady it all in the grows of girl still.

PRINCE EDWARD:
Bou give to him.

ISABELLA:
A patient to unjurniors,
Alas, take to him ill be proad ministate.

BENVOLIO:
To make you his
youness dear?

QUEEN MARGARET:
How no.
Lay, long.

CORIOLANUS:
My life honourable repurage us, stays
That I see three it know up, my lord.

CAMILLO:
They dost Cliffondly? I here thumb, I'll well;
Or not eagle steel the playes, end your glown,
But he finet armanly spake, which he dods maid
To make into be titl title your wexts,
For you royal weeping tune that they bear in me weight;
The will will I stroke your tongue,
That hastised the flatter of do shall be tyrangled.

NORTHUMBERLAND:
What I do you come to your day; hast let him
The jalle of senator'd title to justifiation
you be so you the England, as have with you:
I doss that I will sellow; for here nood,
Some brew I lost by thy desire ring: the nobless,
You were else not to of stears and all thei