In [1]:
#imports
import jax
import jax.numpy as jnp
import jax.random as random
import optax
from tokenizers import CharBPETokenizer


gpu_device = jax.device_get('gpu')[0]
cpu_device = jax.device_get('cpu')[0]
# LSTM
# xs = B, input_size = B, T, C
# h = c = y = B, output_size = B, T, logits_size = B, T, vocab_size


In [2]:
#dataset
with open('data/shakespeare.txt', 'r') as file:
  dataset = file.read()

# tokenize
vocab = sorted(list(set(dataset)))
print("vocab length:", len(vocab))

token_to_char = dict(enumerate(vocab))
char_to_token = dict([(v, k) for k, v in token_to_char.items()])
decode = lambda tokens: "".join([token_to_char[int(token)] for token in tokens])
encode = lambda chars: jnp.array([char_to_token[c] for c in chars])

print("dog", encode("dog"), decode(encode("dog")))

dataset_tokens = encode(dataset)

vocab length: 65
dog [42 53 45] dog


In [50]:
lstm_blocks = 1
model_size = 128

input_size = len(vocab) # just do one-hot for now
hidden_size = model_size
output_size = len(vocab) # logits => one-hot => tokens


# init LSTM params
def init_LSTM_params(key, lstm_blocks, input_size, model_size, output_size):
  layers = 8
  keys = random.split(key, layers*lstm_blocks + 2)
  hxconcat_size = input_size + model_size
  he = lambda rkey, shape: random.normal(rkey, shape=shape) * jnp.sqrt(2 / shape[0])
  params = [
    {
      "wU" : he(keys[layers*i + 0], (hxconcat_size, model_size)),
      "bU" : jnp.zeros((model_size,)),
      "wC" : he(keys[layers*i + 6], (hxconcat_size, model_size)),
      "bC" : jnp.zeros((model_size,)),
      "wF1": he(keys[layers*i + 1], (hxconcat_size, model_size)),
      "bF1": jnp.zeros((model_size,)),
      "wF2": he(keys[layers*i + 2], (hxconcat_size, model_size)),
      "bF2": jnp.zeros((model_size,)),
      "wO" : he(keys[layers*i + 3], (hxconcat_size, model_size)),
      "bO" : jnp.zeros((model_size,)),
      # this is for the y layer, which i am probably imlementing wrong.
      "wY1" : he(keys[layers*i + 4], (model_size, output_size)),
      "bY1" : jnp.zeros((output_size,)),
      "wY2" : he(keys[layers*i + 5], (output_size, output_size)),
      "bY2" : jnp.zeros((output_size,)),
    }
    for i in range(lstm_blocks)
  ]
  params[0].update(
    {
    "h0" : random.normal(keys[layers*(layers - 1) + 0], shape=(model_size, )) * jnp.sqrt(2 / model_size),
    "c0" : random.normal(keys[layers*(layers - 1) + 1], shape=(model_size, )) * jnp.sqrt(2 / model_size),
  })
  return params

# x: input_size
# h: hidden_size (same as model_size)
# c: hidden_size (same as model_size)



def lstm_step(lstm_params, xs, h, c, i):
  hxconcat = jax.lax.concatenate([h, xs[:, i]], dimension=1) #B, h ++ B, C => B, h+c
  # update gate
  update = jax.nn.sigmoid(hxconcat @ lstm_params[i]["wU"] + lstm_params[i]["bU"])
  candidate = jax.nn.tanh(hxconcat @ lstm_params[i]["wC"] + lstm_params[i]["bC"])
  c = c + update * candidate # (batch, c) => (batch, c)
  # forget gate
  forget1 = jax.nn.sigmoid(hxconcat @ lstm_params[i]["wF1"] + lstm_params[i]["bF1"])
  forget2 = jax.nn.tanh(hxconcat @ lstm_params[i]["wF2"] + lstm_params[i]["bF2"])
  forget = forget1 * forget2
  c = c + forget # (batch, c) => (batch, c)

  # output
  o = jax.nn.sigmoid(hxconcat @ lstm_params[i]["wO"] + lstm_params[i]["bO"])  # B, model_size
  h = jax.nn.tanh(c) * o # (B, model_size)
  return h, c

# LSTM forward
import functools
@functools.partial(jax.jit, static_argnames=['batches'])
def lstm_forward(batches, lstm_params, xs):
  logits_ts = []
  i = 0
  while i < lstm_blocks:
    if i == 0:
      # setup
      # step 0
      # initialize h-1 and c-1 as zeros
      h = jnp.tile(lstm_params[0]["h0"], (batches, 1))
      c = jnp.tile(lstm_params[0]["c0"], (batches, 1))
    hxconcat = jax.lax.concatenate([h, xs[:, i]], dimension=1) #B, h ++ B, C => B, h+c
    # update gate
    update = jax.nn.sigmoid(hxconcat @ lstm_params[i]["wU"] + lstm_params[i]["bU"])
    candidate = jax.nn.tanh(hxconcat @ lstm_params[i]["wC"] + lstm_params[i]["bC"])
    c = c + update * candidate # (batch, c) => (batch, c)
    # forget gate
    forget1 = jax.nn.sigmoid(hxconcat @ lstm_params[i]["wF1"] + lstm_params[i]["bF1"])
    forget2 = jax.nn.tanh(hxconcat @ lstm_params[i]["wF2"] + lstm_params[i]["bF2"])
    forget = forget1 * forget2
    c = c + forget # (batch, c) => (batch, c)

    # output
    o = jax.nn.sigmoid(hxconcat @ lstm_params[i]["wO"] + lstm_params[i]["bO"])  # B, model_size
    h = jax.nn.tanh(c) * o # (B, model_size)

    # this is wrong but.. whatever
    y = h @ lstm_params[i]['wY1'] + lstm_params[i]["bY1"]
    #y = y @ lstm_params[i]['wY2'] + lstm_params[i]["bY2"]

    logits_ts.append(y)
    i += 1
  logits = jnp.transpose(jnp.array(logits_ts), axes=(1, 0, 2)) # T, B, C => B, T, C
  return logits

@jax.jit
def loss(lstm_params, xs, ys):
  batches = xs.shape[0] # B, T, C
  logits = lstm_forward(batches, lstm_params, xs)
  vocab_size = logits.shape[-1]
  ys_one_hot = jax.nn.one_hot(ys, vocab_size, axis=-1)
  logprobs = jax.nn.log_softmax(logits, axis=-1)
  crossentropylosses = -jnp.sum(ys_one_hot * logprobs, axis=-1)
  crossentropyloss = jnp.mean(crossentropylosses)
  return crossentropyloss

lr = 2e-3 # thanks karpathy
optimizer = optax.adam(learning_rate=lr)

# make optimizer a static arg in jit or it breaks
@jax.jit
def train(lstm_params, xs, ys, opt_state):
  step_loss, grads = jax.value_and_grad(loss)(lstm_params, xs, ys)
  param_updates, updated_opt_state = optimizer.update(grads, opt_state, lstm_params)
  updated_lstm_params = optax.apply_updates(lstm_params, param_updates)
  return updated_lstm_params, updated_opt_state, step_loss, grads


# set up lstm params
keys = random.split(random.PRNGKey(123), 20)
lstm_params = init_LSTM_params(keys[0], lstm_blocks, input_size, model_size, output_size)
opt_state = optimizer.init(lstm_params)

# train
# for now just overfit on small sample idk lol
dataset_tokens = jnp.array(dataset_tokens)
train_tokens = dataset_tokens[:int(len(dataset_tokens)*0.9)]
test_tokens = dataset_tokens[int(len(dataset_tokens)*0.9):]

epochs = 10
for epoch in range(epochs):
  samples = (len(train_tokens) - 1) // lstm_blocks
  for i in range(0, len(train_tokens)-1, lstm_blocks):
    xtokens = train_tokens[i:i+lstm_blocks]
    ytokens = train_tokens[i+1:i+lstm_blocks+1]

    xtokens_batch = xtokens[None, :]
    ytokens_batch = ytokens[None, :] # artificially create batch of 1

    xembeds_batch = jax.nn.one_hot(xtokens_batch, len(vocab), axis=-1)

    # train example
    logits_batch = lstm_forward(xembeds_batch.shape[0], lstm_params, xembeds_batch)
    prediction_batch = jnp.argmax(logits_batch, axis=-1)

    # val batch
    val_batch_size = 3 # nx3
    j = i % ((len(test_tokens) - 1)//((val_batch_size)*lstm_blocks))
    idx = j*val_batch_size*lstm_blocks
    xtokens_val_batch = test_tokens[idx:idx+lstm_blocks*val_batch_size].reshape(-1, lstm_blocks) # batches of sequences lstm block count size
    ytokens_val_batch = test_tokens[idx+1:idx+lstm_blocks*val_batch_size+1].reshape(-1, lstm_blocks)
    xembeds_val_batch = jax.nn.one_hot(xtokens_val_batch, len(vocab), axis=-1)
    
    logits_val_batch = lstm_forward(xembeds_batch.shape[0], lstm_params, xembeds_batch)
    prediction_val_batch = jnp.argmax(logits_val_batch, axis=-1)
    ys_onehot = jax.nn.one_hot(ytokens_val_batch, len(vocab), axis=-1)
    logprobs = jax.nn.log_softmax(logits_val_batch, axis=-1)
    crossentropies = -jnp.sum(ys_onehot*logprobs,axis=-1)
    val_loss = jnp.mean(crossentropies) #lmao
    val_accuracy = jnp.mean(prediction_val_batch == ytokens_val_batch)


    lstm_params, opt_state, step_loss, grads = train(lstm_params, xembeds_batch, ytokens_batch, opt_state)
    print(epoch * samples + i, f"{step_loss:1.4f}", "pred:", xtokens_batch[0], "=>", ytokens_batch[0], "?=", prediction_batch[0])
    print("VAL::", f"loss: {val_loss:1.4f}", f"accuracy: {val_accuracy:1.4f}" )
    print()

0 4.1259 pred: [18 47] => [47 56] ?= [43 49]
VAL:: loss: 4.1918 accuracy: 0.0000

2 4.0629 pred: [56 57] => [57 58] ?= [43 45]
VAL:: loss: 4.0967 accuracy: 0.0000

4 4.1795 pred: [58  1] => [ 1 15] ?= [47 56]
VAL:: loss: 4.1362 accuracy: 0.1667

6 3.8378 pred: [15 47] => [47 58] ?= [47 56]
VAL:: loss: 4.1007 accuracy: 0.1667

8 4.0790 pred: [58 47] => [47 64] ?= [47 56]
VAL:: loss: 4.1417 accuracy: 0.0000

10 4.1257 pred: [64 43] => [43 52] ?= [47 58]
VAL:: loss: 4.0517 accuracy: 0.0000

12 4.2955 pred: [52 10] => [10  0] ?= [47 58]
VAL:: loss: 4.2543 accuracy: 0.0000

14 4.2744 pred: [ 0 14] => [14 43] ?= [47 58]
VAL:: loss: 4.1668 accuracy: 0.0000

16 4.4379 pred: [43 44] => [44 53] ?= [47 58]
VAL:: loss: 4.2509 accuracy: 0.0000

18 4.1248 pred: [53 56] => [56 43] ?= [47 58]
VAL:: loss: 4.2870 accuracy: 0.0000

20 4.0697 pred: [43  1] => [ 1 61] ?= [47 58]
VAL:: loss: 4.3451 accuracy: 0.0000

22 3.9501 pred: [61 43] => [43  1] ?= [47 58]
VAL:: loss: 4.2064 accuracy: 0.0000

24 3.7723

KeyboardInterrupt: 

In [51]:
def inference(chars):
  xtokens = encode(chars)
  xembed = jax.nn.one_hot(xtokens, len(vocab))[None, :] # artificial single batch
  logits = lstm_forward(xembed.shape[0], lstm_params, xembed)
  yhattokens = jnp.argmax(logits, axis=-1) # along channel axis in B T C. B T C => B T
  sequence = yhattokens[0] # first in 'batch'. B T => T
  return sequence

text = "Once"
print(text, end='')
for i in range(100):
  yseq = inference(text[-lstm_blocks:])
  #print(yseq)
  next_char = decode(yseq)[-1]
  text += next_char
  print(next_char, end='')

Once the the the the the the the the the the the the the the the the the the the the the the the the the

In [23]:
jax.tree_util.tree_map(jnp.linalg.norm, grads)

[{'bC': Array(0.00019623, dtype=float32),
  'bF1': Array(6.8030313e-06, dtype=float32),
  'bF2': Array(0.00020164, dtype=float32),
  'bO': Array(1.0678398e-05, dtype=float32),
  'bU': Array(8.155224e-06, dtype=float32),
  'bY1': Array(6.391086e-05, dtype=float32),
  'bY2': Array(7.6374214e-05, dtype=float32),
  'c0': Array(0.00026105, dtype=float32),
  'h0': Array(0.0003813, dtype=float32),
  'wC': Array(0.00035778, dtype=float32),
  'wF1': Array(1.2403322e-05, dtype=float32),
  'wF2': Array(0.00036764, dtype=float32),
  'wO': Array(1.9468913e-05, dtype=float32),
  'wU': Array(1.48686495e-05, dtype=float32),
  'wY1': Array(0.00026253, dtype=float32),
  'wY2': Array(0.00058146, dtype=float32)},
 {'bC': Array(1.3227721e-05, dtype=float32),
  'bF1': Array(1.1500442e-06, dtype=float32),
  'bF2': Array(1.30392145e-05, dtype=float32),
  'bO': Array(2.1790418e-06, dtype=float32),
  'bU': Array(1.158496e-06, dtype=float32),
  'bY1': Array(1.1653725e-05, dtype=float32),
  'bY2': Array(5.201777e