In [215]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [216]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [217]:
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap
from jax import random as jrandom
from jax import nn
import numpy as np
from tqdm import tqdm

In [218]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):
  X, Y = [], []

  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] # crop and append

  X = jnp.array(X)
  Y = jnp.array(Y)
  print(X.shape, Y.shape)
  return X, Y

import random as prandom
prandom.seed(42)
prandom.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%


(182625, 3) (182625,)
(22655, 3) (22655,)
(22866, 3) (22866,)


In [219]:
key, *g = jrandom.split(jrandom.key(43), 6)

C = jrandom.normal(g[0], (27, 10)) * 0.01
W1 = jrandom.normal(g[1], (30, 200)) * 0.01
b1 = jrandom.normal(g[2], 200) * 0.01
W2 = jrandom.normal(g[3], (200, 27)) * 0.01
b2 = jrandom.normal(g[4], 27) * 0.01
parameters = [C, W1, b1, W2, b2]

In [220]:
sum(np.prod(list(p.shape)) for p in parameters)

np.int64(11897)

In [221]:
# lri = []
lossi = []
stepi = []

In [222]:
def cross_entropy_loss(logits, targets):
    log_probs = nn.log_softmax(logits)
    nll = -jnp.take_along_axis(log_probs, targets[:, jnp.newaxis], axis=1)
    return jnp.mean(nll)

In [223]:
@jit
def forward(parameters, X):
  C, W1, b1, W2, b2 = parameters
  emb = C[X] # (32, 3, 10)
  h = jnp.tanh(emb.reshape(-1, 30) @ W1 + b1) # (32, 200)
  logits = h @ W2 + b2 # (32, 27)
  return logits

In [224]:
@jit
def calc_loss(parameters, X, targets):
    logits = forward(parameters, X)
    log_probs = nn.log_softmax(logits)
    nll = -jnp.take_along_axis(log_probs, targets[:, jnp.newaxis], axis=1)
    return jnp.mean(nll)

In [225]:
with tqdm(total=5000) as pbar:
  for i in range(5000):

    # minibatch construct
    key, ix_key = jrandom.split(key)
    ix = jrandom.randint(ix_key, (32,), 0, Xtr.shape[0])

    # forward pass
    loss, grads = value_and_grad(calc_loss)(parameters, Xtr[ix], Ytr[ix])
    # print(loss.item())
    # print(grads)

    # update
    #lr = lrs[i]
    lr = jnp.array(0.1)

    for j in range(len(parameters)):
      parameters[j] += -lr * grads[j]

    # track stats
    #lri.append(lre[i])
    stepi.append(i)
    lossi.append(jnp.log10(loss).item())

    pbar.update(1)
    if i % 100 == 0:
      pbar.set_postfix(loss=loss.item())

    # break

  #print(loss.item())

100%|██████████| 5000/5000 [00:50<00:00, 98.68it/s, loss=2.48]


In [226]:
val_loss = calc_loss(parameters, Xdev, Ydev)
val_loss

Array(2.436199, dtype=float32)

In [227]:
for _ in range(10):
    out = []
    context = [0] * block_size # initialize with all ...
    while True:
      logits = forward(parameters, jnp.array([context]))
      probs = nn.softmax(logits, axis=1)
      key, key_split = jrandom.split(key)
      ix = jnp.argmax(jrandom.multinomial(key_split, 1, probs), axis=1).item()
      context = context[1:] + [ix]
      out.append(ix)
      if ix == 0:
        break

    print(''.join(itos[i] for i in out))

nainn.
ala.
ayrlya.
dat.
biunan.
yley.
kiya.
cabin.
hadyna.
raqa.
