In [2]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

In [1]:
# hyperparameters

batch_size = 16
block_size = 32
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0

In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }

def encode(input_string):
    return [stoi[char] for char in input_string]

def decode(input_list):
    return ''.join([itos[i] for i in input_list])

In [4]:
data = mx.array(encode(text))
# train test split
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = mx.random.randint(0, len(data) - block_size, (batch_size,))
    ix = [i.item() for i in ix]
    x = mx.stack([data[i:i+block_size] for i in ix])
    y = mx.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

def loss_fn(model, X, y):
    logits = model(X)
    B, T, C = logits.shape
    logits = logits.reshape(B*T, C)
    targets = y.reshape(B*T)
    return mx.mean(nn.losses.cross_entropy(logits, targets))

def estimate_loss(model):
    out = {}
    for split in ['train', 'val']:
        losses = mx.zeros(eval_iters)
        for k in range(eval_iters):
            X, y = get_batch(split)
            loss = loss_fn(model, X, y)
            losses[k] = loss.item()
        out[split] = mx.mean(losses)
    return out


In [5]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(input_dims=n_embd, output_dims=head_size, bias=False)
        self.query = nn.Linear(input_dims=n_embd, output_dims=head_size, bias=False)
        self.value = nn.Linear(input_dims=n_embd, output_dims=head_size, bias=False)
        self.tril = mx.tril(mx.ones((block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def __call__(self, X):
        B, T, C = X.shape
        k = self.key(X)
        q = self.query(X)
        wei = q @ k.transpose((0, -1, -2)) * C ** -0.5
        wei = mx.where(self.tril[:T, :T] == 0, mx.array(float('-inf')), wei)
        wei = nn.softmax(wei, axis=-1)
        wei = self.dropout(wei)
        v = self.value(X)
        out = wei @ v
        return out
    
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = [Head(head_size=head_size) for _ in range(num_heads)]
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def __call__(self, X):
        out = mx.concatenate([h(X) for h in self.heads], axis=-1)
        out = self.dropout(self.proj(out))
        return out
    
class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )
    def __call__(self, X):
        return self.net(X)
    
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def __call__(self, X):
        X = X + self.sa(self.ln1(X))
        X = X + self.ffwd(self.ln2(X))
        return X
    
class TransformerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def __call__(self, idx):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(mx.arange(T))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        return logits
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits = self(idx_cond)
            logits = logits[:, -1, :]
            idx_next = mx.random.categorical(logits, num_samples=1)
            idx = mx.concatenate((idx, idx_next), axis=-1)
        return idx

In [6]:
model = TransformerModel()

In [7]:
optimizer = optim.AdamW(learning_rate=learning_rate)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

In [8]:
for iter in range(max_iters):
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss(model)
        print(f"step {iter}: train loss {losses['train'].item():.4f}, val loss {losses['val'].item():.4f}")
    
    xb, yb = get_batch('train')

    loss, grads = loss_and_grad_fn(model, xb, yb)
    optimizer.update(model, grads)

step 0: train loss 4.4852, val loss 4.4830
step 100: train loss 2.5373, val loss 2.5373


: 

In [None]:
context = mx.zeros((1,1), dtype=mx.int32)
print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))


More good, my kinger him heart. I will me, shalt in myself,
More madeather flikestee ard wart's grated my years too!
Tybals ause to expecotegh; I way sir, Varmemean your Kecour
Tore.

GRETS:
The villain:
One myspity blen you reads men every, conce privilegether, lyself, by my respy;
And thiness to be me nagmeman peak?

KING RICHARD II:
If I like I am to will
Oxink, rame grment to thing but hit
he was
year?

MENENIUS:
The will way, an Citizen 'liege will doo,
till I disteou one of stribysernick,



In [None]:
model.save_weights('char_level_mpx.safetensors')