In [6]:
# IMPORTANT: there is a colab where you can easily run this
# see FINISHED/LSTM/LSTM.md for the link
# or copy and paste this

In [43]:
#imports
import jax
import jax.numpy as jnp
import jax.random as random
import optax
import functools
import time

# colab
# pip install jax[gpu] optax numpy

gpu_device = jax.device_get('gpu')[0]
cpu_device = jax.device_get('cpu')[0]


In [45]:
## settings ##
# remove rare characters (that appear under 0.01% of the time)
# this help the model train better
vocab_frequency_threshold = 0.0001

# the path to your tweets. you only need to upload tweets.js. dataset.txt is created automatically
twitter_js_path = 'tweets.js' # the file with your posts
dataset_path = 'dataset.txt'

## if you want to train the LSTM on a .txt file instead of your tweets:
# 1) set this to True
use_text_file_instead_of_tweets = False 
# 2) upload the text file and name it 'dataset.txt'
#       - or, change the dataset_path to the name of your .txt file

lstm_layers = 2
model_size = 512
train_epochs = 50

# there are more in the 'set model/training params' block if you want to get more advanced

In [9]:
# clean/load/tokenize dataset


# code to load tweets.js, written by chatgpt
# because why tf would i write this
import json

# Load the tweets.js file
with open(twitter_js_path, 'r', encoding='utf-8') as file:
    # Skip the JavaScript assignment and load only the JSON part
    content = file.read()
    json_data = content.split('=', 1)[1].strip()  # Extract the JSON part after `=`
    json_data = json_data.rstrip(';')  # Remove trailing semicolon if present
    tweets_data = json.loads(json_data)

# Initialize lists for replies and posts
replies = []
posts = []

# Process each tweet in the dataset
for tweet_obj in tweets_data:
    tweet = tweet_obj["tweet"]
    
    if "in_reply_to_status_id_str" in tweet and tweet["in_reply_to_status_id_str"]:
        # It's a reply
        replies.append({
            "id": tweet["id_str"],
            "text": tweet["full_text"],
            "in_reply_to": tweet["in_reply_to_status_id_str"],
            "user": tweet.get("in_reply_to_screen_name", None),
            "created_at": tweet["created_at"]
        })
    else:
        # It's a standalone post
        posts.append({
            "id": tweet["id_str"],
            "text": tweet["full_text"],
            "created_at": tweet["created_at"]
        })

# Output the results
print("Replies:")
for reply in replies[:4]:
    print(reply["text"])

print("\nPosts:")
for post in posts[:4]:
    print(post["text"])



# code to clean the data a bit more, adding stop tokens and prefixes
import numpy.random as rand

stop_character = "🛑"
filename = dataset_path

all_posts = [
  "reply: " + reply["text"] + stop_character for reply in replies
]
all_posts.extend(
  [
    "post: " + post["text"] + stop_character for post in posts
  ]
)
rand.shuffle(all_posts)

print("posts:", len(all_posts))
dataset = "\n\n\n\n\n\n".join(all_posts)
print('chars:', len(dataset))


with open(filename, 'w') as file:
  file.write(dataset)



# code to load/tokenize dataset
with open(dataset_path, 'r') as file:
  dataset = file.read()

# remove chars w low frequency
removed_chars = []
frequencies = []
dataset_length = len(dataset)
for c in set(dataset):
  frequencies.append((dataset.count(c), c, c.isalnum()))
  if dataset.count(c)/dataset_length < vocab_frequency_threshold:
    removed_chars.append(c)
    dataset = dataset.replace(c, '')


# tokenize
vocab = sorted(list(set(dataset)))
print("vocab length:", len(vocab))

token_to_char = dict(enumerate(vocab))
char_to_token = dict([(v, k) for k, v in token_to_char.items()])
decode = lambda tokens: "".join([token_to_char[int(token)] for token in tokens])
encode = lambda chars: jnp.array([char_to_token[c] for c in chars])

dataset_tokens = encode(dataset)
split_ratio = 0.9
train_tokens = dataset_tokens[:int(len(dataset_tokens)*split_ratio)]
test_tokens = dataset_tokens[int(len(dataset_tokens)*split_ratio):]
del dataset
del dataset_tokens


print("removed:", "".join(removed_chars))
print("dog", encode("dog"), decode(encode("dog")))



Replies:
@VictorTaelin ive found sonnet 3.5v2 to be surprisingly good at coding. upgraded my tools from v1 to v2 and all of a sudden i have to reprompt it like 1/3rd the time
@dgant &lt;script src="gifeditor.mp3" type="application/json"&gt;
@bozo10n 💪
@sunsettler experimentation games are the best

Posts:
compression always feels so satisfying https://t.co/mM5acJxydL
it only takes one line of code to make a gifboard btw https://t.co/hFlPyNTvRm
RT @calbch: @kuberdenis entrepreneurship is the ultimative vehicle for personal development
just two idiots playing a game of chess https://t.co/USjWySv3W9
posts: 3916
chars: 529781
vocab length: 84
removed: 🚀𝗻𝗵[😎🌑”📈📉‍😢吧ɪ~}👀😭ᴏ$“ʜ𝗰😆^走`😉💪ᴡᴛ我𝗱🍰🤔ᴇ𝗪😁👌🎉#🤦ᴀ𝘀😤🤣ᴘ🤷♂☠🧠𝗯ɴ𝗼𝗿𝘂𝘁]ʟ🫡𝗲𝗶ʀ|️👍{ᴄ*們’🤯
dog [59 70 62] dog


In [10]:
# lstm network & other functions
def init_LSTM_params(key, lstm_layers, input_size, model_size, output_size):
  param_sets = 8 # manual, idc
  keys = random.split(key, param_sets*lstm_layers + 2)
  hxconcat_size = model_size + model_size
  he = lambda rkey, shape: random.normal(rkey, shape=shape) * jnp.sqrt(2 / shape[0])
  # supposedly xavier is better for networks using tanh
  xavier = lambda rkey, shape: random.normal(rkey, shape=shape) * jnp.sqrt(2 / (shape[0] + shape[1]))
  params = [
    {
      "wU" : xavier(keys[param_sets*i + 0], (hxconcat_size, model_size)),
      "bU" : jnp.zeros((model_size,)),
      "wC" : xavier(keys[param_sets*i + 6], (hxconcat_size, model_size)),
      "bC" : jnp.zeros((model_size,)),
      "wF": xavier(keys[param_sets*i + 1], (hxconcat_size, model_size)),
      "bF": jnp.zeros((model_size,)),
      "wO" : xavier(keys[param_sets*i + 3], (hxconcat_size, model_size)),
      "bO" : jnp.zeros((model_size,)),
      "h0" : jnp.zeros((model_size,)),
      "c0" : jnp.zeros((model_size,)),
      #"h0" : random.normal(keys[param_sets*i + 4], shape=(model_size)) * jnp.sqrt(2 / model_size),
      #"c0" : random.normal(keys[param_sets*i + 5], shape=(model_size)) * jnp.sqrt(2 / model_size),
    }
    for i in range(lstm_layers)
  ]
  params[0].update(
    {
    # then embedding table weight and bias
    "wEM" : xavier(keys[param_sets*(param_sets - 1) + 2], (input_size, model_size)),
    "bEM" : jnp.zeros((model_size,)),

  })
  params[-1].update(
    {
      # this is for the y layer, which i am probably imlementing wrong.
      "wY1" : xavier(keys[param_sets*(lstm_layers-1) + 4], (model_size, model_size)),
      "bY1" : jnp.zeros((model_size,)),
      "wY2" : xavier(keys[param_sets*(lstm_layers-1) + 5], (model_size, output_size)),
      "bY2" : jnp.zeros((output_size,)),
    }
  )
  return params


@functools.partial(jax.jit, static_argnames=[])
def dropout(dropout_key, original_tensor, dropout_rate):
  # generate random of same shape
  dropout_probs = random.uniform(dropout_key, shape=original_tensor.shape)
  # mask = random < dropout_rate
  mask = (dropout_probs > dropout_rate) / (1 - dropout_rate) # scale to keep avg the same
  return original_tensor * mask


@functools.partial(jax.jit, static_argnames=[]) # static dropout rate?
def lstm_step(step_dropout_key, lstm_layer_params, layer_h, layer_c, current_xt, dropout_rate):
  hxconcat = jax.lax.concatenate([layer_h, current_xt], dimension=1) #B, h ++ B, C => B, h+c
  # update gate
  forget_gate = jax.nn.sigmoid(hxconcat @ lstm_layer_params["wF"] + lstm_layer_params["bF"])
  #update = dropout(step_dropout_keys[0], update, dropout_rate)

  # forget
  layer_c = layer_c * forget_gate

  input_node = jax.nn.tanh(hxconcat @ lstm_layer_params["wC"] + lstm_layer_params["bC"])
  #candidate = dropout(step_dropout_keys[1], candidate, dropout_rate)
  update = jax.nn.sigmoid(
              hxconcat @ lstm_layer_params["wU"] + lstm_layer_params["bU"]
            )
  input_gate =  update * input_node

  # update
  layer_c = layer_c + input_gate

  # output
  layer_h = jax.nn.tanh(layer_c) * jax.nn.sigmoid(hxconcat @ lstm_layer_params["wO"] + lstm_layer_params["bO"]) # (B, model_size)

  next_layer_xt = dropout(step_dropout_key, layer_h, dropout_rate) # the next layer's input x is the current layer's hidden state
  # karpathy: dropout after EACH LAYER not several times in the block. lol.

  # i may also need to do dropout horizontally (i.e. dropout the hidden state memory each block)

  return (layer_h, layer_c), next_layer_xt


# LSTM forward
import functools
@functools.partial(jax.jit, static_argnames=[])
def lstm_forward(dropout_key, lstm_params, xembeds_batch, dropout_rate):
  batches = xembeds_batch.shape[0]
  lstm_layers = len(lstm_params)
  model_size = lstm_params[0]["h0"].size
  # initialize h and c as random/learnable params
  #h = jnp.tile(lstm_params[0]["h0"], (batches, lstm_layers, 1)) # B, lstm_layer, h_size
  #c = jnp.tile(lstm_params[0]["c0"], (batches, lstm_layers, 1)) # B, lstm_layer, c_size
  # wait.. these are the same for all of the layers.. maybe they shouldn't be
  T = xembeds_batch.shape[1]
  # take xembeds_batch and pass each xt through the same SINGULAR block. don't update the weight layer. there is only one layer.
  dropout_keys = random.split(dropout_key, lstm_layers)

  # for each layer:
    # scan over xt
    # carry : h, c
    # a: xt
    # b: h,c
    # f = lambda ((h, c), xt) : lstm_step(h, c, xt, everything else) => h, c
    # scans over xt
    # for next layer: xt = h of previous layer. h = h0 and c = c0
  
  current_embeddings_batch = jnp.transpose(xembeds_batch, (1, 0, 2)) # B, T, C => T, B, C
    # The reason for this is that jax.lax.scan only uses the leading dim. why? idk. its dumb, it needs an axis arg so i can scan over whatever

  for lstm_layer in range(lstm_layers):
    h = jnp.tile(lstm_params[lstm_layer]["h0"], (batches, 1))
    c = jnp.tile(lstm_params[lstm_layer]["c0"], (batches, 1))
    # zeroes makes the backprop faster
    #h = jnp.zeros((batches, model_size))
    #c = jnp.zeros((batches, model_size))
    layer_dropout_key = dropout_keys[lstm_layer] # it doesnt matter if this is the same across all layers
    # scan should be inexpensive since layer size is small while t size is usually LARGE
    # scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
    # scan :: scanfunc -> h_and_c -> xs -> (h_and_c_final, hs_to_be_used_as_input_xt_in_next_layer)
    # scanfunc :: (c -> a -> (c, b))
    scanfunc = lambda hc, xt : lstm_step(layer_dropout_key, lstm_params[lstm_layer], hc[0], hc[1], xt, dropout_rate)
      # for xs: scan along the t dimension! it scans along B by default
      # to fix this, we transpose xs with jnp.transpose(current_embeddings_batch, (1, 0, 2))
    current_embeddings_batch = jax.lax.scan(scanfunc, (h, c), current_embeddings_batch)[1] # (c, [b]) => [b] ==> B, T, C
  

  # finally turn current_embeddings_batch into ys (logits)
  hs = jnp.transpose(current_embeddings_batch, (1, 0, 2)) # T, B, C => B, T, C
  ys = jax.nn.relu(hs @ lstm_params[-1]['wY1'] + lstm_params[-1]["bY1"]) # B, T, model_size => B, T, vocab_size
  ys = ys @ lstm_params[-1]['wY2'] + lstm_params[-1]["bY2"]
  return ys


@functools.partial(jax.jit, static_argnames=[])
def loss_func(dropout_key, lstm_params, xtokens_batch, ytokens_batch, dropout_rate):
  xembeds_batch = embed(lstm_params, xtokens_batch)
  logits = lstm_forward(dropout_key, lstm_params, xembeds_batch, dropout_rate)
  vocab_size = logits.shape[-1]
  ys_one_hot = jax.nn.one_hot(ytokens_batch, vocab_size, axis=-1)
  logprobs = jax.nn.log_softmax(logits, axis=-1)
  crossentropylosses = -jnp.sum(ys_one_hot * logprobs, axis=-1)
  crossentropyloss = jnp.mean(crossentropylosses)
  return crossentropyloss


@functools.partial(jax.jit, static_argnames=[])
def loss_and_value(dropout_key, lstm_params, xtokens_batch, ytokens_batch, dropout_rate):
  xembeds_batch = embed(lstm_params, xtokens_batch)
  logits = lstm_forward(dropout_key, lstm_params, xembeds_batch, dropout_rate)
  vocab_size = logits.shape[-1]
  ys_one_hot = jax.nn.one_hot(ytokens_batch, vocab_size, axis=-1)
  logprobs = jax.nn.log_softmax(logits, axis=-1)
  predictions = jnp.argmax(logprobs, axis=-1)
  crossentropylosses = -jnp.sum(ys_one_hot * logprobs, axis=-1)
  crossentropyloss = jnp.mean(crossentropylosses)
  return crossentropyloss, predictions


jitted_backwards_loss = jax.jit(jax.value_and_grad(loss_func, argnums=1), static_argnames=[])


@functools.partial(jax.jit, static_argnames=['vocab_size'])
def embed(lstm_params, xtokens, vocab_size=len(vocab)):
  xs_one_hot = jax.nn.one_hot(xtokens, vocab_size, axis=-1) #B, T, vocab_size
  activations = xs_one_hot @ lstm_params[0]["wEM"] + lstm_params[0]["bEM"]
  return activations

# make optimizer a static arg in jit or it breaks
@functools.partial(jax.jit, static_argnames=["optimizer"])
def train(dropout_key, lstm_params, xtokens_batch, ytokens_batch, opt_state, dropout_rate, optimizer):
  step_loss, grads = jitted_backwards_loss(dropout_key, lstm_params, xtokens_batch, ytokens_batch, dropout_rate)
  param_updates, updated_opt_state = optimizer.update(grads, opt_state, lstm_params)
  updated_lstm_params = optax.apply_updates(lstm_params, param_updates) 
  return updated_lstm_params, updated_opt_state, step_loss, grads




In [46]:
# set model and training parameters

# train params
# these initialy parameters seem to train well
lr = 0.0007
sequence_length = 100
dropout_rate = 0.25

decay_lr = False
decay = 0.98
decay_epochs = 3

epochs = train_epochs
print_every = 100_000 # print progress every n steps. set extremely high to only print every epoch
train_batch_size = 50
val_batch_size = 50

# needs to be 'true' initially
# set to False to continue training on your existing model
start_new_training_run = True 

In [47]:
# train the model


# init some variables
input_size = len(vocab) # just do one-hot for now
hidden_size = model_size
output_size = len(vocab) # logits => one-hot => tokens
keys = random.split(random.PRNGKey(123), 20)
losses = []
val_losses = []
val_accuracies = []
start = time.time()


# init optimizer and network state
if start_new_training_run:
  optimizer = optax.inject_hyperparams(optax.adam)(learning_rate=lr)
  lstm_params = init_LSTM_params(keys[0], lstm_layers, input_size, model_size, output_size)
  opt_state = optimizer.init(lstm_params)
else:
  opt_state.hyperparams['learning_rate'] = lr


# train
for epoch in range(epochs):
    if decay_lr and epoch != 0 and epoch % decay_epochs == 0:
       opt_state.hyperparams['learning_rate'] = opt_state.hyperparams['learning_rate'] * decay

    # train
    steps = (len(train_tokens) // ((sequence_length+1)*train_batch_size)) - 2
    for step in range(steps): # probably wrong but w/e
      # B, T where T = sequence_length
      train_data_idx = step*sequence_length*train_batch_size
      next_train_data_idx = (step+1)*sequence_length*train_batch_size
      xtokens_batch = train_tokens[train_data_idx:next_train_data_idx].reshape(-1, sequence_length) #(B, T)
      ytokens_batch = train_tokens[train_data_idx+1:next_train_data_idx+1].reshape(-1, sequence_length) # (B,)

      dropout_key = random.PRNGKey(epoch*steps + step) # unique for every step
      lstm_params, opt_state, step_loss, grads = train(dropout_key, lstm_params, xtokens_batch, ytokens_batch, opt_state, dropout_rate, optimizer)

      losses.append(step_loss)

      # val
      j = step % ((len(test_tokens) - 1)//((val_batch_size)*sequence_length))
      val_idx = j*val_batch_size*sequence_length
      next_val_idx = (j+1)*val_batch_size*sequence_length
      xtokens_val_batch = test_tokens[val_idx:next_val_idx].reshape(-1, sequence_length) # batches of sequences lstm block count size
      ytokens_val_batch = test_tokens[val_idx+1:next_val_idx+1].reshape(-1, sequence_length)
      
      val_loss, prediction_val_batch = loss_and_value(dropout_key, lstm_params, xtokens_val_batch, ytokens_val_batch, dropout_rate=0)
      val_accuracy = jnp.mean(prediction_val_batch == ytokens_val_batch)

      val_losses.append(val_loss)
      val_accuracies.append(val_accuracy)

      if (step == steps - 1):
        end = time.time()
        duration = end - start
        # train inference example (no dropout)
        xembeds_batch = embed(lstm_params, xtokens_batch[0][None, :]) # 1-batch - (1, T, C)
        last_logit_batch = lstm_forward(dropout_key, lstm_params, xembeds_batch, 0) # B, C
        prediction_batch = jnp.argmax(last_logit_batch, axis=-1) # B

        # print train status
        x = decode(xtokens_batch[0]).replace('\n', ' ')
        y = decode(ytokens_batch[0]).replace('\n', ' ')
        yhat = decode(prediction_batch[0]).replace('\n', ' ')
        #print(f'INPUT  ({len(x)}) | "{x}"')
        avg_loss = sum(losses)/len(losses)
        avg_val_loss = sum(val_losses)/len(val_losses)
        avg_val_acc = sum(val_accuracies)/len(val_accuracies)
        lines = [
          f'TARGET | "{y}"',
          f'PRED   | "{yhat}"',
          f"e:{epoch+1}/{epochs} s:{step+1}/{steps} || samples/sec: {train_batch_size*steps/(duration):0.0f} || "
          f"loss: {step_loss:1.4f} || val_loss: {avg_val_loss:1.4f} val_acc: {avg_val_acc:1.4f} || " 
          f"LR = {opt_state.hyperparams['learning_rate']:0.6f}",
        ]
        print("\n".join(lines))
        start = time.time()

    losses = []
    val_losses = []

TARGET | "nd put a texture on it  ill remember this if i ever do a larger opengl project, im at the point wher"
PRED   | "t  tooetooo      to to   te to e ee  toen to toooe etoetoooe    toe    toee  te toetoetoe toen  toe "
e:1/50 s:92/92 || samples/sec: 242 || loss: 2.7491 || val_loss: 3.1553 val_acc: 0.2019 || LR = 0.000700
TARGET | "nd put a texture on it  ill remember this if i ever do a larger opengl project, im at the point wher"
PRED   | "nd tos tnthr  r  tf tn  htlyteaerei  then tn tntreretomt tete   tf r  ytoont t  tn tn the trrng toe "
e:2/50 s:92/92 || samples/sec: 264 || loss: 2.2967 || val_loss: 2.4468 val_acc: 0.2734 || LR = 0.000700
TARGET | "nd put a texture on it  ill remember this if i ever do a larger opengl project, im at the point wher"
PRED   | "ld trs t shr  re tf tt  hsl teaarel  then in tttnerotowt tiye   tf r eetrowett  in t  the trrng the "
e:3/50 s:92/92 || samples/sec: 248 || loss: 2.0919 || val_loss: 2.1427 val_acc: 0.3197 || LR = 0.000700
TARGET | "nd put

In [51]:
# inference settings
# the first time you run the model, inference will be slow

temperature = 1   # from 0 to 2. 1 is normal.

reply_prompt = "reply: "
post_prompt = "post: "

prompt = post_prompt

# how many characters until the line wraps around
formatting_line_length = 50


In [53]:
# run the model!

def inference(key, chars, temperature):
  xtokens = encode(chars)[None, :]
  xembed = embed(lstm_params, xtokens) # artificial single batch
  logits = lstm_forward(key, lstm_params, xembed, 0)[0][-1] # logits of the first B and last T in the B T C. should be (C,)
  probs = jax.nn.softmax(logits/(temperature + 0.001))
  yhattokens = random.choice(key, a=logits.shape[0], p=probs) # no need for axis=-1 since logits are (C,)
  return yhattokens


steps = 1000
import time
seed = int(1000*time.time())
keys = random.split(random.PRNGKey(seed), steps)
text =  " "*50 + prompt
print(text.replace('  ', ''), end='')
for i in range(steps):
  next_token = inference(keys[i], text[-sequence_length:], temperature)
  next_char = decode([next_token])[-1]
  if next_char == '🛑':
    print(next_char, end='')
    break
  text += next_char
  if (len(text) - 50) % formatting_line_length == 0:
    print()
  print(next_char, end='')

post: 10, so far)

yeah i exploding download the 
context and look it takes data w a contack and sto
ring how it works suuuuper well

searching for cla
ssic psychological architecture:
https://t.co/wdQC
5NrLED🛑