In [44]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from IPython.display import clear_output

In [45]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }

def encode(input_string):
    return [stoi[char] for char in input_string]

def decode(input_list):
    return ''.join([itos[i] for i in input_list])

n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
block_size = 32


In [46]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(input_dims=n_embd, output_dims=head_size, bias=False)
        self.query = nn.Linear(input_dims=n_embd, output_dims=head_size, bias=False)
        self.value = nn.Linear(input_dims=n_embd, output_dims=head_size, bias=False)
        self.tril = mx.tril(mx.ones((block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def __call__(self, X):
        B, T, C = X.shape
        k = self.key(X)
        q = self.query(X)
        wei = q @ k.transpose((0, -1, -2)) * C ** -0.5
        wei = mx.where(self.tril[:T, :T] == 0, mx.array(float('-inf')), wei)
        wei = nn.softmax(wei, axis=-1)
        wei = self.dropout(wei)
        v = self.value(X)
        out = wei @ v
        return out
    
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = [Head(head_size=head_size) for _ in range(num_heads)]
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def __call__(self, X):
        out = mx.concatenate([h(X) for h in self.heads], axis=-1)
        out = self.dropout(self.proj(out))
        return out
    
class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )
    def __call__(self, X):
        return self.net(X)
    
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def __call__(self, X):
        X = X + self.sa(self.ln1(X))
        X = X + self.ffwd(self.ln2(X))
        return X
    
class TransformerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def __call__(self, idx):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(mx.arange(T))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        return logits
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits = self(idx_cond)
            logits = logits[:, -1, :]
            idx_next = mx.random.categorical(logits, num_samples=1)
            idx = mx.concatenate((idx, idx_next), axis=-1)
            print(decode(idx_next[0].tolist()), end='')
        return idx

In [47]:
model = TransformerModel()
model.load_weights('char_level_mpx.safetensors')

In [49]:
context = mx.zeros((1,1), dtype=mx.int32)
generated = model.generate(context, max_new_tokens=1000)[0].tolist()

SALANUS:
Lord All, thy botch may was ere, thou heaving she upon and negleher and begge heave
the havet Clatio though yuse homost, and so excopperoporforce,
Be
And we to at he pilles
deat them old. my have and wais:
Him:
Vuch thou life, he'rt, ay, one Lord fly a profemmine a creath,
that's know meany and by somembers
Her leave our life, my heart I crost.

First Servant:
As the prince that Benister not thing:
So he'rt up,
Ne's my limped in wither obe with of will persulace.

Fince:
My Provost, go and cries me nope, sit parish, chang gone
With yet I'll eady? not upen ving;
But to rembred ance for truphecy well.
Beshings, have is wore here
Stardertly twells for this neaturm and be night;
My all spare weak shall wear I rest offl men
We Presently me down, when is not unsterning. I we add.

LOUCESTER:
Yet peope as it in poors ords I dany,
And retiek the bairly heads make livise,
To shall chold that prhenclight chamb-seet,
So my darker no the warm impree danio,
That, my clipy in like of thy in