In [9]:
#imports
import jax
import jax.numpy as jnp
import jax.random as random
import optax
import tokenizers
from tokenizers import CharBPETokenizer
import functools


gpu_device = jax.device_get('gpu')[0]
cpu_device = jax.device_get('cpu')[0]
# LSTM
# xs = B, input_size = B, T, C
# h = c = y = B, output_size = B, T, logits_size = B, T, vocab_size


In [10]:
#dataset
max_vocab_size = 100

with open('data/dnbt_posts.txt', 'r') as file:
  dataset = file.read()

# init tokenizer
tokenizer = tokenizers.Tokenizer(tokenizers.models.WordPiece())
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
trainer = tokenizers.trainers.WordPieceTrainer(vocab_size=max_vocab_size, special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
tokenizer.train_from_iterator([dataset], trainer=trainer)
tokenizer.decoder = tokenizers.decoders.WordPiece()

# tokenize
vocab = list(tokenizer.get_vocab())
vocab_size = len(vocab)
print("vocab length:", len(vocab))

encode = lambda characters: tokenizer.encode(characters).ids
decode = lambda token_ids: tokenizer.decode(token_ids)

test_phrase = "sup dog! its me fam. aaa  aaa"
print(test_phrase, encode(test_phrase), decode(encode(test_phrase)))

dataset_tokens = encode(dataset)




vocab length: 273
sup dog! its me fam. aaa  aaa [84, 168, 172, 69, 181, 195, 5, 74, 166, 170, 78, 162, 71, 165, 164, 18, 66, 165, 165, 66, 165, 165] sup dog! its me fam. aaa aaa


In [11]:
lstm_layers = 2
sequence_length = 100
model_size = 512

input_size = len(vocab) # just do one-hot for now
hidden_size = model_size
output_size = len(vocab) # logits => one-hot => tokens


# init LSTM params
def init_LSTM_params(key, lstm_layers, input_size, model_size, output_size):
  param_sets = 8 # manual, idc
  keys = random.split(key, param_sets*lstm_layers + 2)
  hxconcat_size = model_size + model_size
  he = lambda rkey, shape: random.normal(rkey, shape=shape) * jnp.sqrt(2 / shape[0])
  params = [
    {
      "wU" : he(keys[param_sets*i + 0], (hxconcat_size, model_size)),
      "bU" : jnp.zeros((model_size,)),
      "wC" : he(keys[param_sets*i + 6], (hxconcat_size, model_size)),
      "bC" : jnp.zeros((model_size,)),
      "wF1": he(keys[param_sets*i + 1], (hxconcat_size, model_size)),
      "bF1": jnp.zeros((model_size,)),
      "wF2": he(keys[param_sets*i + 2], (hxconcat_size, model_size)),
      "bF2": jnp.zeros((model_size,)),
      "wO" : he(keys[param_sets*i + 3], (hxconcat_size, model_size)),
      "bO" : jnp.zeros((model_size,)),
    }
    for i in range(lstm_layers)
  ]
  params[0].update(
    {
    # h0 and c0 are unique for each block.
    "h0" : random.normal(keys[param_sets*(param_sets - 1) + 0], shape=(lstm_layers, model_size)) * jnp.sqrt(2 / model_size),
    "c0" : random.normal(keys[param_sets*(param_sets - 1) + 1], shape=(lstm_layers, model_size)) * jnp.sqrt(2 / model_size),
    # then embedding table weight and bias
    "wEM" : he(keys[param_sets*(param_sets - 1) + 2], (input_size, model_size)),
    "bEM" : jnp.zeros((model_size,)),

  })
  params[-1].update(
    {
      # this is for the y layer, which i am probably imlementing wrong.
      "wY1" : he(keys[param_sets*(lstm_layers-1) + 4], (model_size, output_size)),
      "bY1" : jnp.zeros((output_size,)),
    }
  )
  return params


@functools.partial(jax.jit, static_argnames=["dropout_rate"])
def dropout(key, original_tensor, dropout_rate):
  # generate random of same shape
  dropout_probs = random.uniform(key, shape=original_tensor.shape)
  # mask = random < dropout_rate
  mask = (dropout_probs > dropout_rate) / (1 - dropout_rate) # scale to keep avg the same
  return original_tensor * mask


# LSTM forward
import functools
@functools.partial(jax.jit, static_argnames=['batches', 'dropout_rate'])
def lstm_forward(key, batches, lstm_params, xs, dropout_rate):
  logits_ts = []
  lstm_layer = 0
  lstm_layers = len(lstm_params)
  steps = 0
  # initialize h and c as random/learnable params
  h = jnp.tile(lstm_params[0]["h0"], (batches, 1, 1)) # B, lstm_layer, model_size
  c = jnp.tile(lstm_params[0]["c0"], (batches, 1, 1)) # B, lstm_layer, model_size
  T = xs.shape[1]
  # take xs and pass each xt through the same SINGULAR block. don't update the weight layer. there is only one layer.
  dropout_layers = 6*T*lstm_layers # manual
  dropout_keys = random.split(key, dropout_layers)
  while steps < T:
    # iterate through all LSTM blocks, and get the output from the final one
    # send h and c from one block to the next?
    while lstm_layer < lstm_layers:
      hxconcat = jax.lax.concatenate([h[:, lstm_layer], xs[:, steps]], dimension=1) #B, h ++ B, C => B, h+c
      # update gate
      update = jax.nn.sigmoid(hxconcat @ lstm_params[lstm_layer]["wU"] + lstm_params[lstm_layer]["bU"])
      update = dropout(dropout_keys[steps*lstm_layer + 0], update, dropout_rate)
      candidate = jax.nn.tanh(hxconcat @ lstm_params[lstm_layer]["wC"] + lstm_params[lstm_layer]["bC"])
      candidate = dropout(dropout_keys[steps*lstm_layer + 1], candidate, dropout_rate)
      c = c.at[:, lstm_layer].set(c[:, lstm_layer] + update * candidate) # (batch, c) => (batch, c)
      # forget gate
      forget1 = jax.nn.sigmoid(hxconcat @ lstm_params[lstm_layer]["wF1"] + lstm_params[lstm_layer]["bF1"])
      forget1 = dropout(dropout_keys[steps*lstm_layer + 2], forget1, dropout_rate)
      forget2 = jax.nn.tanh(hxconcat @ lstm_params[lstm_layer]["wF2"] + lstm_params[lstm_layer]["bF2"])
      forget2 = dropout(dropout_keys[steps*lstm_layer + 3], forget2, dropout_rate)
      forget = forget1 * forget2
      c = c.at[:, lstm_layer].set(c[:, lstm_layer] + forget) # (batch, c) => (batch, c)

      # output
      o = jax.nn.sigmoid(hxconcat @ lstm_params[lstm_layer]["wO"] + lstm_params[lstm_layer]["bO"])  # B, model_size
      o = dropout(dropout_keys[steps*lstm_layer + 4], o, dropout_rate)
      h = h.at[:, lstm_layer].set(jax.nn.tanh(c[:, lstm_layer]) * o) # (B, model_size)

      xs = xs.at[:, steps].set(h[:, lstm_layer]) # the next layer's input is the current layer's hidden state

      lstm_layer += 1
    
    lstm_layer = len(lstm_params) - 1
    y = h[:, lstm_layer] @ lstm_params[lstm_layer]['wY1'] + lstm_params[lstm_layer]["bY1"] # generate y with the last layer's hidden state
    y = dropout(dropout_keys[steps*lstm_layer + 5], y, dropout_rate)
    #y = y @ lstm_params[i]['wY2'] + lstm_params[i]["bY2"]

    logits_ts.append(y)
    steps += 1
  logits = jnp.transpose(jnp.array(logits_ts), axes=(1, 0, 2)) # T, B, C => B, T, C
  return logits

@functools.partial(jax.jit, static_argnames=["dropout_rate"])
def loss(dropout_key, lstm_params, xs, ys, dropout_rate):
  batches = xs.shape[0] # B, T, C
  logits = lstm_forward(dropout_key, batches, lstm_params, xs, dropout_rate)
  vocab_size = logits.shape[-1]
  ys_one_hot = jax.nn.one_hot(ys, vocab_size, axis=-1)
  logprobs = jax.nn.log_softmax(logits, axis=-1)
  crossentropylosses = -jnp.sum(ys_one_hot * logprobs, axis=-1)
  crossentropyloss = jnp.mean(crossentropylosses)
  return crossentropyloss

@functools.partial(jax.jit, static_argnames=['vocab_size'])
def embed(xtokens, lstm_params, vocab_size=len(vocab)):
  xs_one_hot = jax.nn.one_hot(xtokens, vocab_size, axis=-1) #B, T, vocab_size
  activations = xs_one_hot @ lstm_params[0]["wEM"] + lstm_params[0]["bEM"]
  xembeds = jax.nn.tanh(activations) # B, T, C
  return activations # TODO ??? this doesnt seem to make a difference btw.


lr = 4e-3
lr_decay = 0.995
decay_after = 10
optimizer = optax.inject_hyperparams(optax.adam)(learning_rate=lr)

# make optimizer a static arg in jit or it breaks
@jax.jit
def train(dropout_key, lstm_params, xtokens_batch, ytokens_batch, opt_state, dropout_rate):
  xembeds_batch = embed(xtokens_batch, lstm_params)
  step_loss, grads = jax.value_and_grad(loss, argnums=1)(dropout_key, lstm_params, xembeds_batch, ytokens_batch, dropout_rate)
  param_updates, updated_opt_state = optimizer.update(grads, opt_state, lstm_params)
  updated_lstm_params = optax.apply_updates(lstm_params, param_updates)
  return updated_lstm_params, updated_opt_state, step_loss, grads


In [None]:
# train
# set up lstm params
keys = random.split(random.PRNGKey(123), 20)
lstm_params = init_LSTM_params(keys[0], lstm_layers, input_size, model_size, output_size)
opt_state = optimizer.init(lstm_params)

# train
# for now just overfit on small sample idk lol
dataset_tokens = jnp.array(dataset_tokens)
split_ratio = 0.8
train_tokens = dataset_tokens[:int(len(dataset_tokens)*split_ratio)]
test_tokens = dataset_tokens[int(len(dataset_tokens)*split_ratio):]

train_batch_size = 1
val_batch_size = 1

dropout_rate = 0.5

epochs = 1000
decay_every = 1
for epoch in range(epochs):
  if epoch > decay_after:
    if epoch % decay_every == 0:
      lr *= lr_decay
      opt_state.hyperparams['learning_rate'] = lr
  samples = (len(train_tokens) - 1) // sequence_length
  for i in range(0, len(train_tokens) - 1 - sequence_length*train_batch_size, sequence_length*train_batch_size):
    print('training')
    # train
    xtokens = train_tokens[i:i+sequence_length*train_batch_size]
    ytokens = train_tokens[i+1:i+sequence_length*train_batch_size+1]

    xtokens_batch = xtokens.reshape(-1, sequence_length)
    ytokens_batch = ytokens.reshape(-1, sequence_length) # B, T where T = sequence_length

    dropout_key = random.PRNGKey(epoch*samples + i) # unique for every step

    lstm_params, opt_state, step_loss, grads = train(dropout_key, lstm_params, xtokens_batch, ytokens_batch, opt_state, dropout_rate)
    print('train done')

    # train inference example (no dropout)
    xembeds_batch = embed(xtokens_batch, lstm_params)
    logits_batch = lstm_forward(dropout_key, xembeds_batch.shape[0], lstm_params, xembeds_batch, 0)
    prediction_batch = jnp.argmax(logits_batch, axis=-1)
    print('train inference done')

    # val batch
    print('valing')
    j = i % ((len(test_tokens) - 1)//((val_batch_size)*sequence_length))
    idx = j*val_batch_size*sequence_length
    xtokens_val_batch = test_tokens[idx:idx+sequence_length*val_batch_size].reshape(-1, sequence_length) # batches of sequences lstm block count size
    ytokens_val_batch = test_tokens[idx+1:idx+sequence_length*val_batch_size+1].reshape(-1, sequence_length)
    xembeds_val_batch = embed(xtokens_val_batch, lstm_params)
    
    logits_val_batch = lstm_forward(dropout_key, xembeds_val_batch.shape[0], lstm_params, xembeds_val_batch, 0)
    prediction_val_batch = jnp.argmax(logits_val_batch, axis=-1)
    ys_onehot = jax.nn.one_hot(ytokens_val_batch, len(vocab), axis=-1)
    logprobs = jax.nn.log_softmax(logits_val_batch, axis=-1)
    crossentropies = -jnp.sum(ys_onehot*logprobs,axis=-1)
    val_loss = jnp.mean(crossentropies) #lmao
    val_accuracy = jnp.mean(prediction_val_batch == ytokens_val_batch)
    print('val done')

    x = decode(xtokens_batch[0]).replace('\n', ' ')
    y = decode(ytokens_batch[0]).replace('\n', ' ')
    yhat = decode(prediction_batch[0]).replace('\n', ' ')
    #print(epoch, epoch * samples + i, f"{step_loss:1.4f}", "pred:", x, "=>", y, "?=", yhat)
    print(f'INPUT | "{x}"')
    print(f'PRED  | "{yhat}"')
    print(f"step {(epoch, epoch * samples + i)} || loss: {step_loss:1.4f} || val_loss: {val_loss:1.4f} val_acc: {val_accuracy:1.4f} || LR = ", opt_state.hyperparams['learning_rate'] )
    print()


training




train done


In [11]:
def inference_chars(key, chars):
  xtokens = encode(chars)
  xembed = jax.nn.one_hot(xtokens, len(vocab))[None, :] # artificial single batch
  logits = lstm_forward(xembed.shape[0], lstm_params, xembed)[0][-1] # logits of the first B and last T in the B T C. should be (C,)
  yhattokens = random.choice(key, a=logits.shape[0], p=jax.nn.softmax(logits)) # no need for axis=-1 since logits are (C,)
  sequence = yhattokens
  return sequence


steps = 1000
keys = random.split(random.PRNGKey(1221133), steps)
temperature = 0.5
text = '\n'*sequence_length + 'post: '
tokens = encode(text)
print(text.replace('\n\n', ''), end='')
for i in range(steps):
  xtokens = jnp.array(tokens[-sequence_length:])
  print(xtokens.shape[0], xembed.shape)
  xembed = embed(xtokens, lstm_params)[None, :]
  logits = lstm_forward(random.PRNGKey(0), xembed.shape[0], lstm_params, xembed, 0)[0][-1]
  yseq = random.choice(keys[i], a=logits.shape[0], p=jax.nn.softmax(logits))
  tokens.append(yseq)
  print(decode(tokens).replace('\n\n', ''))
print()

post: 

NameError: name 'xembed' is not defined

In [13]:
jax.tree_util.tree_map(jnp.linalg.norm, grads)

[{'bC': Array(1.6914073e-05, dtype=float32),
  'bEM': Array(0., dtype=float32),
  'bF1': Array(3.101523e-06, dtype=float32),
  'bF2': Array(1.9935575e-05, dtype=float32),
  'bO': Array(3.250238e-05, dtype=float32),
  'bU': Array(3.8769945e-06, dtype=float32),
  'bY1': Array(0., dtype=float32),
  'bY2': Array(0., dtype=float32),
  'c0': Array(0.05276936, dtype=float32),
  'h0': Array(0.00066209, dtype=float32),
  'wC': Array(3.488135e-05, dtype=float32),
  'wEM': Array(0., dtype=float32),
  'wF1': Array(5.667383e-06, dtype=float32),
  'wF2': Array(4.112061e-05, dtype=float32),
  'wO': Array(5.918855e-05, dtype=float32),
  'wU': Array(7.188513e-06, dtype=float32),
  'wY1': Array(0., dtype=float32),
  'wY2': Array(0., dtype=float32)},
 {'bC': Array(0.00494138, dtype=float32),
  'bF1': Array(0.00284158, dtype=float32),
  'bF2': Array(0.003767, dtype=float32),
  'bO': Array(0.05332662, dtype=float32),
  'bU': Array(0.00265908, dtype=float32),
  'bY1': Array(0.0179696, dtype=float32),
  'bY2

In [34]:
getsize = lambda s: s.size
sizes = jax.tree_util.tree_map(getsize, grads)
total_params = 0
for layer in sizes:
  for _, v in layer.items():
    total_params += v

print(f"TOTAL_PARAMS: {total_params}")
print(f"DTYPE: {grads[0]['bC'].dtype}")
print(f"TOTAL_MEGABYTES: {total_params*4/1_000_000}")

TOTAL_PARAMS: 1183262
DTYPE: float32
TOTAL_MEGABYTES: 4.733048
