In [1]:
#imports
import jax
import jax.numpy as jnp
import jax.random as random
import optax
from tokenizers import CharBPETokenizer
import functools
import time


gpu_device = jax.device_get('gpu')[0]
cpu_device = jax.device_get('cpu')[0]
# LSTM
# xs = B, input_size = B, T, C
# h = c = y = B, output_size = B, T, logits_size = B, T, vocab_size


In [2]:
#dataset
with open('data/dnbt_posts.txt', 'r') as file:
  dataset = file.read()

removed_chars = []
frequencies = []
for c in set(dataset):
  frequencies.append((dataset.count(c), c, c.isalnum()))
  if dataset.count(c) < 50:
    removed_chars.append(c)
    dataset = dataset.replace(c, '')


# tokenize
vocab = sorted(list(set(dataset)))
print("vocab length:", len(vocab))

token_to_char = dict(enumerate(vocab))
char_to_token = dict([(v, k) for k, v in token_to_char.items()])
decode = lambda tokens: "".join([token_to_char[int(token)] for token in tokens])
encode = lambda chars: jnp.array([char_to_token[c] for c in chars])

dataset_tokens = encode(dataset)
split_ratio = 0.9
train_tokens = dataset_tokens[:int(len(dataset_tokens)*split_ratio)]
test_tokens = dataset_tokens[int(len(dataset_tokens)*split_ratio):]
del dataset
del dataset_tokens


print("removed:", "".join(removed_chars))
print("dog", encode("dog"), decode(encode("dog")))

vocab length: 84
removed: 🫡ɴ𝘂我𝗪ᴇ😎👀𝗼😤📈ʀᴛ#𝗲{ʟ’🤣👌𝘁ɪ🚀🤦🤷~𝗶ʜ𝘀|𝗿`ᴏ️🎉💪‍😁😭😉$👍們🍰𝗱[吧🌑}*]𝗯😆”𝗰𝗵“🧠🤔😢♂^ᴡ𝗻ᴄᴘᴀ走📉🤯☠
dog [59 70 62] dog


In [11]:
sorted(frequencies)

[(1, 'ɪ', True),
 (1, 'ɴ', True),
 (1, 'ʜ', True),
 (1, 'ᴘ', True),
 (1, 'ᴛ', True),
 (1, '☠', False),
 (1, '們', True),
 (1, '吧', True),
 (1, '我', True),
 (1, '走', True),
 (1, '𝗪', True),
 (1, '𝗯', True),
 (1, '𝗰', True),
 (1, '𝗱', True),
 (1, '𝗶', True),
 (1, '𝗿', True),
 (1, '𝘂', True),
 (1, '🌑', False),
 (1, '🍰', False),
 (1, '🎉', False),
 (1, '👀', False),
 (1, '📈', False),
 (1, '📉', False),
 (1, '😁', False),
 (1, '😉', False),
 (1, '😢', False),
 (1, '😤', False),
 (1, '🚀', False),
 (1, '🤦', False),
 (1, '🤯', False),
 (1, '🤷', False),
 (1, '🧠', False),
 (2, '|', False),
 (2, 'ʀ', True),
 (2, 'ʟ', True),
 (2, 'ᴏ', True),
 (2, '\u200d', False),
 (2, '♂', False),
 (2, '𝗵', True),
 (2, '𝗻', True),
 (2, '𝗼', True),
 (2, '𝘀', True),
 (2, '𝘁', True),
 (2, '😆', False),
 (3, 'ᴀ', True),
 (3, 'ᴄ', True),
 (3, 'ᴡ', True),
 (3, '’', False),
 (3, '“', False),
 (3, '”', False),
 (3, '️', False),
 (3, '𝗲', True),
 (3, '👌', False),
 (3, '👍', False),
 (3, '😭', False),
 (3, '🤣', False),
 (4, '}', False

In [3]:
# lstm network & other functions
def init_LSTM_params(key, lstm_layers, input_size, model_size, output_size):
  param_sets = 8 # manual, idc
  keys = random.split(key, param_sets*lstm_layers + 2)
  hxconcat_size = model_size + model_size
  he = lambda rkey, shape: random.normal(rkey, shape=shape) * jnp.sqrt(2 / shape[0])
  # supposedly xavier is better for networks using tanh
  xavier = lambda rkey, shape: random.normal(rkey, shape=shape) * jnp.sqrt(2 / (shape[0] + shape[1]))
  params = [
    {
      "wU" : xavier(keys[param_sets*i + 0], (hxconcat_size, model_size)),
      "bU" : jnp.zeros((model_size,)),
      "wC" : xavier(keys[param_sets*i + 6], (hxconcat_size, model_size)),
      "bC" : jnp.zeros((model_size,)),
      "wF": xavier(keys[param_sets*i + 1], (hxconcat_size, model_size)),
      "bF": jnp.zeros((model_size,)),
      "wO" : xavier(keys[param_sets*i + 3], (hxconcat_size, model_size)),
      "bO" : jnp.zeros((model_size,)),
      "h0" : jnp.zeros((model_size,)),
      "c0" : jnp.zeros((model_size,)),
      #"h0" : random.normal(keys[param_sets*i + 4], shape=(model_size)) * jnp.sqrt(2 / model_size),
      #"c0" : random.normal(keys[param_sets*i + 5], shape=(model_size)) * jnp.sqrt(2 / model_size),
    }
    for i in range(lstm_layers)
  ]
  params[0].update(
    {
    # then embedding table weight and bias
    "wEM" : xavier(keys[param_sets*(param_sets - 1) + 2], (input_size, model_size)),
    "bEM" : jnp.zeros((model_size,)),

  })
  params[-1].update(
    {
      # this is for the y layer, which i am probably imlementing wrong.
      "wY1" : xavier(keys[param_sets*(lstm_layers-1) + 4], (model_size, model_size)),
      "bY1" : jnp.zeros((model_size,)),
      "wY2" : xavier(keys[param_sets*(lstm_layers-1) + 5], (model_size, output_size)),
      "bY2" : jnp.zeros((output_size,)),
    }
  )
  return params


@functools.partial(jax.jit, static_argnames=[])
def dropout(dropout_key, original_tensor, dropout_rate):
  # generate random of same shape
  dropout_probs = random.uniform(dropout_key, shape=original_tensor.shape)
  # mask = random < dropout_rate
  mask = (dropout_probs > dropout_rate) / (1 - dropout_rate) # scale to keep avg the same
  return original_tensor * mask


@functools.partial(jax.jit, static_argnames=[]) # static dropout rate?
def lstm_step(step_dropout_key, lstm_layer_params, layer_h, layer_c, current_xt, dropout_rate):
  hxconcat = jax.lax.concatenate([layer_h, current_xt], dimension=1) #B, h ++ B, C => B, h+c
  # update gate
  forget_gate = jax.nn.sigmoid(hxconcat @ lstm_layer_params["wF"] + lstm_layer_params["bF"])
  #update = dropout(step_dropout_keys[0], update, dropout_rate)

  # forget
  layer_c = layer_c * forget_gate

  input_node = jax.nn.tanh(hxconcat @ lstm_layer_params["wC"] + lstm_layer_params["bC"])
  #candidate = dropout(step_dropout_keys[1], candidate, dropout_rate)
  update = jax.nn.sigmoid(
              hxconcat @ lstm_layer_params["wU"] + lstm_layer_params["bU"]
            )
  input_gate =  update * input_node

  # update
  layer_c = layer_c + input_gate

  # output
  layer_h = jax.nn.tanh(layer_c) * jax.nn.sigmoid(hxconcat @ lstm_layer_params["wO"] + lstm_layer_params["bO"]) # (B, model_size)

  next_layer_xt = dropout(step_dropout_key, layer_h, dropout_rate) # the next layer's input x is the current layer's hidden state
  # karpathy: dropout after EACH LAYER not several times in the block. lol.

  # i may also need to do dropout horizontally (i.e. dropout the hidden state memory each block)

  return (layer_h, layer_c), next_layer_xt


# LSTM forward
import functools
@functools.partial(jax.jit, static_argnames=[])
def lstm_forward(dropout_key, lstm_params, xembeds_batch, dropout_rate):
  batches = xembeds_batch.shape[0]
  lstm_layers = len(lstm_params)
  model_size = lstm_params[0]["h0"].size
  # initialize h and c as random/learnable params
  #h = jnp.tile(lstm_params[0]["h0"], (batches, lstm_layers, 1)) # B, lstm_layer, h_size
  #c = jnp.tile(lstm_params[0]["c0"], (batches, lstm_layers, 1)) # B, lstm_layer, c_size
  # wait.. these are the same for all of the layers.. maybe they shouldn't be
  T = xembeds_batch.shape[1]
  # take xembeds_batch and pass each xt through the same SINGULAR block. don't update the weight layer. there is only one layer.
  dropout_keys = random.split(dropout_key, lstm_layers)

  # for each layer:
    # scan over xt
    # carry : h, c
    # a: xt
    # b: h,c
    # f = lambda ((h, c), xt) : lstm_step(h, c, xt, everything else) => h, c
    # scans over xt
    # for next layer: xt = h of previous layer. h = h0 and c = c0
  
  current_embeddings_batch = jnp.transpose(xembeds_batch, (1, 0, 2)) # B, T, C => T, B, C
    # The reason for this is that jax.lax.scan only uses the leading dim. why? idk. its dumb, it needs an axis arg so i can scan over whatever

  for lstm_layer in range(lstm_layers):
    h = jnp.tile(lstm_params[lstm_layer]["h0"], (batches, 1))
    c = jnp.tile(lstm_params[lstm_layer]["c0"], (batches, 1))
    # zeroes makes the backprop faster
    #h = jnp.zeros((batches, model_size))
    #c = jnp.zeros((batches, model_size))
    layer_dropout_key = dropout_keys[lstm_layer] # it doesnt matter if this is the same across all layers
    # scan should be inexpensive since layer size is small while t size is usually LARGE
    # scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
    # scan :: scanfunc -> h_and_c -> xs -> (h_and_c_final, hs_to_be_used_as_input_xt_in_next_layer)
    # scanfunc :: (c -> a -> (c, b))
    scanfunc = lambda hc, xt : lstm_step(layer_dropout_key, lstm_params[lstm_layer], hc[0], hc[1], xt, dropout_rate)
      # for xs: scan along the t dimension! it scans along B by default
      # to fix this, we transpose xs with jnp.transpose(current_embeddings_batch, (1, 0, 2))
    current_embeddings_batch = jax.lax.scan(scanfunc, (h, c), current_embeddings_batch)[1] # (c, [b]) => [b] ==> B, T, C
  

  # finally turn current_embeddings_batch into ys (logits)
  hs = jnp.transpose(current_embeddings_batch, (1, 0, 2)) # T, B, C => B, T, C
  ys = jax.nn.relu(hs @ lstm_params[-1]['wY1'] + lstm_params[-1]["bY1"]) # B, T, model_size => B, T, vocab_size
  ys = ys @ lstm_params[-1]['wY2'] + lstm_params[-1]["bY2"]
  return ys


@functools.partial(jax.jit, static_argnames=[])
def loss_func(dropout_key, lstm_params, xtokens_batch, ytokens_batch, dropout_rate):
  xembeds_batch = embed(lstm_params, xtokens_batch)
  logits = lstm_forward(dropout_key, lstm_params, xembeds_batch, dropout_rate)
  vocab_size = logits.shape[-1]
  ys_one_hot = jax.nn.one_hot(ytokens_batch, vocab_size, axis=-1)
  logprobs = jax.nn.log_softmax(logits, axis=-1)
  crossentropylosses = -jnp.sum(ys_one_hot * logprobs, axis=-1)
  crossentropyloss = jnp.mean(crossentropylosses)
  return crossentropyloss


@functools.partial(jax.jit, static_argnames=[])
def loss_and_value(dropout_key, lstm_params, xtokens_batch, ytokens_batch, dropout_rate):
  xembeds_batch = embed(lstm_params, xtokens_batch)
  logits = lstm_forward(dropout_key, lstm_params, xembeds_batch, dropout_rate)
  vocab_size = logits.shape[-1]
  ys_one_hot = jax.nn.one_hot(ytokens_batch, vocab_size, axis=-1)
  logprobs = jax.nn.log_softmax(logits, axis=-1)
  predictions = jnp.argmax(logprobs, axis=-1)
  crossentropylosses = -jnp.sum(ys_one_hot * logprobs, axis=-1)
  crossentropyloss = jnp.mean(crossentropylosses)
  return crossentropyloss, predictions


jitted_backwards_loss = jax.jit(jax.value_and_grad(loss_func, argnums=1), static_argnames=[])


@functools.partial(jax.jit, static_argnames=['vocab_size'])
def embed(lstm_params, xtokens, vocab_size=len(vocab)):
  xs_one_hot = jax.nn.one_hot(xtokens, vocab_size, axis=-1) #B, T, vocab_size
  activations = xs_one_hot @ lstm_params[0]["wEM"] + lstm_params[0]["bEM"]
  return activations

# make optimizer a static arg in jit or it breaks
@functools.partial(jax.jit, static_argnames=["optimizer"])
def train(dropout_key, lstm_params, xtokens_batch, ytokens_batch, opt_state, dropout_rate, optimizer):
  step_loss, grads = jitted_backwards_loss(dropout_key, lstm_params, xtokens_batch, ytokens_batch, dropout_rate)
  param_updates, updated_opt_state = optimizer.update(grads, opt_state, lstm_params)
  updated_lstm_params = optax.apply_updates(lstm_params, param_updates) 
  return updated_lstm_params, updated_opt_state, step_loss, grads




In [4]:
# laddering params
lstm_layers = 3
model_size = 512# 512

# laddering
# memorize, generalize, move up in complexity. memorize, generalize, move up in complexity.
# until desired complexity is reached

# strategy: aim for extremely low error*epochs
SEQ_LEN      = [2,    4,    10,   25,   50,   50,   50,   50,   ]
LR           = [2e-2, 2e-2, 2e-2, 2e-2, 2e-2, 1e-2, 5e-3, 2e-3,]
DROPOUT_RATE = [0.00, 0.00, 0.00, 0.1,  0.20, 0.20, 0.20, 0.20]
EPOCHS       = [1,    1,    1,    1,    1,    1,    1,    1,  ]
rungs = list(zip(SEQ_LEN, LR, DROPOUT_RATE, EPOCHS)) # causes problems if generator

# test epoch params
# use this to find the optimal LR for laddering steps. could be automated but whatever
target_rung = 2
sequence_length, lr, dropout_rate, test_epochs = rungs[target_rung]
resume_checkpoint = False # saves the checkpoint for rung-1, and reruns rung $rung. this was you dont have to re-climb every rung

print_every = 1000
train_batch_size = 100
val_batch_size = 200

# loss: 2.2663 || val_loss: 1.6317 val_acc: 0.5529 || 5 epochs total

In [5]:
# run test epoch

# setup vars
input_size = len(vocab) # just do one-hot for now
hidden_size = model_size
output_size = len(vocab) # logits => one-hot => tokens
keys = random.split(random.PRNGKey(123), 20)
losses = []
start = time.time()

# train
# train rungs
for r, rung in enumerate(rungs):
  if resume_checkpoint and r < target_rung:
    # skip rungs until the target rung
    continue

  if r > target_rung:
    print(f"ended after rung {r-1}")
    break
  
  if resume_checkpoint and r == target_rung:
    lstm_params = lstm_params_checkpoint

  sequence_length, lr, dropout_rate, epochs = rung
  # initialize if first rung
  if r == 0:
    lstm_params = init_LSTM_params(keys[0], lstm_layers, input_size, model_size, output_size)
    optimizer = optax.inject_hyperparams(optax.adam)(learning_rate=lr)
    opt_state = optimizer.init(lstm_params)
  else:
    opt_state.hyperparams['learning_rate'] = lr

  # train
  for epoch in range(epochs):
      steps = (len(train_tokens) // ((sequence_length+1)*train_batch_size)) - 2
      for step in range(steps): # probably wrong but w/e
        # train
        # B, T where T = sequence_length
        train_data_idx = step*sequence_length*train_batch_size
        next_train_data_idx = (step+1)*sequence_length*train_batch_size
        xtokens_batch = train_tokens[train_data_idx:next_train_data_idx].reshape(-1, sequence_length) #(B, T)
        ytokens_batch = train_tokens[train_data_idx+1:next_train_data_idx+1].reshape(-1, sequence_length) # (B,)

        dropout_key = random.PRNGKey(epoch*steps + step) # unique for every step
        lstm_params, opt_state, step_loss, grads = train(dropout_key, lstm_params, xtokens_batch, ytokens_batch, opt_state, dropout_rate, optimizer)

        losses.append(step_loss)

        if ((epoch*step + step) % print_every == 0) or (epoch + steps == 0):
          end = time.time()
          duration = end - start
          # train inference example (no dropout)
          xembeds_batch = embed(lstm_params, xtokens_batch[0][None, :]) # 1-batch - (1, T, C)
          last_logit_batch = lstm_forward(dropout_key, lstm_params, xembeds_batch, 0) # B, C
          prediction_batch = jnp.argmax(last_logit_batch, axis=-1) # B

          # val batch
          j = step % ((len(test_tokens) - 1)//((val_batch_size)*sequence_length))
          val_idx = j*val_batch_size*sequence_length
          next_val_idx = (j+1)*val_batch_size*sequence_length
          xtokens_val_batch = test_tokens[val_idx:next_val_idx].reshape(-1, sequence_length) # batches of sequences lstm block count size
          ytokens_val_batch = test_tokens[val_idx+1:next_val_idx+1].reshape(-1, sequence_length)
          
          val_loss, prediction_val_batch = loss_and_value(dropout_key, lstm_params, xtokens_val_batch, ytokens_val_batch, dropout_rate=0)
          val_accuracy = jnp.mean(prediction_val_batch == ytokens_val_batch)

          # print train status
          x = decode(xtokens_batch[0]).replace('\n', ' ')
          y = decode(ytokens_batch[0]).replace('\n', ' ')
          yhat = decode(prediction_batch[0]).replace('\n', ' ')
          #print(f'INPUT  ({len(x)}) | "{x}"')
          lines = [
            f'TARGET ({len(y)}) | "{y}"',
            f'PRED   ({len(yhat)}) | "{yhat}"',
            f"r,e,s | {r}/{len(rungs)}, {epoch}/{epochs}, {step}/{steps} || samples/sec: {train_batch_size*print_every/(duration):0.0f} || "
            f"loss: {sum(losses)/len(losses):1.4f} || val_loss: {val_loss:1.4f} val_acc: {val_accuracy:1.4f} || " 
            f"LR = {opt_state.hyperparams['learning_rate']:0.6f}",
          ]
          print("\n".join(lines))
          start = time.time()
  if r < target_rung:
    # stop saving checkpoint after training the target rung
    lstm_params_checkpoint = lstm_params


TARGET (2) | "ep"
PRED   (2) | "  "
r,e,s | 0/8, 0/1, 0/1234 || samples/sec: 17840 || loss: 5.0435 || val_loss: 4.9277 val_acc: 0.1700 || LR = 0.020000
TARGET (2) | "um"
PRED   (2) | "  "
r,e,s | 0/8, 0/1, 1000/1234 || samples/sec: 14780 || loss: 2.7572 || val_loss: 2.9164 val_acc: 0.2325 || LR = 0.020000
TARGET (4) | "eply"
PRED   (4) | "epue"
r,e,s | 1/8, 0/1, 0/739 || samples/sec: 19693 || loss: 2.7359 || val_loss: 2.6774 val_acc: 0.2637 || LR = 0.020000
TARGET (10) | "eply: @tri"
PRED   (10) | "eply: @aue"
r,e,s | 2/8, 0/1, 0/335 || samples/sec: 7177 || loss: 2.6325 || val_loss: 2.3668 val_acc: 0.3500 || LR = 0.020000
ended after rung 2


In [6]:
# train (laddering)


# init some parameters
input_size = len(vocab) # just do one-hot for now
hidden_size = model_size
output_size = len(vocab) # logits => one-hot => tokens
keys = random.split(random.PRNGKey(123), 20)
train_batch_size = 100
val_batch_size = 200
print_every = 100
j = 0
losses = []
start = time.time()


# train rungs
for r, rung in enumerate(rungs):
  print(f"new rung {r}")
  sequence_length, lr, dropout_rate, epochs = rung
  # initialize if first rung
  if r == 0:
    lstm_params = init_LSTM_params(keys[0], lstm_layers, input_size, model_size, output_size)
    optimizer = optax.inject_hyperparams(optax.adam)(learning_rate=lr)
    opt_state = optimizer.init(lstm_params)
  else:
    opt_state.hyperparams['learning_rate'] = lr

  # train
  for epoch in range(epochs):
      steps = (len(train_tokens) // ((sequence_length+1)*train_batch_size)) - 2
      for step in range(steps): # probably wrong but w/e
        # train
        # B, T where T = sequence_length
        train_data_idx = step*sequence_length*train_batch_size
        next_train_data_idx = (step+1)*sequence_length*train_batch_size
        xtokens_batch = train_tokens[train_data_idx:next_train_data_idx].reshape(-1, sequence_length) #(B, T)
        ytokens_batch = train_tokens[train_data_idx+1:next_train_data_idx+1].reshape(-1, sequence_length) # (B,)

        dropout_key = random.PRNGKey(epoch*steps + step) # unique for every step
        lstm_params, opt_state, step_loss, grads = train(dropout_key, lstm_params, xtokens_batch, ytokens_batch, opt_state, dropout_rate, optimizer)

        losses.append(step_loss)

        if ((epoch*step + step) % print_every == 0) or (epoch + steps == 0):
          end = time.time()
          duration = end - start
          # train inference example (no dropout)
          xembeds_batch = embed(lstm_params, xtokens_batch[0][None, :]) # 1-batch - (1, T, C)
          last_logit_batch = lstm_forward(dropout_key, lstm_params, xembeds_batch, 0) # B, C
          prediction_batch = jnp.argmax(last_logit_batch, axis=-1) # B

          # val batch
          j = step % ((len(test_tokens) - 1)//((val_batch_size)*sequence_length))
          val_idx = j*val_batch_size*sequence_length
          next_val_idx = (j+1)*val_batch_size*sequence_length
          xtokens_val_batch = test_tokens[val_idx:next_val_idx].reshape(-1, sequence_length) # batches of sequences lstm block count size
          ytokens_val_batch = test_tokens[val_idx+1:next_val_idx+1].reshape(-1, sequence_length)
          
          val_loss, prediction_val_batch = loss_and_value(dropout_key, lstm_params, xtokens_val_batch, ytokens_val_batch, dropout_rate=0)
          val_accuracy = jnp.mean(prediction_val_batch == ytokens_val_batch)

          # print train status
          x = decode(xtokens_batch[0]).replace('\n', ' ')
          y = decode(ytokens_batch[0]).replace('\n', ' ')
          yhat = decode(prediction_batch[0]).replace('\n', ' ')
          #print(f'INPUT  ({len(x)}) | "{x}"')
          lines = [
            f'TARGET ({len(y)}) | "{y}"',
            f'PRED   ({len(yhat)}) | "{yhat}"',
            f"r,e,s | {r}/{len(rungs)}, {epoch}/{epochs}, {step}/{steps} || samples/sec: {train_batch_size*print_every/(duration):0.0f} || "
            f"loss: {sum(losses)/len(losses):1.4f} || val_loss: {val_loss:1.4f} val_acc: {val_accuracy:1.4f} || " 
            f"LR = {opt_state.hyperparams['learning_rate']:0.6f}",
          ]
          print("\n".join(lines))
          start = time.time()

new rung 0
TARGET (2) | "ep"
PRED   (2) | "  "
r,e,s | 0/8, 0/1, 0/1234 || samples/sec: 3394 || loss: 5.0435 || val_loss: 4.9277 val_acc: 0.1700 || LR = 0.020000
TARGET (2) | "ep"
PRED   (2) | "  "
r,e,s | 0/8, 0/1, 100/1234 || samples/sec: 16252 || loss: 3.2576 || val_loss: 3.4519 val_acc: 0.2000 || LR = 0.020000
TARGET (2) | "sc"
PRED   (2) | "  "
r,e,s | 0/8, 0/1, 200/1234 || samples/sec: 16144 || loss: 3.0758 || val_loss: 2.6926 val_acc: 0.2325 || LR = 0.020000
TARGET (2) | "nt"
PRED   (2) | "  "
r,e,s | 0/8, 0/1, 300/1234 || samples/sec: 15518 || loss: 2.9792 || val_loss: 2.6449 val_acc: 0.2400 || LR = 0.020000
TARGET (2) | "ik"
PRED   (2) | "yn"
r,e,s | 0/8, 0/1, 400/1234 || samples/sec: 18901 || loss: 2.9016 || val_loss: 2.7475 val_acc: 0.2875 || LR = 0.020000
TARGET (2) | "  "
PRED   (2) | "  "
r,e,s | 0/8, 0/1, 500/1234 || samples/sec: 16580 || loss: 2.8631 || val_loss: 2.6574 val_acc: 0.2925 || LR = 0.020000
TARGET (2) | "ce"
PRED   (2) | "io"
r,e,s | 0/8, 0/1, 600/1234 || sa

In [4]:
# train engine parameters

# this is the function that takes the current hyperparameters
# and makes candidate ones to test.
#@jax.jit
from itertools import product
from jax import random as jrand
def make_candidates(key, current_sequence_length, current_learning_rate, current_dropout_rate):
  # get lr candidates
  upscale = jnp.array([1, 2, 10, 100], dtype=jnp.float32)
  downscale = 1.0 / upscale # 8x, 4x, 2x, 1x, 0.5x, 0.25x, etc
  scale = jnp.concatenate([upscale, downscale])
  lr_candidates = current_learning_rate * scale

  # get seq length candidates (this does have an effect)
  #sequence_length_candidates = jnp.array(list(set([current_sequence_length, 2, 4, 8, 15, 25, 50, 100])))
  sequence_length_candidates = jnp.array([current_sequence_length]) # dont change this

  # future: dropout
  dropout_candidates = jnp.array(list(set([0, 0.01, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])))

  # get all possible combinations as a generator
  # randomly shuffle by preshuffling the inputs
  keys = jrand.split(key, 10)
  candidates = product(
    jrand.permutation(keys[0], sequence_length_candidates, independent=True),
    jrand.permutation(keys[1], lr_candidates, independent=True),
    jrand.permutation(keys[2], dropout_candidates, independent=True),
  )

  return candidates


## PARAMETERS ##

# every n epochs, do another hyperparameter search
# lookahead k steps or epochs.
# pick the best set of hyperparameters and do the next n epochs with them. repeat
# in the future, update whenever final_loss < 0.95*start_loss
retune_min_epochs = 20 # how many epochs minimum to train for before rechecking hyperparameters. set super high for normal h tuning
lookahead_steps = 30 # steps to train to test candidate hyperparameters
candidate_eval_func = lambda final_loss: -final_loss # in this case, candidates are evaluated higher if their final loss is low

initial_lr = 3e-4


# model params
lstm_layers = 2
model_size = 512
initial_sequence_length = 100
initial_dropout_rate = 0.4 # for now

candidate_limit = 100


# general testing params
epochs = 10000
print_every = 100
train_batch_size = 100
val_batch_size = 200

test_key = jrand.PRNGKey(1203)
print(len(list(make_candidates(test_key, initial_lr, initial_sequence_length, 0.1))))


96


In [5]:
# train w retuning

# init some parameters
input_size = len(vocab) # just do one-hot for now
hidden_size = model_size
output_size = len(vocab) # logits => one-hot => tokens
keys = random.split(random.PRNGKey(123), 20)
train_batch_size = 100
val_batch_size = train_batch_size
print_every = 100_000
j = 0
losses = []
val_losses = []
val_accuracies = []
start = time.time()


# init state
lstm_params = init_LSTM_params(keys[0], lstm_layers, input_size, model_size, output_size)
optimizer = optax.inject_hyperparams(optax.adam)(learning_rate=initial_lr)
opt_state = optimizer.init(lstm_params)
sequence_length = initial_sequence_length
dropout_rate = initial_dropout_rate

previous_hyperparameters = (sequence_length, initial_lr, dropout_rate)


retune = True
retune_msg = "initial hyperparameter tuning\nthis will be slow as functions are jitted"


decay = 0.98
decay_epochs = 2

# train
for epoch in range(epochs):
    if epoch % decay_epochs == 0:
       opt_state.hyperparams['learning_rate'] = opt_state.hyperparams['learning_rate'] * decay
    # retune hyperparameters
    if retune: # includes the first epoch
      epochs_since_retune = 0
      print("retuning:", retune_msg)
      # just do lr for now
      current_lr = opt_state.hyperparams['learning_rate']
      current_sequence_length = sequence_length
      current_dropout_rate = dropout_rate
      candidate_key = jrand.PRNGKey(int(time.time()))
      candidates = make_candidates(candidate_key, current_sequence_length, current_lr, current_dropout_rate) # candidate 'moves'
      best_candidate = (None, -100000) # candidate, score
      i = 0
      for candidate in [previous_hyperparameters] + list(candidates):
        if i == candidate_limit:
           break
        if i % 10 == 0: print(i)
        i += 1
        # make copies of the current params and opt state, and train with them for a few steps
        candidate_params = lstm_params
        candidate_opt_state = opt_state
        # update copies with candidate hyperparams:
        sequence_length = candidate[0]
        candidate_opt_state.hyperparams['learning_rate'] = candidate[1]
        dropout_rate = candidate[2]
        # for now, just eval future positions on losses[-1].
        # future evals can be whatever. average accuracy over a val set is a good one.
        candidate_val_losses = []
        for step in range(lookahead_steps):
          # train each candidate hyperparam set on the exact same data
          train_data_idx = step*sequence_length*train_batch_size
          next_train_data_idx = (step+1)*sequence_length*train_batch_size
          if next_train_data_idx > len(train_tokens):
              break
          xtokens_batch = train_tokens[train_data_idx:next_train_data_idx].reshape(-1, sequence_length) #(B, T)
          ytokens_batch = train_tokens[train_data_idx+1:next_train_data_idx+1].reshape(-1, sequence_length) # (B,)
          dropout_key = random.PRNGKey(epoch*lookahead_steps + step) # unique for every step
          # train
          candidate_params, candidate_opt_state, step_loss, _ = train(
            dropout_key, candidate_params, xtokens_batch, ytokens_batch, candidate_opt_state, dropout_rate, optimizer
          )

          # get val accuracy after training
          j = step % ((len(test_tokens) - 1)//((val_batch_size)*sequence_length))
          val_idx = j*val_batch_size*sequence_length
          next_val_idx = (j+1)*val_batch_size*sequence_length
          xtokens_val_batch = test_tokens[val_idx:next_val_idx].reshape(-1, sequence_length) # batches of sequences lstm block count size
          ytokens_val_batch = test_tokens[val_idx+1:next_val_idx+1].reshape(-1, sequence_length)
          
          val_loss, prediction_val_batch = loss_and_value(dropout_key, candidate_params, xtokens_val_batch, ytokens_val_batch, dropout_rate=0)
          val_accuracy = jnp.mean(prediction_val_batch == ytokens_val_batch)

          candidate_val_losses.append(val_loss)

        candidate_result = -sum(candidate_val_losses)/len(candidate_val_losses)#val_loss # just do train loss for now. the best is probably accuracy though.
        candidate_score = candidate_result#candidate_eval_func(candidate_result)
        if candidate_score > best_candidate[1]: # replace the current winner if this one scores better
          best_candidate = (candidate, candidate_score)
        
        candidate_val_losses = []

      # finally, update to the winner
      print(f"old: todo")
      print(f"new: {best_candidate[0]} => {best_candidate[1]}")

      ## update hyperparams:
      new_hyperparameters = best_candidate[0]
      previous_hyperparameters = new_hyperparameters
      new_sequence_length, new_lr, new_dropout_rate = new_hyperparameters
      opt_state.hyperparams['learning_rate'] = new_lr
      sequence_length = int(new_sequence_length)
      dropout_rate = new_dropout_rate

      retune = False


    # do regular training with the current hyperparameters for $sprint_distance epochs
    # retesting hyperparameters every $sprint_distance is handled by the if block above
    steps = (len(train_tokens) // ((sequence_length+1)*train_batch_size)) - 2
    for step in range(steps): # probably wrong but w/e
      # train
      # B, T where T = sequence_length
      train_data_idx = step*sequence_length*train_batch_size
      next_train_data_idx = (step+1)*sequence_length*train_batch_size
      xtokens_batch = train_tokens[train_data_idx:next_train_data_idx].reshape(-1, sequence_length) #(B, T)
      ytokens_batch = train_tokens[train_data_idx+1:next_train_data_idx+1].reshape(-1, sequence_length) # (B,)

      dropout_key = random.PRNGKey(epoch*steps + step) # unique for every step
      lstm_params, opt_state, step_loss, grads = train(dropout_key, lstm_params, xtokens_batch, ytokens_batch, opt_state, dropout_rate, optimizer)

      losses.append(step_loss)

      # val
      j = step % ((len(test_tokens) - 1)//((val_batch_size)*sequence_length))
      val_idx = j*val_batch_size*sequence_length
      next_val_idx = (j+1)*val_batch_size*sequence_length
      xtokens_val_batch = test_tokens[val_idx:next_val_idx].reshape(-1, sequence_length) # batches of sequences lstm block count size
      ytokens_val_batch = test_tokens[val_idx+1:next_val_idx+1].reshape(-1, sequence_length)
      
      val_loss, prediction_val_batch = loss_and_value(dropout_key, lstm_params, xtokens_val_batch, ytokens_val_batch, dropout_rate=0)
      val_accuracy = jnp.mean(prediction_val_batch == ytokens_val_batch)

      val_losses.append(val_loss)
      val_accuracies.append(val_accuracy)

      if (step == steps - 1):
        end = time.time()
        duration = end - start
        # train inference example (no dropout)
        xembeds_batch = embed(lstm_params, xtokens_batch[0][None, :]) # 1-batch - (1, T, C)
        last_logit_batch = lstm_forward(dropout_key, lstm_params, xembeds_batch, 0) # B, C
        prediction_batch = jnp.argmax(last_logit_batch, axis=-1) # B

        # print train status
        x = decode(xtokens_batch[0]).replace('\n', ' ')
        y = decode(ytokens_batch[0]).replace('\n', ' ')
        yhat = decode(prediction_batch[0]).replace('\n', ' ')
        #print(f'INPUT  ({len(x)}) | "{x}"')
        avg_loss = sum(losses)/len(losses)
        avg_val_loss = sum(val_losses)/len(val_losses)
        avg_val_acc = sum(val_accuracies)/len(val_accuracies)
        lines = [
          f'TARGET | "{y}"',
          f'PRED   | "{yhat}"',
          f"e:{epoch}/{epochs} s:{step}/{steps} || samples/sec: {train_batch_size*print_every/(duration):0.0f} || "
          f"loss: {step_loss:1.4f} || val_loss: {avg_val_loss:1.4f} val_acc: {avg_val_acc:1.4f} || " 
          f"LR = {opt_state.hyperparams['learning_rate']:0.6f}",
        ]
        print("\n".join(lines))
        start = time.time()
    
    epochs_since_retune += 1
    # if the val error hasn't decreased to 90%, try to retune hyperparameters
    target_decrease = 0.9997
    if epoch > 0 and epochs_since_retune > retune_min_epochs and avg_val_loss > previous_epoch_val_loss*target_decrease:
       retune = True
       retune_msg = f"\nval_error: {avg_val_loss:0.4f} !< {target_decrease:0.4f}*{previous_epoch_val_loss:0.4f}"
    previous_epoch_val_loss = avg_val_loss

    losses = []
    val_losses = []

retuning: initial hyperparameter tuning
this will be slow as functions are jitted
0
10
20
30
40
50
60
70
80
90
old: todo
new: (Array(100, dtype=int32), Array(0.00294, dtype=float32), Array(0.8, dtype=float32)) => -3.318880081176758
TARGET | "d up with a bunch of half done git repos  not a fan not a fan🛑      reply: @ludwigABAP Another examp"
PRED   | "e toetot  teoeeee te teee toe  tot toete   ee teaeeetee tooeee      reply: @ooo  n     @oee    to t "
e:0/10000 s:33/34 || samples/sec: 14312 || loss: 2.8495 || val_loss: 3.2569 val_acc: 0.2062 || LR = 0.002940
TARGET | "d up with a bunch of half done git repos  not a fan not a fan🛑      reply: @ludwigABAP Another examp"
PRED   | "  tt ton  tntet eean tene ao   ton aersl   et tnton te  antor       reply: @aoteen   B @n   e  tt ne"
e:1/10000 s:33/34 || samples/sec: 1316590 || loss: 2.5645 || val_loss: 2.5969 val_acc: 0.2580 || LR = 0.002940
TARGET | "d up with a bunch of half done git repos  not a fan not a fan🛑      reply: @ludwigABAP Ano

KeyboardInterrupt: 

In [35]:
# normal train parameters

## PARAMETERS ##
lr = 0.0007
sequence_length = 100
dropout_rate = 0.25


decay_lr = False
decay = 0.98
decay_epochs = 3


# model params
lstm_layers = 2
model_size = 1024

resume_train_state = False

# general testing params
epochs = 10000
print_every = 100_000
train_batch_size = 46
val_batch_size = 50


In [34]:
# normal training
# train normally

# init some parameters
input_size = len(vocab) # just do one-hot for now
hidden_size = model_size
output_size = len(vocab) # logits => one-hot => tokens
keys = random.split(random.PRNGKey(123), 20)
losses = []
val_losses = []
val_accuracies = []
start = time.time()


# init state
if not resume_train_state:
  optimizer = optax.inject_hyperparams(optax.adam)(learning_rate=lr)
  lstm_params = init_LSTM_params(keys[0], lstm_layers, input_size, model_size, output_size)
  opt_state = optimizer.init(lstm_params)
else:
  opt_state.hyperparams['learning_rate'] = lr


# train
for epoch in range(epochs):
    if decay_lr and epoch != 0 and epoch % decay_epochs == 0:
       opt_state.hyperparams['learning_rate'] = opt_state.hyperparams['learning_rate'] * decay

    # train
    steps = (len(train_tokens) // ((sequence_length+1)*train_batch_size)) - 2
    for step in range(steps): # probably wrong but w/e
      # B, T where T = sequence_length
      train_data_idx = step*sequence_length*train_batch_size
      next_train_data_idx = (step+1)*sequence_length*train_batch_size
      xtokens_batch = train_tokens[train_data_idx:next_train_data_idx].reshape(-1, sequence_length) #(B, T)
      ytokens_batch = train_tokens[train_data_idx+1:next_train_data_idx+1].reshape(-1, sequence_length) # (B,)

      dropout_key = random.PRNGKey(epoch*steps + step) # unique for every step
      lstm_params, opt_state, step_loss, grads = train(dropout_key, lstm_params, xtokens_batch, ytokens_batch, opt_state, dropout_rate, optimizer)

      losses.append(step_loss)

      # val
      j = step % ((len(test_tokens) - 1)//((val_batch_size)*sequence_length))
      val_idx = j*val_batch_size*sequence_length
      next_val_idx = (j+1)*val_batch_size*sequence_length
      xtokens_val_batch = test_tokens[val_idx:next_val_idx].reshape(-1, sequence_length) # batches of sequences lstm block count size
      ytokens_val_batch = test_tokens[val_idx+1:next_val_idx+1].reshape(-1, sequence_length)
      
      val_loss, prediction_val_batch = loss_and_value(dropout_key, lstm_params, xtokens_val_batch, ytokens_val_batch, dropout_rate=0)
      val_accuracy = jnp.mean(prediction_val_batch == ytokens_val_batch)

      val_losses.append(val_loss)
      val_accuracies.append(val_accuracy)

      if (step == steps - 1):
        end = time.time()
        duration = end - start
        # train inference example (no dropout)
        xembeds_batch = embed(lstm_params, xtokens_batch[0][None, :]) # 1-batch - (1, T, C)
        last_logit_batch = lstm_forward(dropout_key, lstm_params, xembeds_batch, 0) # B, C
        prediction_batch = jnp.argmax(last_logit_batch, axis=-1) # B

        # print train status
        x = decode(xtokens_batch[0]).replace('\n', ' ')
        y = decode(ytokens_batch[0]).replace('\n', ' ')
        yhat = decode(prediction_batch[0]).replace('\n', ' ')
        #print(f'INPUT  ({len(x)}) | "{x}"')
        avg_loss = sum(losses)/len(losses)
        avg_val_loss = sum(val_losses)/len(val_losses)
        avg_val_acc = sum(val_accuracies)/len(val_accuracies)
        lines = [
          f'TARGET | "{y}"',
          f'PRED   | "{yhat}"',
          f"e:{epoch+1}/{epochs} s:{step+1}/{steps} || samples/sec: {train_batch_size*steps/(duration):0.0f} || "
          f"loss: {step_loss:1.4f} || val_loss: {avg_val_loss:1.4f} val_acc: {avg_val_acc:1.4f} || " 
          f"LR = {opt_state.hyperparams['learning_rate']:0.6f}",
        ]
        print("\n".join(lines))
        start = time.time()

    losses = []
    val_losses = []

TARGET | "relaxing🛑      reply: @angkul07 ive noticed this too, its what got me thinking🛑      reply: @kuberde"
PRED   | "te              epl:::                                                                epl:::        "
e:1/10000 s:100/100 || samples/sec: 541 || loss: 3.0564 || val_loss: 3.4078 val_acc: 0.1562 || LR = 0.000700
TARGET | "relaxing🛑      reply: @angkul07 ive noticed this too, its what got me thinking🛑      reply: @kuberde"
PRED   | "to en n        reply: @anleone  tne to  n   toen to   tn  toen to  to toen  n        reply: @aonlnee"
e:2/10000 s:100/100 || samples/sec: 1413 || loss: 2.6120 || val_loss: 2.7691 val_acc: 0.2176 || LR = 0.000700
TARGET | "relaxing🛑      reply: @angkul07 ive noticed this too, its what got me thinking🛑      reply: @kuberde"
PRED   | "teaen ng       reply: @andaene  tne to  ne  then th   tn  toen to  to theng ng       reply: @aonlnee"
e:3/10000 s:100/100 || samples/sec: 1366 || loss: 2.4226 || val_loss: 2.4638 val_acc: 0.2563 || LR = 0.000700

KeyboardInterrupt: 

In [64]:
# inference settings

temperature = 1.5   # from 0 to 2. 1 is normal.

reply_prompt = "reply: "
post_prompt = "post: "

prompt = post_prompt

In [68]:
# run the model!

def inference(key, chars, temperature):
  xtokens = encode(chars)[None, :]
  xembed = embed(lstm_params, xtokens) # artificial single batch
  logits = lstm_forward(key, lstm_params, xembed, 0)[0][-1] # logits of the first B and last T in the B T C. should be (C,)
  probs = jax.nn.softmax(logits/(temperature + 0.001))
  yhattokens = random.choice(key, a=logits.shape[0], p=probs) # no need for axis=-1 since logits are (C,)
  return yhattokens


steps = 1000
import time
seed = int(1000*time.time())
keys = random.split(random.PRNGKey(seed), steps)
text =  "\n"*50 + 'reply: '
print(text.replace('\n\n', ''), end='')
for i in range(steps):
  next_token = inference(keys[i], text[-sequence_length:], temperature)
  next_char = decode([next_token])[-1]
  if next_char == '🛑':
    print(next_char, end='')
    break
  text += next_char
  line_length = 50
  if (len(text) - 50) % line_length == 0:
    print()
  print(next_char, end='')

reply: bc VA E, = kAHDBl
IMa-
You'l lodp theys no
 mil,
Apjoa
hp 1x, n1L)
I can light 16fr  5'06 (ca
i builder
likely bro
289169x23% upsam much-crathik
ingy

you p4yl Set20but?🛑

In [205]:
jax.tree_util.tree_map(jnp.linalg.norm, grads)

[{'bC': Array(7.5398985e-18, dtype=float32),
  'bEM': Array(1.2155308e-16, dtype=float32),
  'bF': Array(5.38526e-19, dtype=float32),
  'bO': Array(2.959863e-19, dtype=float32),
  'bU': Array(7.320699e-19, dtype=float32),
  'bY1': Array(4.5822423e-18, dtype=float32),
  'bY2': Array(0.01353781, dtype=float32),
  'c0': Array(9.196565e-18, dtype=float32),
  'h0': Array(0., dtype=float32),
  'wC': Array(4.9677064e-16, dtype=float32),
  'wEM': Array(1.1980155e-16, dtype=float32),
  'wF': Array(2.0176042e-17, dtype=float32),
  'wO': Array(1.3135809e-17, dtype=float32),
  'wU': Array(2.7428173e-17, dtype=float32),
  'wY1': Array(9.819348e-17, dtype=float32),
  'wY2': Array(1.5380868e-16, dtype=float32)}]

In [None]:
getsize = lambda s: s.size
sizes = jax.tree_util.tree_map(getsize, grads)
total_params = 0
for layer in sizes:
  for _, v in layer.items():
    total_params += v

print(f"TOTAL_PARAMS: {total_params}")
print(f"DTYPE: {grads[0]['bC'].dtype}")
print(f"TOTAL_MEGABYTES: {total_params*4/1_000_000}")

In [48]:
import jax.profiler
jax.profiler.save_device_memory_profile('test.prof')

In [None]:
data = jnp.arange(1000)
seqlen = 10
bs = 4
steps = len(data) // (bs*seqlen)
idx = 24
data_idx = idx*seqlen*bs
next_data_idx = (idx+1)*seqlen*bs
print(
      f"steps: {steps}\n",
      data[data_idx:next_data_idx].reshape(-1, seqlen),
      '\n\n',
      data[data_idx+seqlen:next_data_idx+1:seqlen].reshape(-1, 1),
)