In [21]:
#imports
import jax
import jax.numpy as jnp
import jax.random as random
import optax
from tokenizers import CharBPETokenizer
import functools
import time


gpu_device = jax.device_get('gpu')[0]
cpu_device = jax.device_get('cpu')[0]
# LSTM
# xs = B, input_size = B, T, C
# h = c = y = B, output_size = B, T, logits_size = B, T, vocab_size


In [22]:
#dataset
with open('data/dnbt_posts.txt', 'r') as file:
  dataset = file.read()

# tokenize
vocab = sorted(list(set(dataset)))
print("vocab length:", len(vocab))

token_to_char = dict(enumerate(vocab))
char_to_token = dict([(v, k) for k, v in token_to_char.items()])
decode = lambda tokens: "".join([token_to_char[int(token)] for token in tokens])
encode = lambda chars: jnp.array([char_to_token[c] for c in chars])

print("dog", encode("dog"), decode(encode("dog")))

dataset_tokens = encode(dataset)
split_ratio = 0.8
train_tokens = dataset_tokens[:int(len(dataset_tokens)*split_ratio)]
test_tokens = dataset_tokens[int(len(dataset_tokens)*split_ratio):]
del dataset
del dataset_tokens

vocab length: 155
dog [66 77 69] dog


In [23]:
lstm_layers = 4
sequence_length = 20
model_size = 512

input_size = len(vocab) # just do one-hot for now
hidden_size = model_size
output_size = len(vocab) # logits => one-hot => tokens


# init LSTM params
def init_LSTM_params(key, lstm_layers, input_size, model_size, output_size):
  param_sets = 8 # manual, idc
  keys = random.split(key, param_sets*lstm_layers + 2)
  hxconcat_size = model_size + model_size
  he = lambda rkey, shape: random.normal(rkey, shape=shape) * jnp.sqrt(2 / shape[0])
  params = [
    {
      "wU" : he(keys[param_sets*i + 0], (hxconcat_size, model_size)),
      "bU" : jnp.zeros((model_size,)),
      "wC" : he(keys[param_sets*i + 6], (hxconcat_size, model_size)),
      "bC" : jnp.zeros((model_size,)),
      "wF1": he(keys[param_sets*i + 1], (hxconcat_size, model_size)),
      "bF1": jnp.zeros((model_size,)),
      "wF2": he(keys[param_sets*i + 2], (hxconcat_size, model_size)),
      "bF2": jnp.zeros((model_size,)),
      "wO" : he(keys[param_sets*i + 3], (hxconcat_size, model_size)),
      "bO" : jnp.zeros((model_size,)),
      "h0" : jnp.zeros((model_size,)),
      "c0" : jnp.zeros((model_size,)),
      #"h0" : random.normal(keys[param_sets*i + 4], shape=(model_size)) * jnp.sqrt(2 / model_size),
      #"c0" : random.normal(keys[param_sets*i + 5], shape=(model_size)) * jnp.sqrt(2 / model_size),
    }
    for i in range(lstm_layers)
  ]
  params[0].update(
    {
    # then embedding table weight and bias
    "wEM" : he(keys[param_sets*(param_sets - 1) + 2], (input_size, model_size)),
    "bEM" : jnp.zeros((model_size,)),

  })
  params[-1].update(
    {
      # this is for the y layer, which i am probably imlementing wrong.
      "wY1" : he(keys[param_sets*(lstm_layers-1) + 4], (model_size, output_size)),
      "bY1" : jnp.zeros((output_size,)),
    }
  )
  return params


@functools.partial(jax.jit, static_argnames=["dropout_rate"])
def dropout(dropout_key, original_tensor, dropout_rate):
  # generate random of same shape
  dropout_probs = random.uniform(dropout_key, shape=original_tensor.shape)
  # mask = random < dropout_rate
  mask = (dropout_probs > dropout_rate) / (1 - dropout_rate) # scale to keep avg the same
  return original_tensor * mask


@functools.partial(jax.jit, static_argnames=["dropout_rate"])
def lstm_step(step_dropout_key, lstm_layer_params, layer_h, layer_c, current_xt, dropout_rate):
  hxconcat = jax.lax.concatenate([layer_h, current_xt], dimension=1) #B, h ++ B, C => B, h+c
  # update gate
  update = jax.nn.sigmoid(hxconcat @ lstm_layer_params["wU"] + lstm_layer_params["bU"])
  #update = dropout(step_dropout_keys[0], update, dropout_rate)
  candidate = jax.nn.tanh(hxconcat @ lstm_layer_params["wC"] + lstm_layer_params["bC"])
  #candidate = dropout(step_dropout_keys[1], candidate, dropout_rate)

  # forget gate
  forget = jax.nn.sigmoid(
              hxconcat @ lstm_layer_params["wF1"] + lstm_layer_params["bF1"]
            ) * jax.nn.tanh(
              hxconcat @ lstm_layer_params["wF2"] + lstm_layer_params["bF2"]
            )

  # update c with update and forget
  layer_c = layer_c + update * candidate + forget # (batch, c) => (batch, c)

  # output
  layer_h = jax.nn.tanh(layer_c) * jax.nn.sigmoid(hxconcat @ lstm_layer_params["wO"] + lstm_layer_params["bO"]) # (B, model_size)

  next_layer_xt = dropout(step_dropout_key, layer_h, dropout_rate) # the next layer's input x is the current layer's hidden state
  # karpathy: dropout after EACH LAYER not several times in the block. lol.

  return (layer_h, layer_c), next_layer_xt


# LSTM forward
import functools
@functools.partial(jax.jit, static_argnames=['dropout_rate', 'lstm_layers'])
def lstm_forward(dropout_key, lstm_params, xembeds_batch, dropout_rate, lstm_layers=lstm_layers):
  batches = xembeds_batch.shape[0]
  lstm_layers = len(lstm_params)
  # initialize h and c as random/learnable params
  #h = jnp.tile(lstm_params[0]["h0"], (batches, lstm_layers, 1)) # B, lstm_layer, h_size
  #c = jnp.tile(lstm_params[0]["c0"], (batches, lstm_layers, 1)) # B, lstm_layer, c_size
  # wait.. these are the same for all of the layers.. maybe they shouldn't be
  T = xembeds_batch.shape[1]
  # take xembeds_batch and pass each xt through the same SINGULAR block. don't update the weight layer. there is only one layer.
  dropout_layers = 6*T*lstm_layers # manual
  dropout_keys = random.split(dropout_key, dropout_layers)

  # for each layer:
    # scan over xt
    # carry : h, c
    # a: xt
    # b: h,c
    # f = lambda ((h, c), xt) : lstm_step(h, c, xt, everything else) => h, c
    # scans over xt
    # for next layer: xt = h of previous layer. h = h0 and c = c0
  
  current_embeddings_batch = jnp.transpose(xembeds_batch, (1, 0, 2)) # B, T, C => T, B, C
    # The reason for this is that jax.lax.scan only uses the leading dim. why? idk. its dumb, it needs an axis arg so i can scan over whatever

  for lstm_layer in range(lstm_layers):
    h = jnp.tile(lstm_params[lstm_layer]["h0"], (batches, 1))
    c = jnp.tile(lstm_params[lstm_layer]["c0"], (batches, 1))
    layer_dropout_key = dropout_keys[lstm_layer] # it doesnt matter if this is the same across all layers
    # scan should be inexpensive since layer size is small while t size is usually LARGE
    # scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
    # scan :: scanfunc -> h_and_c -> xs -> (h_and_c_final, hs_to_be_used_as_input_xt_in_next_layer)
    # scanfunc :: (c -> a -> (c, b))
    scanfunc = lambda hc, xt : lstm_step(layer_dropout_key, lstm_params[lstm_layer], hc[0], hc[1], xt, dropout_rate)
      # for xs: scan along the t dimension! it scans along B by default
      # to fix this, we transpose xs with jnp.transpose(current_embeddings_batch, (1, 0, 2))
    current_embeddings_batch = jax.lax.scan(scanfunc, (h, c), current_embeddings_batch)[1] # (c, [b]) => [b] ==> B, T, C
  

  # finally turn current_embeddings_batch into ys (logits)
  hs = jnp.transpose(current_embeddings_batch, (1, 0, 2)) # T, B, C => B, T, C
  ys = hs @ lstm_params[-1]['wY1'] + lstm_params[-1]["bY1"] # B, T, model_size => B, T, vocab_size
  return ys


@functools.partial(jax.jit, static_argnames=["dropout_rate"])
def loss(dropout_key, lstm_params, xtokens_batch, ytokens_batch, dropout_rate):
  xembeds_batch = embed(lstm_params, xtokens_batch)
  logits = lstm_forward(dropout_key, lstm_params, xembeds_batch, dropout_rate)
  vocab_size = logits.shape[-1]
  ys_one_hot = jax.nn.one_hot(ytokens_batch, vocab_size, axis=-1)
  logprobs = jax.nn.log_softmax(logits, axis=-1)
  crossentropylosses = -jnp.sum(ys_one_hot * logprobs, axis=-1)
  crossentropyloss = jnp.mean(crossentropylosses)
  return crossentropyloss


@functools.partial(jax.jit, static_argnames=["dropout_rate"])
def final_token_loss(dropout_key, lstm_params, xtokens_batch, ytokens_batch, dropout_rate):
  xembeds_batch = embed(lstm_params, xtokens_batch)
  logit = lstm_forward(dropout_key, lstm_params, xembeds_batch, dropout_rate)[:, -1] # get last logit
  vocab_size = logit.shape[-1]
  ys_one_hot = jax.nn.one_hot(ytokens_batch, vocab_size, axis=-1)[:, -1] # get last y (the target)
  logprobs = jax.nn.log_softmax(logit, axis=-1)
  crossentropylosses = -jnp.sum(ys_one_hot * logprobs, axis=-1)
  crossentropyloss = jnp.mean(crossentropylosses)
  return crossentropyloss



jitted_backwards_loss = jax.jit(jax.value_and_grad(final_token_loss, argnums=1), static_argnames=["dropout_rate"])


@functools.partial(jax.jit, static_argnames=['vocab_size'])
def embed(lstm_params, xtokens, vocab_size=len(vocab)):
  xs_one_hot = jax.nn.one_hot(xtokens, vocab_size, axis=-1) #B, T, vocab_size
  activations = xs_one_hot @ lstm_params[0]["wEM"] + lstm_params[0]["bEM"]
  xembeds = jax.nn.tanh(activations) # B, T, C
  return xembeds # TODO ??? this doesnt seem to make a difference btw.


lr = 2e-3
lr_decay = 0.97
decay_after = 10
decay_every = 5
optimizer = optax.inject_hyperparams(optax.adam)(learning_rate=lr)


# make optimizer a static arg in jit or it breaks
@functools.partial(jax.jit, static_argnames=["dropout_rate"])
def train(dropout_key, lstm_params, xtokens_batch, ytokens_batch, opt_state, dropout_rate):
  step_loss, grads = jitted_backwards_loss(dropout_key, lstm_params, xtokens_batch, ytokens_batch, dropout_rate)
  param_updates, updated_opt_state = optimizer.update(grads, opt_state, lstm_params)
  updated_lstm_params = optax.apply_updates(lstm_params, param_updates) 
  return updated_lstm_params, updated_opt_state, step_loss, grads



In [24]:
# train
# set up lstm params
keys = random.split(random.PRNGKey(123), 20)
lstm_params = init_LSTM_params(keys[0], lstm_layers, input_size, model_size, output_size)
opt_state = optimizer.init(lstm_params)


# train
# for now just overfit on small sample idk lol
train_batch_size = 100
val_batch_size = 10

dropout_rate = 0.2

epochs = 1000

print_every = 20
j = 0
losses = []
start = time.time()
for epoch in range(epochs):
  if epoch > decay_after:
    if epoch % decay_every == 0:
      lr *= lr_decay
      opt_state.hyperparams['learning_rate'] = lr
  samples = (len(train_tokens) - 1) // sequence_length
  for i in range(0, len(train_tokens) - 1 - sequence_length*train_batch_size, sequence_length*train_batch_size): # probably wrong but w/e
    # train
    # B, T where T = sequence_length
    xtokens_batch = train_tokens[i:i+sequence_length*train_batch_size].reshape(-1, sequence_length)
    ytokens_batch = train_tokens[i+1:i+sequence_length*train_batch_size+1].reshape(-1, sequence_length)

    dropout_key = random.PRNGKey(epoch*samples + i) # unique for every step

    lstm_params, opt_state, step_loss, grads = train(dropout_key, lstm_params, xtokens_batch, ytokens_batch, opt_state, dropout_rate)

    j += 1
    losses.append(step_loss)

    if j % print_every == 0:
      end = time.time()
      duration = end - start
      # train inference example (no dropout)
      xembeds_batch = embed(lstm_params, xtokens_batch[0][None, :]) # 1-batch
      logits_batch = lstm_forward(dropout_key, lstm_params, xembeds_batch, 0)
      prediction_batch = jnp.argmax(logits_batch, axis=-1)

      # val batch
      j = i % ((len(test_tokens) - 1)//((val_batch_size)*sequence_length))
      idx = j*val_batch_size*sequence_length
      xtokens_val_batch = test_tokens[idx:idx+sequence_length*val_batch_size].reshape(-1, sequence_length) # batches of sequences lstm block count size
      ytokens_val_batch = test_tokens[idx+1:idx+sequence_length*val_batch_size+1].reshape(-1, sequence_length)
      xembeds_val_batch = embed(lstm_params, xtokens_val_batch)
      
      logits_val_batch = lstm_forward(dropout_key, lstm_params, xembeds_val_batch, 0)
      prediction_val_batch = jnp.argmax(logits_val_batch, axis=-1)
      ys_onehot = jax.nn.one_hot(ytokens_val_batch, len(vocab), axis=-1)
      logprobs = jax.nn.log_softmax(logits_val_batch, axis=-1)
      crossentropies = -jnp.sum(ys_onehot*logprobs,axis=-1)
      val_loss = jnp.mean(crossentropies) #lmao
      val_accuracy = jnp.mean(prediction_val_batch == ytokens_val_batch)

      x = decode(xtokens_batch[0]).replace('\n', ' ')
      y = decode(ytokens_batch[0]).replace('\n', ' ')
      yhat = decode(prediction_batch[0]).replace('\n', ' ')
      #print(epoch, epoch * samples + i, f"{step_loss:1.4f}", "pred:", x, "=>", y, "?=", yhat)
      print(f'TARGET ({len(y)}) | "{y}"')
      print(f'PRED   ({len(yhat)}) | "{yhat}"')
      print(f"step {(epoch, epoch * samples + j)} || samples/sec: {train_batch_size*print_every/(duration):0.0f} || loss: {sum(losses)/len(losses):1.4f} || val_loss: {val_loss:1.4f} val_acc: {val_accuracy:1.4f} || LR = {opt_state.hyperparams['learning_rate']:0.6f}" )
      print()
      start = time.time()


TARGET (20) | "   reply: @andrew_py"
PRED   (20) | "                    "
step (0, 441) || samples/sec: 327 || loss: 3.8576 || val_loss: 3.4963 val_acc: 0.1350 || LR = 0.002000

TARGET (20) | " @kuberdenis @startu"
PRED   (20) | "                    "
step (0, 353) || samples/sec: 1190 || loss: 3.6470 || val_loss: 3.2829 val_acc: 0.1850 || LR = 0.002000

TARGET (20) | "imo. as a kid i had "
PRED   (20) | "                    "
step (0, 70) || samples/sec: 3450 || loss: 3.6179 || val_loss: 3.3287 val_acc: 0.1600 || LR = 0.002000

TARGET (20) | "joy using your time "
PRED   (20) | "                    "
step (0, 497) || samples/sec: 2359 || loss: 3.5855 || val_loss: 3.3263 val_acc: 0.2000 || LR = 0.002000

TARGET (20) | "roductive when you w"
PRED   (20) | "                    "
step (0, 149) || samples/sec: 9196 || loss: 3.5780 || val_loss: 3.2557 val_acc: 0.1550 || LR = 0.002000

TARGET (20) | " aside though, why w"
PRED   (20) | "                    "
step (0, 460) || samples/sec: 214

KeyboardInterrupt: 

In [26]:
def inference(key, chars):
  xtokens = encode(chars)[None, :]
  xembed = embed(lstm_params, xtokens) # artificial single batch
  logits = lstm_forward(key, lstm_params, xembed, 0)[0][-1] # logits of the first B and last T in the B T C. should be (C,)
  yhattokens = random.choice(key, a=logits.shape[0], p=jax.nn.softmax(logits)) # no need for axis=-1 since logits are (C,)
  sequence = yhattokens
  return sequence

steps = 1000
import time
seed = int(time.time())
keys = random.split(random.PRNGKey(seed), steps)
temperature = 0.5
text =  "\n"*50 + 'post: '
print(text.replace('\n\n', ''), end='')
for i in range(steps):
  yseq = inference(keys[i], text[-sequence_length:])
  next_char = decode([yseq])[-1]
  if next_char == '🛑':
    break
  text += next_char
  print(next_char, end='')

post: a piosg @Ays it atf sabcson threts forcret drasper plr intapale fmafer sraqded reat mikeyod mow te repre froth thit thea and menn a min ceos to 8Ishl 4lafan a porttps://t.co/tovqG6L4xE yourascorf,G this yo biach the. yau pemf yao exang at carlt ec artls als seery susfd ray ors you mes Iv pA)o boster 4 c pywe stirings ctogerl reeld time sry lohes inlad, baph seme ser nes ot modcer the s,hes at os as fan ar,orfiky meol af slarlise wor tou doar rancminus of like &rad biseoding (E5P5Q siverd the eotf is o hos sot i got at and ib ono sot the an to ceearsty would if the e cind for etf tre muncss i canctus f nela. soil as youl, his newur to ceog f the praa casp, hoc thy:WTre meot enars ar pore?

e in the tor lew it youd ther jogatioln nejertacts giws

-f Y*merungs frocks. maging bad is bring ML the remre meef ti ne ined a saff wirisibe thiuld

In [11]:
jax.tree_util.tree_map(jnp.linalg.norm, grads)

[{'bC': Array(0.00024058, dtype=float32),
  'bEM': Array(0.0003749, dtype=float32),
  'bF1': Array(1.8734916e-05, dtype=float32),
  'bF2': Array(0.0002467, dtype=float32),
  'bO': Array(5.4436307e-05, dtype=float32),
  'bU': Array(2.0121175e-05, dtype=float32),
  'c0': Array(0.00227709, dtype=float32),
  'h0': Array(0.00339398, dtype=float32),
  'wC': Array(0.00184112, dtype=float32),
  'wEM': Array(0.00018773, dtype=float32),
  'wF1': Array(0.00021604, dtype=float32),
  'wF2': Array(0.00192728, dtype=float32),
  'wO': Array(0.00067066, dtype=float32),
  'wU': Array(0.00023371, dtype=float32)},
 {'bC': Array(0.00018975, dtype=float32),
  'bF1': Array(2.3798742e-05, dtype=float32),
  'bF2': Array(0.0001887, dtype=float32),
  'bO': Array(5.7448495e-05, dtype=float32),
  'bU': Array(2.3484547e-05, dtype=float32),
  'wC': Array(0.00251008, dtype=float32),
  'wF1': Array(0.00043969, dtype=float32),
  'wF2': Array(0.00262315, dtype=float32),
  'wO': Array(0.00120152, dtype=float32),
  'wU': 

In [34]:
getsize = lambda s: s.size
sizes = jax.tree_util.tree_map(getsize, grads)
total_params = 0
for layer in sizes:
  for _, v in layer.items():
    total_params += v

print(f"TOTAL_PARAMS: {total_params}")
print(f"DTYPE: {grads[0]['bC'].dtype}")
print(f"TOTAL_MEGABYTES: {total_params*4/1_000_000}")

TOTAL_PARAMS: 1183262
DTYPE: float32
TOTAL_MEGABYTES: 4.733048


In [6]:
jnp.arange(1, 10)[, None].shape

(9, 1)

In [16]:
import jax.profiler
jax.profiler.save_device_memory_profile('test.prof')