In [None]:
# Seq2seq inference using our own home made transformer

import sys
sys.path.append('../')
import jax.numpy as jnp
import jax
from transformer import Seq2SeqTransformer

In [None]:
from tiktoken import list_encoding_names

list_encoding_names()

In [None]:
# Load tokenizer

from utils import get_tokenizer

tokenizer = get_tokenizer("r50k_base")
VOCAB_SIZE = tokenizer.n_vocab
print(f"Vocab size: {VOCAB_SIZE}")

In [None]:
print(tokenizer.special_tokens_set)
print(tokenizer.eot_token)
print(tokenizer.decode([50256]))

In [None]:
EMB_SIZE = 4
rng = jax.random.PRNGKey(0)

transformer_kwargs = {
    'n_heads': 2,
    'n_layers': 1,
    'd_ff': 2,
}

model = Seq2SeqTransformer(src_vocab_size=VOCAB_SIZE, emb_size=EMB_SIZE, **transformer_kwargs)
state = model.init_state(rng)

In [None]:
# Try a forward pass

SRC_LEN = 2
TGT_LEN = 3
BATCH_SIZE = 1

src = jnp.ones((SRC_LEN, BATCH_SIZE), dtype=jnp.int32)
tgt = jnp.ones((TGT_LEN, BATCH_SIZE), dtype=jnp.int32)
print(src.shape) # (SRC_LEN, BATCH_SIZE)

rng = jax.random.PRNGKey(1)
out = model(state, src, tgt, rng) # (TGT_LEN, BATCH_SIZE, VOCAB_SIZE)
print(out.shape)

In [24]:
# Greedy decoding. We'll use the same tokenizer for input and output for simplicity!
# Should really use custom tokens. Pretend we're starting with token 0

src = "Mumbo jumbo"
src = jnp.array(tokenizer.encode(src))
print(src)

rng = jax.random.PRNGKey(42069)
res = model.generate(state, src, rng, max_len=4)
res_list = res.squeeze().tolist()

print(tokenizer.decode(res_list))

[   44 29309   474 29309]
! sidelines TerritoriesClock deed
