In [28]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten
import numpy as np
from functools import partial
import time
import math
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [29]:
words = open('names.txt', 'r').read().splitlines()
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [40]:
batch_size = 32
context_size = 3 # Context size in tokens of the model
num_iters = 100_000 # Iterations to train for
learning_rate = 1e-1 # SGD learning rate
lr_warmup = 200 # LR linear warmup iterations
weight_decay = 1e-4 # Set the weight decay
steps_per_eval = 1000 # Number of training steps between validations
steps_per_report = 10 # Number of training steps between loss reporting
optimizer = optim.SGD(
  learning_rate=learning_rate,
  weight_decay=weight_decay
)

In [31]:
# build the dataset
def build_dataset(words, label=''):
  X, Y = [], [] # inputs, labels
  for w in words:
    # print(w)
    context = [0] * context_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      # print(''.join(itos[i] for i in context), '--->', itos[ix])
      context = context[1:] + [ix] # crop and append (rolling window of context)
  X = mx.array(X)
  Y = mx.array(Y)
  print(label, X.shape, Y.shape)
  return X, Y

np.random.seed(42)
np.random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

X_train, Y_train = build_dataset(words[:n1], label='train')
X_val, Y_val = build_dataset(words[n1:n2], label='validation')
X_test, Y_test = build_dataset(words[n2:], label='test')

train (182671, 3) (182671,)
validation (22784, 3) (22784,)
test (22691, 3) (22691,)


In [32]:
class MLP(nn.Module):
  def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
    super().__init__()
    layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
    self.layers = [
      nn.Linear(idim, odim, bias=True) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
    ]

  def __call__(self, x):
    for l in self.layers[:-1]:
      x = nn.tanh(l(x))
    return self.layers[-1](x)

In [33]:
key = mx.random.key(2147483647)
E = mx.random.normal(shape=(27, 10), key=key)
input_dim = 3 * 10
hidden_dim = 200
output_dim = 27
num_layers = 2
model = MLP(input_dim, hidden_dim, output_dim, num_layers)
mx.eval(model.parameters())

In [34]:
def loss_fn(model, x, y):
  logits = model(x)
  losses = nn.losses.cross_entropy(logits, y)
  return mx.mean(losses)

In [35]:
def to_samples(context_size, dataset):
  tokens = dataset.size
  window_size = context_size + 1  # include target
  samples = tokens - window_size + 1
  X = np.lib.stride_tricks.as_strided(
    dataset,
    shape=(samples, window_size),
    strides=(dataset.itemsize, dataset.itemsize),
  )
  return X[:, :-1], X[:, 1:]

In [36]:
def iterate_batches(batch_size, context_size, dataset):
  inputs, targets = to_samples(context_size, dataset)
  s = 0
  while True:
    if s == 0:
      # Reset permutation:
      perm = np.random.permutation(inputs.shape[0])
    ids = perm[s : s + batch_size]
    yield inputs[ids], targets[ids]
    s += batch_size
    if s >= inputs.shape[0]:
      s = 0

In [37]:
def eval_fn(context_size, dataset):
  inputs, targets = map(mx.array, to_samples(context_size, dataset))
  loss = 0
  for s in range(0, targets.shape[0], 32):
    bx, by = inputs[s : s + 32], targets[s : s + 32]
    bx, by = map(mx.array, (bx, by))
    losses = loss_fn(model, inputs, targets, reduce=False)
    loss += mx.sum(losses).item()
  return loss / len(targets)

In [38]:
state = [model.state, optimizer.state]

In [39]:
@partial(mx.compile, inputs=state, outputs=state)
def step(inputs, targets):
  loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
  loss, grads = loss_and_grad_fn(model, inputs, targets)
  optimizer.update(model, grads)
  return loss

In [42]:
iterations = []
losses = []
tic = time.perf_counter()
for it in range(num_iters):
  itx = mx.random.randint(0, X_train.shape[0], (batch_size,))

  inputs = E[X_train[itx]]
  inputs = inputs.reshape(inputs.shape[0], inputs.shape[1] * inputs.shape[2])
  targets = Y_train[itx]
  
  optimizer.learning_rate = min(1, it / lr_warmup) * learning_rate
  
  loss = step(inputs, targets)
  mx.eval(state)
  losses.append(loss.item())
  iterations.append(it)
  
  # if (it + 1) % steps_per_report == 0:
  #   train_loss = np.mean(losses)
  #   toc = time.perf_counter()
  #   print(
  #     f"Iter {it + 1}: Train loss {train_loss:.3f}, "
  #     f"It/sec {steps_per_report / (toc - tic):.3f}"
  #   )
  #   losses = []
  #   tic = time.perf_counter()
  # if (it + 1) % steps_per_eval == 0:
  #   val_loss = eval_fn(X_val)
  #   toc = time.perf_counter()
  #   print(
  #     f"Iter {it + 1}: "
  #     f"Val loss {val_loss:.3f}, "
  #     f"Val ppl {math.exp(val_loss):.3f}, "
  #     f"Val took {(toc - tic):.3f}s, "
  #   )
  #   tic = time.perf_counter()

# test_loss = eval_fn(X_test)
# test_ppl = math.exp(test_loss)
# print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")

ValueError: [eval] Attempting to eval an array without a primitive.

In [None]:
plt.plot(iterations, losses)