In [None]:
%%capture
%pip install git+https://github.com/irhum/hyena.git

## Setup

### Dataset

In [None]:
import jax
import jax.numpy as jnp
import urllib.request

# download the Shakespeare dataset
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = urllib.request.urlopen(url)
text = response.read().decode("utf-8")

# preprocess the data
vocab = sorted(set(text))
char2idx = {c: i for (i, c) in enumerate(vocab)}
idx2char = {i: c for (i, c) in enumerate(vocab)}
tokens = jnp.array([char2idx[c] for c in text])

split_idx = int(0.9*len(tokens))
train, test = tokens[:split_idx], tokens[split_idx:]


# function to generate a single batch of data
def batch_gen(key, data, seq_len, batch_size):
    idxs = jax.random.choice(
        key, len(data) - seq_len - 1, shape=(batch_size,), replace=False
    )
    tok_idxs = jnp.array(idxs)[:, jnp.newaxis] + jnp.arange(seq_len)
    input_tokens = data[tok_idxs]
    target_tokens = data[tok_idxs + 1]
    return input_tokens, target_tokens


# function to generate batches upto a total number of tokens
def dataset(key, data, seq_len, batch_size, total_tokens):
    used_tokens = 0
    while used_tokens < total_tokens:
        key, _ = jax.random.split(key)
        yield batch_gen(key, data, seq_len, batch_size)
        used_tokens += batch_size * seq_len

### Network Definition

In [None]:
from functools import partial

import flax.linen as nn

from hyena import hyena, decoder

In [None]:
# NETWORK
n_dim = 128
n_layers = 6

siren = partial(hyena.Siren, hidden_features=64, num_layers=4, freq=300.0)
mixer = partial(
    hyena.HyenaOperator, max_len=512, filter_fn=siren, modulation_fn=hyena.ExponentialModulation
)
layer = partial(
    decoder.DecoderLayer,
    features=n_dim,
    hidden_features=n_dim * 4,
    mixer_fn=mixer,
    out_init=nn.initializers.normal(stddev=0.02 / jnp.sqrt(2*n_layers)),
)
embed_fn = partial(nn.Embed, num_embeddings=65, features=n_dim, embedding_init=nn.initializers.normal(stddev=0.02))
m = decoder.Decoder(
    embedding=embed_fn(), block_fn=layer, num_layers=n_layers,
    dropout=0.2
)

key = jax.random.PRNGKey(2)
p_key, d_key = jax.random.split(key)
x = jax.random.randint(key, (1, 256), minval=0, maxval=65)
params = m.init({"params": p_key, "dropout": d_key}, x)

In [None]:
import optax
from flax import traverse_util
from flax.core import frozen_dict

# OPTIMIZER
# we create a weight decay mask, to apply only to kernels (not biases)
def wd_mask(params):
    mask = traverse_util.flatten_dict(params, sep="/")
    mask = {k: k.endswith("kernel") for k in mask}
    mask = traverse_util.unflatten_dict(mask, sep="/")
    return frozen_dict.freeze(mask)

# we create an optimizer with weight decay
sched = optax.warmup_cosine_decay_schedule(init_value=1e-4,
  peak_value=1e-3,
  warmup_steps=100,
  decay_steps=5000,
  end_value=1e-4,
)
opt = optax.chain(optax.clip_by_global_norm(1.0),
                  optax.adamw(sched, weight_decay=0.1, mask=wd_mask, b2=0.99))
opt_state = opt.init(params)

In [None]:
# Parameter Count
jax.tree_util.tree_reduce(lambda x, y: x + y, jax.tree_map(lambda x: x.size, params))

1510016

## Training Loop

In [None]:
@partial(jax.jit, static_argnums=(3,))
def loss_fn(params, batch, key, train):
    x, y = batch
    logits = m.apply(params, x, rngs={"dropout": key}, deterministic=not train)
    loss = jnp.mean(
        optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y)
    )
    return loss


@jax.jit
def update(params, opt_state, batch, key):
    loss, grad = jax.value_and_grad(loss_fn)(params, batch, key, True)
    updates, opt_state = opt.update(grad, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss


all_losses = []

data_key = jax.random.PRNGKey(0)
drop_key = jax.random.PRNGKey(1)

# iterate over the batches, for 40 epochs
for i, batch in enumerate(dataset(data_key, train, 256, 64, len(train) * 40)):
    drop_key = jax.random.fold_in(drop_key, i)
    params, opt_state, trn_loss = update(params, opt_state, batch, d_key)

    # every 250 steps, compute the validation loss
    if i % 250 == 0:
        losses = []
        for test_batch in dataset(data_key, test, 256, 256, len(test)*2):
            losses.append(float(loss_fn(params, test_batch, drop_key, False)))

        loss = jnp.mean(jnp.array(losses))
        all_losses.append([i, loss])
        print(trn_loss, f"Iteration {i}: Validation loss: {loss}")

4.1729574 Iteration 0: Validation loss: 4.153886795043945
2.5329142 Iteration 250: Validation loss: 2.529892683029175
1.8411844 Iteration 500: Validation loss: 1.8480111360549927
1.5812775 Iteration 750: Validation loss: 1.6493785381317139
1.4629028 Iteration 1000: Validation loss: 1.5552294254302979
1.3523409 Iteration 1250: Validation loss: 1.5148968696594238
1.347347 Iteration 1500: Validation loss: 1.4818370342254639
1.324385 Iteration 1750: Validation loss: 1.4734055995941162
1.3026986 Iteration 2000: Validation loss: 1.4473998546600342
1.2757719 Iteration 2250: Validation loss: 1.4485535621643066


## Inference

In [None]:
completion_length = 200
prompt = "And therefore, "

prompt_tokens = jnp.array([char2idx[c] for c in prompt])
start_idx = len(prompt_tokens) - 1
prompt_tokens = jnp.pad(prompt_tokens, (0, 256 - len(prompt_tokens)), mode="constant")[None, :]

In [None]:
@jax.jit
def prefill(params, prompt_tokens, idxs):
    _, cache = m.apply(
        params,
        prompt_tokens,
        mode="prefill",
        mutable=["cache"],
        deterministic=True,
        idxs=idxs,
    )

    return cache["cache"]


@jax.jit
def sample_token(key, params, cache, current_token):
    logits, cache = m.apply(
        {"params": params["params"], "cache": cache},
        current_token,
        mode="decode",
        mutable=["cache"],
        deterministic=True,
    )
    cache = cache["cache"]
    next_token = jax.random.categorical(key, logits).astype(jnp.int32)

    return next_token, cache

In [None]:
# Prefill the cache
cache = prefill(params, prompt_tokens, jnp.array([start_idx - 1]))

# Then, decode!
sampling_key = jax.random.PRNGKey(1)
current_token = prompt_tokens[:, start_idx : start_idx + 1]
completion = prompt

for i in range(completion_length):
    sampling_key = jax.random.fold_in(sampling_key, i)
    current_token, cache = sample_token(sampling_key, params, cache, current_token)
    completion += idx2char[int(current_token[0, 0])]

print(completion)

And therefore, bastard Marcius?

LEONTES:
Because the princes and leave the sun's hour shadows I
and a battle. What awe music?

COMINIUS:
Desay to your friend, I am a mile England.

ISABELLA:
The ginss of young of t
