In [64]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np

In [1]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [2]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [4]:
stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }

def encode(input_string):
    return [stoi[char] for char in input_string]

def decode(input_list):
    return ''.join([itos[i] for i in input_list])

print(encode('hi there'))
print(decode(encode('hi there')))

[46, 47, 1, 58, 46, 43, 56, 43]
hi there


In [36]:
data = mx.array(encode(text))
print(data.shape)
print(data.dtype)

(1115394,)
mlx.core.int32


In [37]:
# train test split
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [38]:
block_size = 8
train_data[:block_size+1]

array([18, 47, 56, ..., 15, 47, 58], dtype=int32)

In [39]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is array([18], dtype=int32) the target: array(47, dtype=int32)
when input is array([18, 47], dtype=int32) the target: array(56, dtype=int32)
when input is array([18, 47, 56], dtype=int32) the target: array(57, dtype=int32)
when input is array([18, 47, 56, 57], dtype=int32) the target: array(58, dtype=int32)
when input is array([18, 47, 56, 57, 58], dtype=int32) the target: array(1, dtype=int32)
when input is array([18, 47, 56, 57, 58, 1], dtype=int32) the target: array(15, dtype=int32)
when input is array([18, 47, 56, ..., 58, 1, 15], dtype=int32) the target: array(47, dtype=int32)
when input is array([18, 47, 56, ..., 1, 15, 47], dtype=int32) the target: array(58, dtype=int32)


In [40]:
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = mx.random.randint(0, len(data) - block_size, (batch_size,))
    ix = [i.item() for i in ix]
    x = mx.stack([data[i:i+block_size] for i in ix])
    y = mx.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')

print('input:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

input:
(4, 8)
array([[60, 43, 1, ..., 58, 1, 39],
       [8, 0, 0, ..., 17, 17, 26],
       [53, 59, 58, ..., 43, 10, 0],
       [39, 47, 58, ..., 52, 1, 44]], dtype=int32)
targets:
(4, 8)
array([[43, 1, 51, ..., 1, 39, 1],
       [0, 0, 29, ..., 17, 26, 1],
       [59, 58, 1, ..., 10, 0, 21],
       [47, 58, 1, ..., 1, 44, 53]], dtype=int32)


In [80]:
class BigramLanguageModel(nn.Module):
    
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def __call__(self, idx):
        logits = self.token_embedding_table(idx)
        
        return logits
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits = self(idx)
            logits = logits[:, -1, :]
            idx_next = mx.random.categorical(logits, num_samples=1)
            idx = mx.concatenate((idx, idx_next), axis=-1)
        return idx
    
def loss_fn(model, X, y):
    logits = model(X)
    B, T, C = logits.shape
    logits = logits.reshape(B*T, C)
    targets = y.reshape(B*T)
    return mx.mean(nn.losses.cross_entropy(logits, targets))
        

In [113]:
m = BigramLanguageModel(vocab_size=vocab_size)

In [114]:
loss_and_grad_fn = nn.value_and_grad(m, loss_fn)

In [115]:
optimizer = optim.AdamW(learning_rate=1e-3)

In [116]:
print(decode(m.generate(idx=mx.zeros((1,1), dtype=mx.int32), max_new_tokens=100)[0].tolist()))


!,?s!qlTyA.Lmywz K
yw:v'ES'rCzb-k'S3RjlM?oUuW&ocO:X
j$iyBbzXySh.U- KMV hNA$p;dBBm--WSYuKAHdRFC3CzsNR


In [121]:
batch_size = 32

for steps in range(5000):
    xb, yb = get_batch('train')
    loss, grads = loss_and_grad_fn(m, xb, yb)
    if steps % 500 == 0:
        print(f'step {steps} loss = {loss}')

    optimizer.update(m, grads)

print(f'step {steps + 1} loss = {loss}')

step 0 loss = array(2.53058, dtype=float32)
step 500 loss = array(2.47609, dtype=float32)
step 1000 loss = array(2.54961, dtype=float32)
step 1500 loss = array(2.48677, dtype=float32)
step 2000 loss = array(2.43697, dtype=float32)
step 2500 loss = array(2.54745, dtype=float32)
step 3000 loss = array(2.41424, dtype=float32)
step 3500 loss = array(2.47758, dtype=float32)
step 4000 loss = array(2.45294, dtype=float32)
step 4500 loss = array(2.52854, dtype=float32)
step 5000 loss = array(2.45235, dtype=float32)


In [122]:
print(decode(m.generate(idx=mx.zeros((1,1), dtype=mx.int32), max_new_tokens=100)[0].tolist()))


Wh.
bo ofimyons.
BE been whe t o s t beee s d mape ars; my,
KI as.
Yo bongr tir ter than
Bertanthanc


In [120]:
steps

4999