In [1]:
import jax
import optax
import jax.numpy as jnp
import jax.random as random
from functools import partial
import random as rand
import time

from tokenizers import CharBPETokenizer

gpu_device = jax.devices('gpu')[0]
cpu_device = jax.devices('cpu')[0]

In [2]:
debug_log = print
debug_log = lambda *x: 0

In [3]:
# params
sequence_length = 2
max_vocab_size = 1000

experiment:
get my posts
input: nothing
output: a post

experiment:
video generator
input: blank transparent canvas
output: the next frame of the video
do a several convolution skip connections on each step (adds changes to each frame)

In [4]:
# process tweet data
import json

# Load the tweets.js file
with open('./tweets.js', 'r', encoding='utf-8') as file:
    # Skip the JavaScript assignment and load only the JSON part
    content = file.read()
    json_data = content.split('=', 1)[1].strip()  # Extract the JSON part after `=`
    json_data = json_data.rstrip(';')  # Remove trailing semicolon if present
    tweets_data = json.loads(json_data)

# Initialize lists for replies and posts
replies = []
posts = []

# Process each tweet in the dataset
for tweet_obj in tweets_data:
    tweet = tweet_obj["tweet"]
    
    if "in_reply_to_status_id_str" in tweet and tweet["in_reply_to_status_id_str"]:
        # It's a reply
        replies.append({
            "id": tweet["id_str"],
            "text": tweet["full_text"],
            "in_reply_to": tweet["in_reply_to_status_id_str"],
            "user": tweet.get("in_reply_to_screen_name", None),
            "created_at": tweet["created_at"]
        })
    else:
        # It's a standalone post
        posts.append({
            "id": tweet["id_str"],
            "text": tweet["full_text"],
            "created_at": tweet["created_at"]
        })

# Output the results
print("Replies:")
for reply in replies[:4]:
    print(reply["text"])

print("\nPosts:")
for post in posts[:4]:
    print(post["text"])


Replies:
@VictorTaelin ive found sonnet 3.5v2 to be surprisingly good at coding. upgraded my tools from v1 to v2 and all of a sudden i have to reprompt it like 1/3rd the time
@dgant &lt;script src="gifeditor.mp3" type="application/json"&gt;
@bozo10n 💪
@sunsettler experimentation games are the best

Posts:
compression always feels so satisfying https://t.co/mM5acJxydL
it only takes one line of code to make a gifboard btw https://t.co/hFlPyNTvRm
RT @calbch: @kuberdenis entrepreneurship is the ultimative vehicle for personal development
just two idiots playing a game of chess https://t.co/USjWySv3W9


In [5]:
# set up and process data
post_limit = 100

stop_char = "🛑"
all_posts = [
  "reply: " + reply["text"] + stop_char for reply in replies[:post_limit]
]
all_posts.extend(
  [
    "post: " + post["text"] + stop_char for post in posts[:post_limit]
  ]
)
rand.shuffle(all_posts)
print('gathered initial posts')

token_type = "char"
if token_type == "char":
  tokenizer = CharBPETokenizer()

  text = "".join(all_posts)
  # Initialize and train the tokenizer
  
  tokenizer.train_from_iterator([text], vocab_size = max_vocab_size)

  # Tokenize the text
  vocab_size = len(tokenizer.get_vocab())
  print("vocab size:", vocab_size)

  encode = lambda characters: tokenizer.encode(characters).ids
  decode = lambda token_ids: "".join(tokenizer.decode(token_ids))

  translation = dict([(v, k.replace("</w>", " ")) for k, v in tokenizer.get_vocab().items()])
  decode_to_tokens = lambda token_ids: [translation[int(token_id)] for token_id in token_ids]


post_token_sets = [encode(post) for post in all_posts]
print('converted posts to tokens')

from tqdm import tqdm
sequence_token_sets = []
for post in tqdm(post_token_sets):
  for i in range(max(0, len(post) - sequence_length + 1)):
    sequence_token_sets.append(post[i:i+sequence_length])


print(f"loaded {len(sequence_token_sets)} sequences of len {sequence_length}.")


gathered initial posts



vocab size: 1000
converted posts to tokens


100%|██████████| 200/200 [00:00<00:00, 43683.84it/s]

loaded 11089 sequences of len 2.





In [9]:
# train functions
# init rnn params
# inputs ()
def init_rnn_params(key, input_shape, hidden_shape, conveyor_shape, output_shape):
  keys = random.split(key, 4*sequence_length)
  hx_concat_length = hidden_shape[0] + input_shape[0] # todo incorporate batching
  conveyor_shape = conveyor_shape[0]
  output_size = output_shape[0]
  rnn_params = [
    # xW means shape of W is (input, output)
    # hW and xW should output the shape of h.
    {
      "Whxc" : random.normal(keys[3*n + 0], shape=(hx_concat_length, output_size)) * jnp.sqrt(2 / hx_concat_length),
      "Bhxc" : jnp.zeros(shape=(output_size,)),
      "Whxi" : random.normal(keys[3*n + 1], shape=(hx_concat_length, output_size)) * jnp.sqrt(2 / hx_concat_length),
      "Bhxi" : jnp.zeros(shape=(output_size,)),
      "Whxcupd" : random.normal(keys[3*n + 2], shape=(hx_concat_length, output_size)) * jnp.sqrt(2 / hx_concat_length),
      "Bhxcupd" : jnp.zeros(shape=(output_size,)),
      "Whxo" : random.normal(keys[3*n + 3], shape=(hx_concat_length, output_size)) * jnp.sqrt(2 / hx_concat_length),
      "Bhxo" : jnp.zeros(shape=(output_size,)),
    }
    for n in range(sequence_length) # +1 for the empty sequence at the start # WRONG
  ]
  return rnn_params

@partial(jax.jit, device=gpu_device)
def step_start(rnn_params, x):
  # todo incorporate batching
  hx_concat_length, conveyor_shape = rnn_params[0]["Whxc"].shape
  # x is B, C (T == 1 so its not in here)
  batch_size = x.shape[0]
  input_shape = x.shape[1]
  h_shape = hx_concat_length - input_shape
  h = jnp.ones((batch_size, h_shape)) # batch size 1 for now
  c = jnp.ones((batch_size, conveyor_shape))
  y, c, h = step(rnn_params, c, h, x, n=0)
  debug_log('y: ', y.shape, f"should be {('Batch', 'C(logits=vocab_size)')}")
  return y, c, h


@partial(jax.jit, static_argnames=["n"], device=gpu_device) # n is static (i think calculated at jit comptime each time a new val is passed)
def step(rnn_params, c, h, x, n):
  # dim=1 so that for batching when its (n, h+x) it concats h+x
  hxconcat = jax.lax.concatenate([h, x], dimension=1) # (batch, h+x)

  debug_log('x: ', x.shape, f"should be {('Batch', 'C(one hot encodings)')}")

  # input gate
  # (B, h+x) @ (h+x,c) => (B, c) (B = batch dimension, right now its just 1)
  debug_log(n)
  wxc = rnn_params[n]["Whxc"]
  xc = hxconcat @ wxc
  b = rnn_params[n]["Bhxc"]
  f = jax.nn.sigmoid(xc + b)
  c = c * f # (B, c) * (B, c) => (B, c)

  # update gate
  # (B, h+x) @ (h+x, c) => (B, c)
  i = jax.nn.sigmoid(hxconcat @ rnn_params[n]["Whxi"] + rnn_params[n]["Bhxi"])
  # (B, h+x) @ (h+x, c) => (B, c)
  c_upd = jax.nn.tanh(hxconcat @ rnn_params[n]["Whxcupd"] + rnn_params[n]["Bhxcupd"])
  c = c + i*c_upd # (B, c) + (B, c) * (B, c) => (B, c)

  # output gate
  o = jax.nn.sigmoid(hxconcat @ rnn_params[n]["Whxo"] + rnn_params[n]["Bhxo"])
  h = o * jax.nn.tanh(c) # delete gate

  y = h
  return y, c, h

@partial(jax.jit, device=gpu_device)
def forward(rnn_params, xbow):
  debug_log("forward() called")
  ys = []
  yi, c, h = step_start(rnn_params, xbow[:, 0]) # try to predict the first token of each x in the batch
  ys.append(yi)
  debug_log("steps_total:", xbow.shape[1], list(range(1, xbow.shape[1])))
  for n in range(1, xbow.shape[1]):
    debug_log("step n:", n)
    y, c, h = step(rnn_params, c, h, xbow[:, n], n=n) # try to predict i+1th token for each batch
    ys.append(y)
  
  ys_out = jnp.transpose(jnp.array(ys), (1, 0, 2)) # turn it from (T, B, C) to (B, T, C)
  debug_log("ys", ys_out.shape, f"should be {('batches', 3, 'C(logits=vocab_size)')}")

  return ys_out


# learned embeddings
embedding_type = "learned"
if embedding_type == "learned":
  @jax.jit
  def embed_tokens(embedding_params, tokens):
    # ts[:, None] turns it from [t, t, t, t] to [[t], [t], [t], [t]]. as it should be, a row vector. transpose but for 1d vec.
    # x = one-hot(x) -> embed(x)
    oh = jax.nn.one_hot(tokens, vocab_size, axis=-1) # B, T, => B, T, 1 => T, vocab_size
    debug_log(tokens.shape, oh.shape, embedding_params["layer_1"]["w"].shape)
    x = oh @ embedding_params["layer_1"]["w"] + embedding_params["layer_1"]["b"] # (T, vocab_size) # (vocab_size, model_dim) => (T, model_dim)
    return x
  
  def init_embedding_params(key, vocab_size, model_dim):
    keys = random.split(key, 10)
    embedding_params = {
      "layer_1" : {
        "w" : random.normal(keys[0], shape=(vocab_size, model_dim)), # T, vocab_size => T, model_dim
        "b" : jnp.zeros((model_dim,)),
        }
    }
    return embedding_params
elif embedding_type == "one-hot":
  # one hot embeddings
  # EXPENSIVE when vocab size is high
  @partial(jax.jit, device=gpu_device)
  def embed_tokens(tokens):
    # create channel dim and one hot. (B, T) => (B, T, 1) => (B, T, C)
    return jax.nn.one_hot(tokens, vocab_size, axis=-1)


def embed_chars(chars):
  tokens = encode(chars)
  return embed_tokens(tokens)


@partial(jax.jit, device=gpu_device)
def get_loss(rnn_params, xtokens, ytokens):
  debug_log("xtokens: ", xtokens.shape, "should be (B, T)")
  xbow = embed_tokens(rnn_params[-1], xtokens)
  debug_log("embeddings: ", xbow.shape, "should be", "(B, T, C(model_size)) =", ("b", 2, "model_size"))
  logits = forward(rnn_params, xbow)
  debug_log("logits:", logits.shape, "should be", "(B, T, C) =", ("b", 2, "vocab_size"))
  vocab_size = logits.shape[-1] # channel dimension
  ytokens_one_hot = jax.nn.one_hot(ytokens, vocab_size, axis=-1)
  debug_log("ytokens_one_hot", ytokens_one_hot.shape, "should be", "(B, T, C(one hot)) =", ("b", 2, "vocab_size"))
  cross_entropies = -jnp.sum(ytokens_one_hot * jax.nn.log_softmax(logits, axis=-1), axis=-1) # axis=-1 is along C in a (B,T,C)
  debug_log("cross_entropies", cross_entropies.shape, "should be", "(B, T, 1) =", ("b", 3, 1))
  net_cross_entropy_loss = jnp.mean(cross_entropies)
  return net_cross_entropy_loss



In [11]:
# train_setup
keys = random.split(random.PRNGKey(198123), 10)

if embedding_type == "learned":
  model_dim = 1024*8 # C
elif embedding_type == "one-hot":
  model_dim = vocab_size
input_shape = (model_dim,) # channel size
output_shape = (vocab_size,) # logits
hidden_shape = output_shape # small for now
conveyor_shape = output_shape # these HAVE TO BE THE SAME



rnn_params = init_rnn_params(keys[1], input_shape, hidden_shape, conveyor_shape, output_shape)
if embedding_type == "learned":
  embedding_params = init_embedding_params(keys[0], vocab_size, model_dim)
  rnn_params.append(embedding_params)
jax.device_put(rnn_params, device=gpu_device)



lr = 2e-1
minibatch_size = 1
minibatch_lr_scaling = jnp.sqrt(minibatch_size)
lr *= minibatch_lr_scaling
scheduler = optax.schedules.linear_onecycle_schedule(
  transition_steps=10000,
  peak_value=lr,
  pct_start = 0.1,
  pct_final = 0.9,
  div_factor = 25,
  final_div_factor=100000,
)
optimizer = optax.chain(
  optax.scale_by_adam(),
  optax.scale_by_schedule(scheduler),
  optax.scale(-1), # params += -learning_rate x grads
)

# https://stackoverflow.com/a/53046624

#optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(rnn_params)


@jax.jit
def train_step(rnn_params, xtokens, ytokens, opt_state):
  loss, grads = jax.value_and_grad(get_loss)(rnn_params, xtokens, ytokens)
  param_updates, updated_opt_state = optimizer.update(grads, opt_state, rnn_params)
  updated_params = optax.apply_updates(rnn_params, param_updates)
  return loss, updated_params, updated_opt_state, grads

In [12]:
# train
train_steps = 10000
base_sample_size = 3 # num of post samples to train on
sequences = jnp.array(sequence_token_sets[:base_sample_size + minibatch_size + 1])
sample_size = base_sample_size + minibatch_size
print_every = 1
print(f'Training on {sequences.shape[0]} posts')
last_time = time.time()
#with jax.profiler.trace('./tmp/run'):
for train_step_num in range(train_steps - 1):
  # train step on batch
  train_sample_idx = train_step_num % base_sample_size
  xtokens_minibatch = sequences[train_sample_idx:train_sample_idx+minibatch_size]
  ytokens_minibatch = sequences[train_sample_idx+1:train_sample_idx+minibatch_size+1]
  debug_log("xtokens_minibatch: ", xtokens_minibatch.shape)
  loss, rnn_params, opt_state, grads = train_step(rnn_params, xtokens_minibatch, ytokens_minibatch, opt_state)
  
  if train_sample_idx % print_every == 0:
    batch_time = time.time() - last_time
    last_time = time.time()
    
    show_output = True
    if show_output:
      logits = forward(rnn_params, embed_tokens(rnn_params[-1], xtokens_minibatch)) # B, T, C
      yhattokens = jnp.argmax(logits, axis=-1) # B, T
      debug_log("logits:", logits.shape)
      yhat_minibatch = jnp.argmax(logits, axis=-1)
      debug_log("yhat_minibatch:", yhat_minibatch.shape)
      prediction = decode(xtokens_minibatch[-1]).replace('\n', '')
      prediction_out = decode(yhat_minibatch[-1]).replace('\n', '')
      correct = decode(ytokens_minibatch[-1]).replace('\n', '')
      print(f'minibatch {train_step_num:4.0f}  loss/seq={loss}  samples/s={minibatch_size/batch_time:0.2f}  prediction: "{prediction}" => "{prediction_out}" || "{correct}"')
    else:
      print(f'minibatch {train_step_num:4.0f}  loss/seq={loss/xtokens_minibatch.shape[0]:3.4f}  tsteps/s={xtokens_minibatch.shape[0]/batch_time:0.2f}')


Training on 5 posts
hxconcat (1, 9192)
hxconcat (1, 9192)
minibatch    0  loss/seq=7.154355049133301  samples/s=0.37  prediction: "post :" => "? ti" || ": Im"
minibatch    1  loss/seq=6.786767959594727  samples/s=0.58  prediction: ": Im" => "UP Im" || "Im so"
minibatch    2  loss/seq=7.171241760253906  samples/s=26.54  prediction: "Im so" => "* so" || "so so"
minibatch    3  loss/seq=6.8018388748168945  samples/s=26.33  prediction: "post :" => "? ti" || ": Im"
minibatch    4  loss/seq=6.071624755859375  samples/s=26.67  prediction: ": Im" => "CCIm" || "Im so"
minibatch    5  loss/seq=6.530707359313965  samples/s=26.82  prediction: "Im so" => "? so" || "so so"
minibatch    6  loss/seq=6.801838397979736  samples/s=27.77  prediction: "post :" => "? ti" || ": Im"
minibatch    7  loss/seq=6.057343482971191  samples/s=26.48  prediction: ": Im" => "very Im" || "Im so"
minibatch    8  loss/seq=6.521933078765869  samples/s=26.90  prediction: "Im so" => "? so" || "so so"
minibatch    9  loss/seq

KeyboardInterrupt: 

In [57]:
def inference(xchars):
  xtokens = jnp.array(encode(xchars))
  debug_log(xtokens)
  xbow = embed_tokens(rnn_params[-1], xtokens)[None, :] # from T,C to B,T,C where B=1
  logits = forward(rnn_params, xbow) # 1, T, C
  yhatbow = jnp.argmax(logits, axis=-1)[0] #1, T => T
  yhatbow_chunks = decode_to_tokens(yhatbow) # T
  return yhatbow_chunks # T

text = 'post : '
print(text, end='')
for i in range(100):
  current_input = decode(encode(text)[-sequence_length:]) # final $seq_length chars
  debug_log("current input", len(encode(current_input)), f"|{current_input}|")
  next_chunk = inference(current_input)[-1]
  text += next_chunk
  print(next_chunk, end='')

post : 5AM 5555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555

In [29]:
# print update norms
from pprint import pprint
norms = jax.tree_util.tree_map(lambda x: jnp.linalg.norm(x), grads)
pprint(norms)

[{'Bhxc': Array(0., dtype=float32),
  'Bhxcupd': Array(2.323032e-07, dtype=float32),
  'Bhxi': Array(3.941133e-07, dtype=float32),
  'Bhxo': Array(0.00017044, dtype=float32),
  'Whxc': Array(0., dtype=float32),
  'Whxcupd': Array(5.434752e-06, dtype=float32),
  'Whxi': Array(9.7011325e-06, dtype=float32),
  'Whxo': Array(0.00866771, dtype=float32)},
 {'Bhxc': Array(0., dtype=float32),
  'Bhxcupd': Array(0., dtype=float32),
  'Bhxi': Array(0., dtype=float32),
  'Bhxo': Array(0.00013062, dtype=float32),
  'Whxc': Array(0., dtype=float32),
  'Whxcupd': Array(0., dtype=float32),
  'Whxi': Array(0., dtype=float32),
  'Whxo': Array(0.00522402, dtype=float32)},
 {'layer_1': {'b': Array(0.00042082, dtype=float32),
              'w': Array(0.00042985, dtype=float32)}}]


In [52]:
# print param count
from pprint import pprint
norms = jax.tree_util.tree_map(lambda x: x.size, grads)
pprint(norms)

[{'Bhxc': 1000,
  'Bhxcupd': 1000,
  'Bhxi': 1000,
  'Bhxo': 1000,
  'Whxc': 2024000,
  'Whxcupd': 2024000,
  'Whxi': 2024000,
  'Whxo': 2024000},
 {'Bhxc': 1000,
  'Bhxcupd': 1000,
  'Bhxi': 1000,
  'Bhxo': 1000,
  'Whxc': 2024000,
  'Whxcupd': 2024000,
  'Whxi': 2024000,
  'Whxo': 2024000},
 {'layer_1': {'b': 1024, 'w': 1024000}}]


In [7]:
import jax
import jax.numpy as jnp
with jax.profiler.trace('./tmp/trace'):
  for i in range(1000):
    x = jax.nn.log_softmax(jnp.arange(100000000))
  y = x + 100


2024-12-02 16:14:12.800525: E external/xla/xla/python/profiler/internal/python_hooks.cc:400] Can't import tensorflow.python.profiler.trace
  out = np.array(c).astype(eqn.params['new_dtype'])
2024-12-02 16:14:22.047804: E external/xla/xla/python/profiler/internal/python_hooks.cc:400] Can't import tensorflow.python.profiler.trace


In [14]:

debug_log = lambda *x: print("[DEBUG]:", *x)

In [25]:
debug_log = lambda *x: 0

In [15]:
tokenizer.get_vocab()

{'wn</w>': 413,
 'TI': 967,
 'ter': 443,
 'd': 62,
 'Sh': 645,
 'reme': 837,
 'imo</w>': 892,
 'dys': 918,
 '\U0001fae1</w>': 163,
 'fee': 880,
 'that</w>': 309,
 'ol': 304,
 'ch': 317,
 'ive</w>': 440,
 'min': 395,
 'bi': 649,
 'ec': 356,
 'ku': 594,
 'has</w>': 430,
 'into</w>': 692,
 '`</w>': 141,
 'poten': 845,
 'potential</w>': 936,
 'OO</w>': 961,
 'sett': 617,
 'me': 259,
 'getting</w>': 716,
 'No': 779,
 'resources</w>': 934,
 'cad': 985,
 '🛑</w>': 127,
 'all': 699,
 'bt': 388,
 'pt</w>': 466,
 'ck</w>': 791,
 'ful</w>': 343,
 'dn': 986,
 'sel': 705,
 'loo': 618,
 'NI': 778,
 'fa': 804,
 'S</w>': 137,
 'buil': 429,
 'ges</w>': 661,
 'you': 267,
 'ML</w>': 434,
 'g': 65,
 'ssi': 315,
 'A': 31,
 'LETS</w>': 751,
 'der': 798,
 'now</w>': 353,
 'sually</w>': 862,
 'tter</w>': 416,
 '/': 15,
 'ho': 292,
 'lo': 229,
 'his</w>': 541,
 'times</w>': 851,
 'Pe': 963,
 'zo': 832,
 'ME</w>': 777,
 'I</w>': 152,
 'ni': 361,
 'glo': 999,
 'bo': 324,
 'thats</w>': 916,
 'mon</w>': 598,
 'gr':

In [40]:
tokenizer.decode([377])

'ther'