In [6]:
import jax
import optax
import jax.numpy as jnp
import jax.random as random

In [66]:
# set up and process data
sequence_length = 4 #lmao -> maol => l
raw_train_data = "lmao" * 1000 # lmao

vocab = sorted(list(set(raw_train_data)))
vocab_size = len(vocab)


translations = {
  "encode" : dict([(c, t) for c, t in zip(vocab, range(len(vocab)))]),
  "decode" : dict([(t, c) for c, t in zip(vocab, range(len(vocab)))])
}

token = lambda c: translations["encode"][c] # char to token
char = lambda t: translations["decode"][int(t)] # token to char
encode = lambda cs: jnp.array([token(c) for c in cs])
decode = lambda ts: "".join([char(t) for t in ts])

train_data_tokens = encode(raw_train_data)
x_train_tokens = [jnp.array(train_data_tokens[i:i+sequence_length]) for i in range(len(train_data_tokens))] # 0,1,2,3
y_train_tokens = [jnp.array(train_data_tokens[i+1:i+sequence_length+1]) for i in range(len(train_data_tokens) - 1)] # 1,2,3,4
print(decode(x_train_tokens[0]), decode(y_train_tokens[0]))


lmao maol


experiment:
get my posts
input: nothing
output: a post

experiment:
video generator
input: blank transparent canvas
output: the next frame of the video
do a several convolution skip connections on each step (adds changes to each frame)

In [76]:
# init rnn params
# inputs ()
def init_rnn_params(key, input_shape, hidden_shape, output_shape):
  keys = random.split(key, 20)
  rnn_params = {
    # xW means shape of W is (input, output)
    # hW and xW should output the shape of h.
    "Whh" : random.normal(keys[0], shape=(hidden_shape[0], hidden_shape[0])) * jnp.sqrt(2 / hidden_shape[0]),
    "Whx" : random.normal(keys[1], shape=(input_shape[0], hidden_shape[0])) * jnp.sqrt(2 / input_shape[0]),
    "Why" : random.normal(keys[2], shape=(hidden_shape[0], output_shape[0])) * jnp.sqrt(2 / hidden_shape[0])
  }
  return rnn_params


def step_start(rnn_params, x):
  hidden_shape = rnn_params["Whh"].shape[0]
  h = jnp.zeros(hidden_shape)
  y, h = step(rnn_params, h, x)
  return y, h

def step(rnn_params, h, x):
  z = h @ rnn_params["Whh"] + x @ rnn_params["Whx"]
  h = jax.nn.tanh(z)

  y = h @ rnn_params["Why"]
  return y, h

# choose to get output if you want. update() always returns it but it may not be used.
#
# update(h, x0) ----- update(h, x0) ----- update(h) ----- update(h, x2)
#   |                        |                                   |
#   |                        |                                   |
# step(x0)               step(x1)           step()           step(x2)


def forward(rnn_params, xbow):
  # for now, recreate karpathy's example
  # lma -> mao
  y0, h = step_start(rnn_params, xbow[0])
  y1, h = step(rnn_params, h, xbow[1])
  y2, h = step(rnn_params, h, xbow[2])
  y3, h = step(rnn_params, h, xbow[3])
  return jnp.array([y0, y1, y2, y3])


def init_embedding_params(key, model_dim):
  keys = random.split(key, 10)
  embedding_params = {
    "layer_1" : {
      "w" : random.normal(keys[0], shape=(1, model_dim)),
      "b" : jnp.zeros((model_dim,)),
      }
  }
  return embedding_params

def embed_tokens(embedding_params, tokens):
  # ts[:, None] turns it from [t, t, t, t] to [[t], [t], [t], [t]]. as it should be, a row vector. transpose but for 1d vec.
  x = tokens[:, None] @ embedding_params["layer_1"]["w"] + embedding_params["layer_1"]["b"]
  x = jax.nn.relu(x)
  return x

def embed_chars(embedding_params, chars):
  tokens = encode(chars)
  return embed_tokens(embedding_params, tokens)

def get_loss(rnn_params, xtokens, ytokens):
  xbow = embed_tokens(rnn_params["embedding_params"], xtokens)
  logits = forward(rnn_params, xbow)
  ytokens_one_hot = jax.nn.one_hot(ytokens, len(logits[0]))
  cross_entropies = -jnp.sum(jax.nn.log_softmax(logits, axis=-1) * ytokens_one_hot, axis=-1)
  net_cross_entropy_loss = jnp.sum(cross_entropies)
  return net_cross_entropy_loss



def train_step(rnn_params, xtokens, ytokens, optimizer, opt_state):
  loss, grads = jax.value_and_grad(get_loss)(rnn_params, xtokens, ytokens)
  param_updates, updated_opt_state = optimizer.update(grads, opt_state)
  updated_params = optax.apply_updates(rnn_params, param_updates)
  return loss, updated_params, updated_opt_state



# setup
keys = random.split(random.PRNGKey(198123), 10)

model_dim = 16 # C
input_shape = (model_dim,)
hidden_shape = (20,)
output_shape = (vocab_size,) # logits

embedding_params = init_embedding_params(keys[0], model_dim)

rnn_params = init_rnn_params(keys[1], input_shape, hidden_shape, output_shape)
rnn_params.update({"embedding_params" : embedding_params})

lr = 0.01
optimizer = optax.adam(lr)
opt_state = optimizer.init(rnn_params)




xbow = embed_chars(rnn_params["embedding_params"], "lmao")
ytokens = encode("maoo")

indices = random.permutation(keys[3], len(x_train_tokens))

steps = 1000
print_every = 50
for i in range(steps):
  idx = indices[i]
  xtokens = x_train_tokens[idx]
  ytokens = y_train_tokens[idx]
  loss, rnn_params, opt_state = train_step(rnn_params, xtokens, ytokens, optimizer, opt_state)
  if i % print_every == 0:
    print(i, loss, decode(xtokens), decode(ytokens))

0 6.0532193 maol aolm
50 1.443141 maol aolm
100 0.12231689 olma lmao
150 0.77888703 lmao maol
200 0.24667618 olma lmao
250 0.18515052 lmao maol
300 0.29211754 maol aolm
350 0.08199807 olma lmao
400 0.061798505 lmao maol
450 0.08337262 olma lmao
500 1.4406747 aolm olma
550 0.02647544 lmao maol
600 0.08650149 maol aolm
650 0.02113651 lmao maol
700 0.04582143 maol aolm
750 1.4037068 aolm olma
800 1.4018327 aolm olma
850 1.3933158 aolm olma
900 0.0059581455 olma lmao
950 0.021258852 maol aolm


In [82]:
def inference(rnn_params, xchars):
  xbow = embed_chars(rnn_params["embedding_params"], xchars)
  logits = forward(rnn_params, xbow)
  yhatbow = jnp.argmax(logits, axis=-1)
  yhatbow_chars = decode(yhatbow)
  return yhatbow_chars

text = "lmao"
print(text, end='')
for i in range(100):
  current_input = text[-sequence_length:] # final $seq_length chars
  next_char = inference(rnn_params, current_input)[-1]
  text += next_char
  print(next_char, end='')

lmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmaolmao

: 