In [1]:
#imports
import jax
import jax.numpy as jnp
import jax.random as random
import optax
import tokenizers


gpu_device = jax.device_get('gpu')[0]
cpu_device = jax.device_get('cpu')[0]
# LSTM
# xs = B, input_size = B, T, C
# h = c = y = B, output_size = B, T, logits_size = B, T, vocab_size


In [27]:
#dataset
max_vocab_size = 1000

with open('data/dnbt_posts.txt', 'r') as file:
  dataset = file.read()

# init tokenizer
tokenizer = tokenizers.Tokenizer(tokenizers.models.WordPiece())
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
trainer = tokenizers.trainers.WordPieceTrainer(vocab_size=max_vocab_size, special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
tokenizer.train_from_iterator([dataset], trainer=trainer)
tokenizer.decoder = tokenizers.decoders.WordPiece()

# tokenize
vocab = list(tokenizer.get_vocab())
vocab_size = len(vocab)
print("vocab length:", len(vocab))

encode = lambda characters: tokenizer.encode(characters).ids
decode = lambda token_ids: tokenizer.decode(token_ids)

test_phrase = "sup dog! its me fam. aaa  aaa"
print(test_phrase, encode(test_phrase), decode(encode(test_phrase)))

dataset_tokens = encode(dataset)




vocab length: 1000
sup dog! its me fam. aaa  aaa [366, 182, 346, 171, 5, 380, 384, 71, 325, 18, 66, 172, 172, 66, 172, 172] sup dog! its me fam. aaa aaa


In [31]:
lstm_blocks = 2 # thanks karpathy
sequence_length = 100 # thanks karpathy
model_size = 128*4 # thanks karpathy

input_size = len(vocab) # just do one-hot for now
hidden_size = model_size
output_size = len(vocab) # logits => one-hot => tokens


# init LSTM params
def init_LSTM_params(key, lstm_blocks, input_size, model_size, output_size):
  layers = 8
  keys = random.split(key, layers*lstm_blocks + 2)
  hxconcat_size = input_size + model_size
  he = lambda rkey, shape: random.normal(rkey, shape=shape) * jnp.sqrt(2 / shape[0])
  params = [
    {
      "wU" : he(keys[layers*i + 0], (hxconcat_size, model_size)),
      "bU" : jnp.zeros((model_size,)),
      "wC" : he(keys[layers*i + 6], (hxconcat_size, model_size)),
      "bC" : jnp.zeros((model_size,)),
      "wF1": he(keys[layers*i + 1], (hxconcat_size, model_size)),
      "bF1": jnp.zeros((model_size,)),
      "wF2": he(keys[layers*i + 2], (hxconcat_size, model_size)),
      "bF2": jnp.zeros((model_size,)),
      "wO" : he(keys[layers*i + 3], (hxconcat_size, model_size)),
      "bO" : jnp.zeros((model_size,)),
      # this is for the y layer, which i am probably imlementing wrong.
      "wY1" : he(keys[layers*i + 4], (model_size, output_size)),
      "bY1" : jnp.zeros((output_size,)),
      "wY2" : he(keys[layers*i + 5], (output_size, output_size)),
      "bY2" : jnp.zeros((output_size,)),
    }
    for i in range(lstm_blocks)
  ]
  params[0].update(
    {
    "h0" : random.normal(keys[layers*(layers - 1) + 0], shape=(model_size, )) * jnp.sqrt(2 / model_size),
    "c0" : random.normal(keys[layers*(layers - 1) + 1], shape=(model_size, )) * jnp.sqrt(2 / model_size),
  })
  return params


def lstm_step(lstm_params, xs, h, c, i):
  hxconcat = jax.lax.concatenate([h, xs[:, i]], dimension=1) #B, h ++ B, C => B, h+c
  # update gate
  update = jax.nn.sigmoid(hxconcat @ lstm_params[i]["wU"] + lstm_params[i]["bU"])
  candidate = jax.nn.tanh(hxconcat @ lstm_params[i]["wC"] + lstm_params[i]["bC"])
  c = c + update * candidate # (batch, c) => (batch, c)
  # forget gate
  forget1 = jax.nn.sigmoid(hxconcat @ lstm_params[i]["wF1"] + lstm_params[i]["bF1"])
  forget2 = jax.nn.tanh(hxconcat @ lstm_params[i]["wF2"] + lstm_params[i]["bF2"])
  forget = forget1 * forget2
  c = c + forget # (batch, c) => (batch, c)

  # output
  o = jax.nn.sigmoid(hxconcat @ lstm_params[i]["wO"] + lstm_params[i]["bO"])  # B, model_size
  h = jax.nn.tanh(c) * o # (B, model_size)
  return h, c

# LSTM forward
import functools
@functools.partial(jax.jit, static_argnames=['batches'])
def lstm_forward(batches, lstm_params, xs):
  logits_ts = []
  lstm_block = 0
  steps = 0
  # initialize h and c as random/learnable params
  h = jnp.tile(lstm_params[0]["h0"], (batches, 1))
  c = jnp.tile(lstm_params[0]["c0"], (batches, 1))
  T = xs.shape[1]
  # take xs and pass each xt through the same SINGULAR block. don't update the weight layer. there is only one layer.
  while steps < T:
    # iterate through all LSTM blocks, and get the output from the final one
    # send h and c from one block to the next?
    while lstm_block < len(lstm_params):
      hxconcat = jax.lax.concatenate([h, xs[:, steps]], dimension=1) #B, h ++ B, C => B, h+c
      # update gate
      update = jax.nn.sigmoid(hxconcat @ lstm_params[lstm_block]["wU"] + lstm_params[lstm_block]["bU"])
      candidate = jax.nn.tanh(hxconcat @ lstm_params[lstm_block]["wC"] + lstm_params[lstm_block]["bC"])
      c = c + update * candidate # (batch, c) => (batch, c)
      # forget gate
      forget1 = jax.nn.sigmoid(hxconcat @ lstm_params[lstm_block]["wF1"] + lstm_params[lstm_block]["bF1"])
      forget2 = jax.nn.tanh(hxconcat @ lstm_params[lstm_block]["wF2"] + lstm_params[lstm_block]["bF2"])
      forget = forget1 * forget2
      c = c + forget # (batch, c) => (batch, c)

      # output
      o = jax.nn.sigmoid(hxconcat @ lstm_params[lstm_block]["wO"] + lstm_params[lstm_block]["bO"])  # B, model_size
      h = jax.nn.tanh(c) * o # (B, model_size)

      lstm_block += 1
    
    lstm_block = len(lstm_params) - 1
    y = h @ lstm_params[lstm_block]['wY1'] + lstm_params[lstm_block]["bY1"]
    #y = y @ lstm_params[i]['wY2'] + lstm_params[i]["bY2"]

    logits_ts.append(y)
    steps += 1
  logits = jnp.transpose(jnp.array(logits_ts), axes=(1, 0, 2)) # T, B, C => B, T, C
  return logits


@jax.jit
def loss(lstm_params, xs, ys):
  batches = xs.shape[0] # B, T, C
  logits = lstm_forward(batches, lstm_params, xs)
  vocab_size = logits.shape[-1]
  ys_one_hot = jax.nn.one_hot(ys, vocab_size, axis=-1)
  logprobs = jax.nn.log_softmax(logits, axis=-1)
  crossentropylosses = -jnp.sum(ys_one_hot * logprobs, axis=-1)
  crossentropyloss = jnp.mean(crossentropylosses)
  return crossentropyloss


lr = 2e-2 # thanks karpathy
lr_decay = 0.97 # thanks karpathy
decay_after = 10 # thanks karpathy
optimizer = optax.inject_hyperparams(optax.adam)(learning_rate=lr)


# make optimizer a static arg in jit or it breaks
@jax.jit
def train(lstm_params, xs, ys, opt_state):
  step_loss, grads = jax.value_and_grad(loss)(lstm_params, xs, ys)
  param_updates, updated_opt_state = optimizer.update(grads, opt_state, lstm_params)
  updated_lstm_params = optax.apply_updates(lstm_params, param_updates)
  return updated_lstm_params, updated_opt_state, step_loss, grads


# set up lstm params
keys = random.split(random.PRNGKey(123), 20)
lstm_params = init_LSTM_params(keys[0], lstm_blocks, input_size, model_size, output_size)
opt_state = optimizer.init(lstm_params)

# train
# for now just overfit on small sample idk lol
dataset_tokens = jnp.array(dataset_tokens)
train_tokens = dataset_tokens[:int(len(dataset_tokens)*0.9)]
test_tokens = dataset_tokens[int(len(dataset_tokens)*0.9):]

train_batch_size = 50 # thanks karpathy
val_batch_size = 3 # nx3

epochs = 1000
for epoch in range(epochs):
  if epoch > decay_after:
    lr *= lr_decay
    opt_state.hyperparams['learning_rate'] = lr
  samples = (len(train_tokens) - 1) // sequence_length
  for i in range(0, len(train_tokens) - 1 - sequence_length*train_batch_size, sequence_length*train_batch_size):
    xtokens = train_tokens[i:i+sequence_length*train_batch_size]
    ytokens = train_tokens[i+1:i+sequence_length*train_batch_size+1]

    xtokens_batch = xtokens.reshape(-1, sequence_length)
    ytokens_batch = ytokens.reshape(-1, sequence_length) # B, T where T = sequence_length

    xembeds_batch = jax.nn.one_hot(xtokens_batch, len(vocab), axis=-1)

    # train example
    logits_batch = lstm_forward(xembeds_batch.shape[0], lstm_params, xembeds_batch)
    prediction_batch = jnp.argmax(logits_batch, axis=-1)

    # val batch
    
    j = i % ((len(test_tokens) - 1)//((val_batch_size)*sequence_length))
    idx = j*val_batch_size*sequence_length
    xtokens_val_batch = test_tokens[idx:idx+sequence_length*val_batch_size].reshape(-1, sequence_length) # batches of sequences lstm block count size
    ytokens_val_batch = test_tokens[idx+1:idx+sequence_length*val_batch_size+1].reshape(-1, sequence_length)
    xembeds_val_batch = jax.nn.one_hot(xtokens_val_batch, len(vocab), axis=-1)
    
    logits_val_batch = lstm_forward(xembeds_val_batch.shape[0], lstm_params, xembeds_val_batch)
    prediction_val_batch = jnp.argmax(logits_val_batch, axis=-1)
    ys_onehot = jax.nn.one_hot(ytokens_val_batch, len(vocab), axis=-1)
    logprobs = jax.nn.log_softmax(logits_val_batch, axis=-1)
    crossentropies = -jnp.sum(ys_onehot*logprobs,axis=-1)
    val_loss = jnp.mean(crossentropies) #lmao
    val_accuracy = jnp.mean(prediction_val_batch == ytokens_val_batch)

    lstm_params, opt_state, step_loss, grads = train(lstm_params, xembeds_batch, ytokens_batch, opt_state)
    x = decode(xtokens_batch[0]).replace('\n', ' ')
    y = decode(ytokens_batch[0]).replace('\n', ' ')
    yhat = decode(prediction_batch[0]).replace('\n', ' ')
    #print(epoch, epoch * samples + i, f"{step_loss:1.4f}", "pred:", x, "=>", y, "?=", yhat)
    print(f'INPUT | "{x}"')
    print(f'PRED  | "{yhat}"')
    print(f"step {(epoch, epoch * samples + i)} || loss: {step_loss:1.4f} || val_loss: {val_loss:1.4f} val_acc: {val_accuracy:1.4f} || LR = ", opt_state.hyperparams['learning_rate'] )
    print()

INPUT | "reply : @ trickylabyrinth A better model architecture not based on tokens 🛑 reply : @ lelouchdaily Its a muscle that can be trained though through practice, luckily 🛑 reply : @ 0xluffyb kick their ass 🛑 post : RT @ gizmobly : Hi it ' s # buildinpublic time Myself and distinguish"
PRED  | "pixqc pixqc Nomin Nominjectjectjectjectject H H 9 9 9 9 9 9 9 9 com com com com com com com comchch ev ev ev evbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb ev ev ev ev ev ev ev ev why why why why why why why why why why why why why why why why whyom whyomomomomomomomomom"
step (0, 0) || loss: 7.1067 || val_loss: 7.0515 val_acc: 0.0000 || LR =  0.02

INPUT | "a lot industrious people succeed more if you build 10x more stuff than the next guy youre way more likely to succeed pretty awesome story btw 🛑 reply : @ Brycicle77 & gt ; the act of telling someone your goals can paradoxically make you less likely to pursue them & gt ; I do wonder why both cases of my pin"
PRED  | "at idkri get get get get get get ge

KeyboardInterrupt: 

In [34]:
def inference_chars(key, chars):
  xtokens = encode(chars)
  xembed = jax.nn.one_hot(xtokens, len(vocab))[None, :] # artificial single batch
  logits = lstm_forward(xembed.shape[0], lstm_params, xembed)[0][-1] # logits of the first B and last T in the B T C. should be (C,)
  yhattokens = random.choice(key, a=logits.shape[0], p=jax.nn.softmax(logits)) # no need for axis=-1 since logits are (C,)
  sequence = yhattokens
  return sequence

def inference_tokens(key, xtokens):
  xembed = jax.nn.one_hot(xtokens, len(vocab))[None, :] # artificial single batch
  logits = lstm_forward(xembed.shape[0], lstm_params, xembed)[0][-1] # logits of the first B and last T in the B T C. should be (C,)
  yhattokens = random.choice(key, a=logits.shape[0], p=jax.nn.softmax(logits)) # no need for axis=-1 since logits are (C,)
  sequence = yhattokens
  return sequence

steps = 1000
keys = random.split(random.PRNGKey(1221133), steps)
temperature = 0.5
text = "\n"*50 + 'reply: '
tokens = encode(text)
print(text.replace('\n\n', ''), end='')
for i in range(steps):
  yseq = inference_tokens(keys[i], tokens[-sequence_length:])
  tokens.append(yseq)
  print(decode(tokens).replace('\n\n', ''))
print()

reply: reply : @
reply : @ se
reply : @ seb
reply : @ sebb
reply : @ sebby
reply : @ sebby_
reply : @ sebby_b
reply : @ sebby_builds
reply : @ sebby_builds W
reply : @ sebby_builds WA
reply : @ sebby_builds WAS
reply : @ sebby_builds WASI
reply : @ sebby_builds WASIG
reply : @ sebby_builds WASIGM
reply : @ sebby_builds WASIGMA
reply : @ sebby_builds WASIGMAN
reply : @ sebby_builds WASIGMAN 🛑
reply : @ sebby_builds WASIGMAN 🛑 reply
reply : @ sebby_builds WASIGMAN 🛑 reply :
reply : @ sebby_builds WASIGMAN 🛑 reply : @
reply : @ sebby_builds WASIGMAN 🛑 reply : @ le
reply : @ sebby_builds WASIGMAN 🛑 reply : @ lelo
reply : @ sebby_builds WASIGMAN 🛑 reply : @ lelou
reply : @ sebby_builds WASIGMAN 🛑 reply : @ lelouch
reply : @ sebby_builds WASIGMAN 🛑 reply : @ lelouchd
reply : @ sebby_builds WASIGMAN 🛑 reply : @ lelouchda
reply : @ sebby_builds WASIGMAN 🛑 reply : @ lelouchdaily
reply : @ sebby_builds WASIGMAN 🛑 reply : @ lelouchdaily N
reply : @ sebby_builds WASIGMAN 🛑 reply : @ lelouchdaily N

In [21]:
jax.tree_util.tree_map(jnp.linalg.norm, grads)

[{'bC': Array(0.00382806, dtype=float32),
  'bF1': Array(0.00067422, dtype=float32),
  'bF2': Array(0.00351771, dtype=float32),
  'bO': Array(0.00102159, dtype=float32),
  'bU': Array(0.00061966, dtype=float32),
  'bY1': Array(0., dtype=float32),
  'bY2': Array(0., dtype=float32),
  'c0': Array(0.00768698, dtype=float32),
  'h0': Array(0.00609007, dtype=float32),
  'wC': Array(0.00712224, dtype=float32),
  'wF1': Array(0.00125441, dtype=float32),
  'wF2': Array(0.00654482, dtype=float32),
  'wO': Array(0.0019007, dtype=float32),
  'wU': Array(0.0011529, dtype=float32),
  'wY1': Array(0., dtype=float32),
  'wY2': Array(0., dtype=float32)},
 {'bC': Array(0.00296637, dtype=float32),
  'bF1': Array(0.00067843, dtype=float32),
  'bF2': Array(0.00304542, dtype=float32),
  'bO': Array(0.06423178, dtype=float32),
  'bU': Array(0.00065218, dtype=float32),
  'bY1': Array(0.08539081, dtype=float32),
  'bY2': Array(0., dtype=float32),
  'wC': Array(0.0122483, dtype=float32),
  'wF1': Array(0.00278

In [34]:
getsize = lambda s: s.size
sizes = jax.tree_util.tree_map(getsize, grads)
total_params = 0
for layer in sizes:
  for _, v in layer.items():
    total_params += v

print(f"TOTAL_PARAMS: {total_params}")
print(f"DTYPE: {grads[0]['bC'].dtype}")
print(f"TOTAL_MEGABYTES: {total_params*4/1_000_000}")

TOTAL_PARAMS: 1183262
DTYPE: float32
TOTAL_MEGABYTES: 4.733048
