In [1]:
import mlx.core as mx
import mlx.nn as nn

In [2]:
# Load dataset
names = []
with open('names.txt', 'r') as file:
    while line := file.readline():
        names.append(line.rstrip())

In [3]:
# Get unique chars in dataset
chars = set()
for name in names:
    for c in name:
        chars.add(c)

# Assign value to each character -- this will be our embedding.
stoi = {ch:i+1 for i,ch in enumerate(sorted(chars))}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [4]:
# Build dataset
block_size = 3
X = []
y = []

for name in names[:5]:
    ctx = [0] * block_size
    for c in name + '.':
        idx = stoi[c]
        X.append(ctx)
        y.append(idx)
        ctx = ctx[1:] + [stoi[c]]

X = mx.array(X)
y = mx.array(y)

In [5]:
X.shape, X.dtype, y.shape, y.dtype

([32, 3], mlx.core.int32, [32], mlx.core.int32)

In [6]:
# Create 2D embedding for each letter
C = mx.random.normal([27, 2])

In [7]:
# Embed each letter in each of the examples in the training data, X
emb = C[X]
emb.shape

[32, 3, 2]

In [8]:
# Create new layer
# Input size = 3 x 2; This is b/c each example in `emb` contains 3 chars, each of which have 2 dims.
# Output size = 100; This is arbitrary.
W1 = mx.random.normal([6, 100])
b1 = mx.random.normal([100])

### First Layer
We want to multiply our embedded input by our first layer of weights, add the bias, and perform a tanh function over the results to normalize the output: 
    
    tanh(emb @ W1 + b1)


However, the current shapes of our tensors don't support this multiplication operation:

    emb.shape == [32, 3, 2]
    W1.shape == [6, 100]


To solve this, we need to combine the second and third dimensions of our embedded input tensor, giving us:
    
    emb.shape == [32, 6]


This represents 32 examples of 3 characters, each with a 2-dimensional embedding:
    
    Ex. [Char1FirstEmb, Char1SecondEmb, Char2FirstEmb, Char2SecondEmb, Char3FirstEmb, Char3SecondEmb]

In [9]:
# We can achieve this functionality by using `reshape()`.
# Provides a more memory-efficient way of re-shaping the array
emb_reshaped = mx.reshape(emb, (32, 6))
print(emb_reshaped[:5])
print(emb_reshaped.shape)

array([[0.406554, 1.14167, 0.406554, 1.14167, 0.406554, 1.14167],
       [0.406554, 1.14167, 0.406554, 1.14167, -0.879908, 0.838293],
       [0.406554, 1.14167, -0.879908, 0.838293, 0.261963, 0.0893494],
       [-0.879908, 0.838293, 0.261963, 0.0893494, 0.261963, 0.0893494],
       [0.261963, 0.0893494, 0.261963, 0.0893494, -0.430756, -0.792048]], dtype=float32)
[32, 6]


In [10]:
# Generalize the re-shaping of the tensor to accommodate arbitrary block_size's
emb_flattened = mx.flatten(emb, start_axis=1)
print(emb_flattened[:5])
print(emb_flattened.shape)

array([[0.406554, 1.14167, 0.406554, 1.14167, 0.406554, 1.14167],
       [0.406554, 1.14167, 0.406554, 1.14167, -0.879908, 0.838293],
       [0.406554, 1.14167, -0.879908, 0.838293, 0.261963, 0.0893494],
       [-0.879908, 0.838293, 0.261963, 0.0893494, 0.261963, 0.0893494],
       [0.261963, 0.0893494, 0.261963, 0.0893494, -0.430756, -0.792048]], dtype=float32)
[32, 6]


In [11]:
# Perform the matrix multiplication and apply tanh
h = mx.tanh(emb_flattened @ W1 + b1)
print(h[:5])
print(h.shape)

array([[-0.99985, -0.977032, -0.937331, ..., -0.92223, -0.723898, -0.413056],
       [-0.997375, -0.994642, -0.901601, ..., -0.962323, 0.705885, -0.556825],
       [-0.996248, -0.825479, -0.962809, ..., -0.954416, -0.226425, -0.999833],
       [-0.996994, -0.0197523, -0.924646, ..., -0.261878, -0.534077, -0.0220973],
       [-0.713046, 0.602526, -0.973489, ..., 0.0509802, 0.809074, -0.944905]], dtype=float32)
[32, 100]


### Next Layer
This layer consists of another set of weights and biases, W2 and b2. It produces logits by multiplying the outputs of the previous layer by W2 and adding the bias vector b2.

In [12]:
# Next layer produces logits
W2 = mx.random.normal([100, 27])
b2 = mx.random.normal([27])
logits = h @ W2 + b2
logits.shape

[32, 27]

### Final Layer
To make the logits useful, we must perform a softmax operation. This gives us a vector of normalized probabilities for each character in an example.

In [13]:
# Complete Softmax over all logits (manually)
counts = logits.exp()
prob = counts / counts.sum(1, keepdims=True)

In [14]:
prob.shape

[32, 27]

In [15]:
# Get the probability of the correct character produced by the model, as defined by `Y`
print(y)
prob[mx.arange(32), y]

array([5, 13, 13, ..., 9, 1, 0], dtype=int32)


array([2.38166e-05, 5.00042e-06, 1.00381e-07, ..., 3.27193e-08, 5.95435e-12, 2.14101e-09], dtype=float32)

### Calculate Loss
With these probabilities, we are able to calculate the loss (negative log likelihood).

In [16]:
# For each example, index into the y-th position to retrieve the probability calculated for the correct label.
loss = -prob[mx.arange(32), y].log().mean()
loss

array(13.8181, dtype=float32)

# Swap out explicit operations for library functions below

In [17]:
#mx.default_stream(mx.cpu)

In [18]:
# Build dataset
def build_dataset(words):
    block_size = 3
    X = []
    y = []
    
    for word in words:
        ctx = [0] * block_size
        for c in word + '.':
            idx = stoi[c]
            X.append(ctx)
            y.append(idx)
            ctx = ctx[1:] + [stoi[c]]
    
    X = mx.array(X)
    y = mx.array(y)
    print(X.shape, y.shape)
    return X, y

In [19]:
# Split dataset
import random
random.seed(42)
random.shuffle(names)
n1 = int(0.8 * len(names))
n2 = int(0.9 * len(names))
Xtr, Ytr = build_dataset(names[:n1])
Xdev, Ydev = build_dataset(names[n1:n2])
Xtest, Ytest = build_dataset(names[n2:])

[182625, 3] [182625]
[22655, 3] [22655]
[22866, 3] [22866]


In [20]:
print('Total # of words:', len(names))
print('# of words in training set:', n1)
print('# of words in dev set:', n2 - n1)
print('# of words in test set:', len(names) - n2)

Total # of words: 32033
# of words in training set: 25626
# of words in dev set: 3203
# of words in test set: 3204


In [21]:
# Peek at training dataset shape
Xtr.shape, Ytr.shape

([182625, 3], [182625])

In [22]:
mx.random.seed(42)
C = mx.random.normal([27, 2])
W1 = mx.random.normal([6, 300])
b1 = mx.random.normal([300])
W2 = mx.random.normal([300, 27])
b2 = mx.random.normal([27])
parameters = {'C': C, 'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2}

In [23]:
# Total number of trainable parameters
sum(p.size for p in parameters.values())

10281

In [24]:
# Loss Function -- passing all parameters to calculate the gradient
# NOTE: This includes the forward pass.
def new_loss_fn(params, ix):
    emb_vals = mx.flatten(C[Xtr[ix]], start_axis=1)
    h = mx.tanh(emb_vals @ params['W1'] + params['b1'])
    logits = h @ params['W2'] + params['b2']
    return nn.losses.cross_entropy(logits, Ytr[ix], reduction='mean')

In [25]:
# Function to calculate loss and gradient
loss_and_grad_fn = mx.value_and_grad(new_loss_fn)

In [26]:
# Training Loop
for _ in range(20000):
    
    # Use a minibatch
    ix = mx.random.randint(0, Xtr.shape[0], (32,))

    # Calculate loss
    loss, grads = loss_and_grad_fn(parameters, ix)

    # Update
    for k in parameters.keys():
        parameters[k] += -0.01 * grads[k]

print(loss.item())

2.8529322147369385


In [27]:
# Evaluate loss on entire training set
emb_vals = mx.flatten(C[Xtr], start_axis=1)
h = mx.tanh(emb_vals @ parameters['W1'] + parameters['b1'])
logits = h @ parameters['W2'] + parameters['b2']
nn.losses.cross_entropy(logits, Ytr, reduction='mean').item()

2.825923204421997

In [28]:
# Evaluate loss on dev set
emb_vals = mx.flatten(C[Xdev], start_axis=1)
h = mx.tanh(emb_vals @ parameters['W1'] + parameters['b1'])
logits = h @ parameters['W2'] + parameters['b2']
nn.losses.cross_entropy(logits, Ydev, reduction='mean').item()

2.8314409255981445

## Sample from the model

In [68]:
import numpy as np

mx.random.seed(42 + 10)

# Sample 20 names
for _ in range(20):
    out = []
    context = [0] * block_size
    while True:
        emb_vals = mx.flatten(C[mx.array(context)])
        h = mx.tanh(emb_vals @ parameters['W1'] + parameters['b1'])
        logits = h @ parameters['W2'] + parameters['b2']
        probs = mx.softmax(logits)
        ixList = np.random.multinomial(1, probs.tolist())
        ix = np.where(ixList == 1)[0].item()
        context = context[1:] + [ix]
        out.append(ix)
        if ix == 0:
           break

    # Print result
    print(''.join(itos[i] for i in out))

kar.
ya.
lanei.
naa.
a.
gid.
ber.
maryo.
yahulr.
anaa.
kyrnn.
aumlrc.
basvymoda.
jayeenm.
ikhyldnn.
syzsha.
lan.
jadeavaaysen.
man.
ter.
