🔸 https://github.com/BeeGass/HiPPO-Jax

In [1]:
! pip install -q --upgrade pip
! pip install -q einops jax jaxlib flax jaxtyping typing-extensions
! pip install -q tqdm

In [2]:
import os
import sys
import warnings
from pathlib import Path

from tqdm import tqdm

In [3]:
JIT = False  # Set to False to disable JIT

if JIT:
    # TODO: set JIT=True and inspect this logfile
    # takes > 30min to generate on my macbook
    %env XLA_FLAGS=--xla_dump_to=/tmp/why_is_this_slow.txt


In [4]:
# module_path = os.path.abspath(os.path.join("../../../../"))
# print(f"module_path: {module_path}")
# if module_path not in sys.path:
#     print(f"Adding {module_path} to sys.path")
#     sys.path.append(module_path)

In [5]:
warnings.filterwarnings("ignore")

In [6]:
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"

## import packages

In [7]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.activation import sigmoid, tanh, relu
import numpy as np
# import torch
from flax.training import train_state  # , orbax_utils
import orbax.checkpoint as ocp
import optax
from jaxtyping import install_import_hook

In [8]:
from functools import partial

In [9]:
from hippo_live import HiPPOCell, HiPPOLTI

from cells_live import LSTMCell, BatchedGatedHiPPOCell, CharRNN

from trans import initializer, legt, legs, lmu, lagt, fru, fout, foud

In [10]:
print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


[CpuDevice(id=0)]
The Device: cpu


In [11]:
# print(f"MPS available: {torch.backends.mps.is_available()}")

In [12]:
# torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [13]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [14]:
num_copies = 20
subkeys = jax.random.split(key, num=num_copies)
key = subkeys[0]

## Parameters For Generating Data

In [15]:
T = 1
freq = 10
step = 1 / (28 * 28)  # 1e-3
L = int(T / step)

Parameters For Training

In [16]:
batch_size = 64
data_size = L
input_size = 1
_block_size = 512

In [17]:
num_sequences = 100
epochs = 3
lr = 0.001

Parameters For HiPPO

In [18]:
N = 128

# Dataset (TinyShakespeare)

In [19]:
with open("../datasets/shakespeare.txt", "r", encoding="latin-1") as f:
    text = f.read()


In [20]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [21]:
print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


get all the unique characters that occur in this text

In [22]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("all the unique characters:", "".join(chars))
print(f"vocab size: {vocab_size:,}")

all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65


create a mapping from characters to integers

In [23]:
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

In [24]:
def encode(s):
    int_list = []
    for c in s:
        if c not in stoi:
            raise ValueError(f"character {c} not in vocabulary")
        else:
            int_list.append(
                stoi[c]
            )  # encoder: take a string, output a list of integers
    return int_list

In [25]:
def decode(l):
    str_list = []
    for i in l:
        if i not in itos:
            raise ValueError(f"integer {i} not in the vocabulary")
        else:
            str_list.append(
                itos[i]
            )  # decoder: take a list of integers, output a list of characters
    return "".join(str_list)  # take a list of integers, output a string

In [26]:
txt = "The fox jumps over the lazy dog."
print(encode(txt))
print(decode(encode(txt)))

[32, 46, 43, 1, 44, 53, 62, 1, 48, 59, 51, 54, 57, 1, 53, 60, 43, 56, 1, 58, 46, 43, 1, 50, 39, 64, 63, 1, 42, 53, 45, 8]
The fox jumps over the lazy dog.


In [27]:
data = jnp.array(encode(text))
print(data.shape)
print(data[:100])

(1115394,)
[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43 39
 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 37
 53 59]


Let's now split up the data into train and validation sets

In [28]:
n = int(0.9 * len(data))  # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
print(f"train has {len(train_data):,} tokens")
print(f"train has shape {train_data.shape}\n")
print(f"val has {len(val_data):,} tokens")
print(f"val has shape {val_data.shape}")

train has 1,003,854 tokens
train has shape (1003854,)

val has 111,540 tokens
val has shape (111540,)


In [29]:
block_size = 8
train_data[: block_size + 1]

Array([18, 47, 56, 57, 58,  1, 15, 47, 58], dtype=int32)

In [30]:
x = train_data[:block_size]
y = train_data[1 : block_size + 1]
for t in range(block_size):
    context = x[: t + 1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is [18] the target: 47
when input is [18 47] the target: 56
when input is [18 47 56] the target: 57
when input is [18 47 56 57] the target: 58
when input is [18 47 56 57 58] the target: 1
when input is [18 47 56 57 58  1] the target: 15
when input is [18 47 56 57 58  1 15] the target: 47
when input is [18 47 56 57 58  1 15 47] the target: 58


In [31]:
def batch_data(data, block_size=8):
    # Trim the data to ensure that it is divisible by block_size for proper reshaping
    num_complete_blocks = (data.size - 1) // block_size
    trimmed_data_size = num_complete_blocks * block_size + 1
    data = data[:trimmed_data_size]

    # Prepare the input data 'x'
    x = data[:-1].reshape(num_complete_blocks, block_size)

    # Prepare the target data 'y'
    # The target for each sequence in 'x' is the sequence shifted by one token
    y = data[1:].reshape(num_complete_blocks, block_size)
    return x, y

In [32]:
def data_target_relation(x_ds, y_ds, bs, block_size):
    data_size = len(x_ds)
    target_size = len(y_ds)
    steps_per_epoch = data_size // bs

    perms = jax.random.permutation(subkeys[0], data_size)
    perms = perms[: steps_per_epoch * bs]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, bs))

    for i, perm in enumerate(perms):
        print(f"inputs: {i}")
        print(x_ds[perm, ...].shape)
        print(f"{x_ds[perm, ...]}\n")
        print(f"targets: {i}")
        print(y_ds[perm, ...].shape)
        print(y_ds[perm, ...])
        print("----")
        for b in range(bs):  # batch dimension
            for t in range(block_size):  # time dimension
                context = x_ds[perm, ...][b, : t + 1]
                target = y_ds[perm, ...][b, t]
                print(f"when input is {context.tolist()} the target: {target}")
        print("--------")
        if i == 2:
            break

In [33]:
x_ds, y_ds = batch_data(data=train_data, block_size=block_size)
print(f"x_ds shape: {x_ds.shape}")
print(f"y_ds shape: {y_ds.shape}")

x_ds shape: (125481, 8)
y_ds shape: (125481, 8)


In [34]:
# data_target_relation(x_ds, y_ds, bs=4, block_size=8)

In [35]:
def initialize_rnn(cell_list, method: str, alpha: float = 2.0):
    assert method in ["legs", "legt", "lmu", "lagt", "fru", "fout", "foud"]
    init_fn = {
        "legs": legs,
        "legt": legt,
        "lmu": lmu,
        "lagt": lagt,
        "fru": fru,
        "fout": fout,
        "foud": foud,
        # "chebt": chebt,
    }[method]

    cell_args = [
        {
            "hippo_cell": HiPPOLTI,  # HiPPOLTICell,
            "hippo_args": {
                "step_size": step,
                "basis_size": T,
                "alpha": alpha,
                # "recon": False,
                "A_init": init_fn,
                "B_init": init_fn,
            },
            "tau_args": {
                "features": N,
                "bias": True,
                "gate_fn": sigmoid,
                "activation_fn": tanh,
                "dtype": jnp.float32,
            },
            # "mlp_args": {
            #     "features": [1],
            #     "activation_fn": relu,
            #     "bias": False,
            #     "dtype": jnp.float32,
            # },
            "_tau": LSTMCell,
            "bias": True,
            "dtype": jnp.float32,
        }
        for i in range(len(cell_list))
    ]

    rnn = CharRNN(
        vocab_size=vocab_size,
        hidden_size=N,
        rnn_cells=cell_list,
        cell_args=cell_args,
    )
    return rnn

In [36]:
def test_shaping():
    # Define the model
    model = initialize_rnn(
        cell_list=[BatchedGatedHiPPOCell, BatchedGatedHiPPOCell],
        method="legs",
        alpha=2.0,
    )

    # Define the input
    input_data = jnp.ones((batch_size, block_size), dtype=jnp.int32)
    print(f"input data shape: {input_data.shape}")

    # Initialize the model carries
    carries = model.initialize_carries(
        rng=subkeys[7], batch_size=batch_size, hidden_sizes=[N, N]
    )
    print(f"hidden shape: {carries}")
    print(f"hidden shape: {carries[-1][0].shape}")

    # Initialize the model parameters
    params = model.init(
        subkeys[1],
        x=input_data,
        carry=carries,
        targets=None,
    )

    # Initialize the model carries
    carries = model.initialize_carries(
        rng=subkeys[8], batch_size=batch_size, hidden_sizes=[N, N]
    )
    print(f"hidden shape: {carries[-1][0].shape}")

    # Apply the model to the input
    output, new_carries = model.apply(
        params, x=input_data, carry=carries, targets=None
    )

    # Check the output shape
    print(f"output shape: {output.shape}")
    assert output.shape == (batch_size, block_size, vocab_size)

In [37]:
test_shaping()

input data shape: (64, 8)
hidden shape: [(Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), (Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., .

## Training
The cells below will be dedicated towards training a character-level RNN on a given dataset

In [38]:
def apply_model(state, carries, input, target, vocab_size, train=True):
    def loss_fn(params):
        # Use targets for teacher forcing if training
        targets = target if train else None
        logits, new_carries = state.apply_fn(
            params, x=input, carry=carries, targets=targets
        )

        # One-hot encode the target with the correct vocabulary size
        one_hot = jax.nn.one_hot(target, vocab_size)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, (logits, new_carries)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (logits, new_carries)), grads = grad_fn(state.params)

    # Ensure accuracy calculation matches the target structure
    accuracy = jnp.mean(jnp.argmax(logits, -1) == target)
    return grads, loss, accuracy, jnp.argmax(logits, axis=-1)

if JIT:
    apply_model = partial(jax.jit, static_argnames=["vocab_size"])(apply_model)

In [39]:
def update_model(state, grad):
    return state.apply_gradients(grads=grad)

if JIT:
    update_model = jax.jit(update_model)

In [40]:
def train_epoch(rng, model, train_ds, state, vocab_size, batch_size=64):
    """Train for a single epoch."""
    train_x, train_y = train_ds
    rng, permute_rng = jax.random.split(rng, 2)
    data_size = len(train_x)
    steps_per_epoch = data_size // batch_size
    perms = jax.random.permutation(subkeys[0], data_size)
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    epoch_loss = []
    epoch_accuracy = []
    for perm in tqdm(perms):
        input = train_x[perm, ...]
        target = train_y[perm, ...]
        rng, carry_rng = jax.random.split(rng, 2)
        carries = model.initialize_carries(
            rng=carry_rng, batch_size=batch_size, hidden_sizes=[N, N]
        )
        # print('Apply model')
        grads, loss, accuracy, _ = apply_model(
            state=state,
            carries=carries,
            input=input,
            target=target,
            vocab_size=vocab_size,
        )
        # print('Update model')
        state = update_model(state, grads)
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)
    train_loss = jnp.mean(jnp.array(epoch_loss))
    train_accuracy = jnp.mean(jnp.array(epoch_accuracy))
    return state, train_loss, train_accuracy

In [41]:
def train(
    model,
    epochs,
    train_data,
    test_data,
    block_size,
    vocab_size,
    learning_rate=0.001,
    batch_size=64,
):
    print('Checkpoints')
    # Define the directory where checkpoints will be stored
    checkpoint_dir = Path("/tmp/my_checkpoints")
    options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=2)
    checkpoint_manager = ocp.CheckpointManager(
        checkpoint_dir, {"model_state": ocp.PyTreeCheckpointer()}, options=options
    )

    # Create the dataset
    print('Creating dataset')
    train_x, train_y = batch_data(data=train_data, block_size=block_size)
    test_x, test_y = batch_data(data=test_data, block_size=block_size)
    train_x_size = len(train_x)
    steps_per_epoch = train_x_size // batch_size
    decay_steps = steps_per_epoch * epochs

    # Intialize the scheduler
    print('Initializing scheduler, optimizer, model state')
    schedule = optax.warmup_cosine_decay_schedule(
        init_value=2e-3,
        peak_value=2e-5,
        warmup_steps=10,
        decay_steps=decay_steps,
        end_value=0.0,
    )

    # Initialization of the optimizer
    optimizer = optax.adam(learning_rate)

    # Initialize the model state
    carries = model.initialize_carries(
        rng=subkeys[7], batch_size=batch_size, hidden_sizes=[N, N]
    )

    print('Checkpoints')
    state = None
    starting_epoch = 0
    if checkpoint_dir.exists():
        lastest_step = checkpoint_manager.latest_step()
        if lastest_step is not None:
            # Restore the latest checkpoint
            restored = checkpoint_manager.restore(lastest_step)
            if restored is not None:
                # Unpack the checkpointed state and set the starting epoch
                state, starting_epoch = restored["model_state"], restored["epoch"]
                print(f"Resuming training from epoch {starting_epoch}")
            else:
                raise ValueError(
                    f"restored is None valued despite the checkpoint directory existing"
                )

        else:
            checkpoint_dir.mkdir(
                parents=True, exist_ok=True
            )  # Ensure the checkpoint directory is created

            # Initialize the model parameters
            params = model.init(
                subkeys[1],
                x=jnp.ones((batch_size, block_size), dtype=jnp.int32),
                carry=carries,
                targets=None,
            )

            state = train_state.TrainState.create(
                apply_fn=model.apply, params=params, tx=optimizer
            )
            starting_epoch = 0
            print("Starting training from scratch")

    print('Epochs')
    rng = subkeys[7]
    for epoch in range(starting_epoch, epochs):
        print(f'Training epoch {epoch} / {epochs}')

        rng, input_rng, carry_rng, test_rng = jax.random.split(rng, 4)
        state, loss, accuracy = train_epoch(
            rng=input_rng,
            model=model,
            train_ds=(train_x, train_y),
            state=state,
            vocab_size=vocab_size,
            batch_size=batch_size,
        )

        print('Testing epoch {epoch} / {epochs}')

        indices = jax.random.randint(
            test_rng,
            shape=(batch_size,),
            minval=0,
            maxval=(test_x.shape[0] - 1),
        )
        test_input = test_x[indices, ...]
        test_target = test_y[indices, ...]
        carries = model.initialize_carries(
            rng=carry_rng, batch_size=batch_size, hidden_sizes=[N, N]
        )
        _, test_loss, test_accuracy, test_logits = apply_model(
            state=state,
            carries=carries,
            input=test_input,
            target=test_target,
            vocab_size=vocab_size,
        )
        print(
            f"Epoch: {epoch}\n\tTrain Loss: {loss}\n\tTrain Accuracy: {accuracy}\n\tTest Loss: {test_loss}\n\tTest Accuracy: {test_accuracy}\n\n"
        )
        print(f"Sample Model Text Output:")
        print(f"------------------------------------------------------------")
        print(f"{decode(test_logits[0].tolist())}")
        print(f"------------------------------------------------------------")
        checkpoint_manager.save(epoch, {"model_state": state})

In [42]:
# hippo_list = [
#     {
#         "name": "LegS-LSI",
#         "use": False,
#         "val": "legs",
#     },
#     {
#         "name": "LegS-LTI",
#         "use": True,
#         "val": "legs",
#     },
#     {
#         "name": "LegT",
#         "use": False,
#         "val": "legt",
#     },
#     {
#         "name": "LMU",
#         "use": False,
#         "val": "lmu",
#     },
#     {
#         "name": "LagT",
#         "use": False,
#         "val": "lagt",
#     },
#     {
#         "name": "FRU",
#         "use": False,
#         "val": "fru",
#     },
#     {
#         "name": "FouT",
#         "use": False,
#         "val": "fout",
#     },
#     {
#         "name": "FouD",
#         "use": False,
#         "val": "foud",
#     },
#     {
#         "name": "ChebT",
#         "use": False,
#         "val": "chebt",
#     },
# ]

In [43]:
# discretization = [
#     {
#         "name": "Forward-Euler",
#         "use": False,
#         "val": 0.0,
#     },
#     {
#         "name": "Backward-Euler",
#         "use": False,
#         "val": 1.0,
#     },
#     {
#         "name": "Bilinear",
#         "use": False,
#         "val": 0.5,
#     },
#     {
#         "name": "Zero-Order Hold",
#         "use": True,
#         "val": 2.0,
#     },
# ]

In [44]:
cell_list = [BatchedGatedHiPPOCell, BatchedGatedHiPPOCell]

In [45]:
class Discretizations:
    FORWARD_EULER, BACKWARD_EULER, BILINEAR, ZOH = 0.0, 1.0, 0.5, 2.0

print('Initializing HiPPO-RNN with: LegS-LTI and ZOH...')
model = initialize_rnn(cell_list=cell_list, method='legs', alpha=Discretizations.ZOH)
print('... Done!')

Initializing HiPPO-RNN with: LegS-LTI and ZOH...
... Done!


In [46]:
train(
    model=model,
    epochs=epochs,
    train_data=train_data,
    test_data=val_data,
    block_size=_block_size,
    vocab_size=vocab_size,
    learning_rate=lr,
    batch_size=batch_size,
)


Checkpoints
Creating dataset
Initializing scheduler, optimizer, model state
Checkpoints
Starting training from scratch
Epochs
Training epoch 0 / 3


100%|██████████| 30/30 [48:32<00:00, 97.09s/it] 


Testing epoch {epoch} / {epochs}
Epoch: 0
	Train Loss: 3.9151813983917236
	Train Accuracy: 0.15810446441173553
	Test Loss: 3.3927512168884277
	Test Accuracy: 0.149566650390625


Sample Model Text Output:
------------------------------------------------------------
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
------------------------------------------------------------
Training epoch 1 / 3


100%|██████████| 30/30 [47:36<00:00, 95.23s/it]


Testing epoch {epoch} / {epochs}
Epoch: 1
	Train Loss: 3.1012654304504395
	Train Accuracy: 0.20543111860752106
	Test Loss: 2.8244597911834717
	Test Accuracy: 0.257843017578125


Sample Model Text Output:
------------------------------------------------------------
 te e e  t et tee ee te   t  e ee te et eee  te t t  te   t   te eee e   ee  t e eeee  eee e ee ee   t  e eeeeeeei iiiix ti  te te t et e  t e eeee eeeeeiix ii t  e t  t    eeeeee tiixxx t  e t  t    eeeee   te t  e t  e ee te e  te e t  t    eee t t e t  te e  t  e ee ee   e  ee eeeeeeiix ii te t    e  t e eeeeee iiiix ti  te teeeee tiixxx te tet ee teee t  eeeeeee iiiix t    t e   teee eeeeeixx ii t  et ee   eeee e ee   e   e te  t    eeet e e t  e t  eeeeeeeeiiixxx ti t eeeeee ee e te  e te   te t    te 
------------------------------------------------------------
Training epoch 2 / 3


100%|██████████| 30/30 [47:07<00:00, 94.26s/it]


Testing epoch {epoch} / {epochs}
Epoch: 2
	Train Loss: 2.4194655418395996
	Train Accuracy: 0.36268311738967896
	Test Loss: 1.9905269145965576
	Test Accuracy: 0.556060791015625


Sample Model Text Output:
------------------------------------------------------------
 she a ordinin enini   where too  aaeni we es weet toaethers hen  o wons oe the theni that weens theer w  oen horih  itt e were a oos a eet weth  itt e wen iniet eit ene a sth we   w oo ort were and a  anio i to her and so she wee  s to heinio  i an  orih and woo not  iee a wasee

iii ii ie
 e   wanst thor wooi and haeen we tha saeenini t we thor aroi  wor sone  nhaeen wordsin
ii iiiiiieni i to the a ooai as worntaens are wor wen sen het shaie note thorah then w oo aeriet a  oin
iii ii ie
ior  ooi wo w ien 
------------------------------------------------------------


In [47]:
# for hippo in hippo_list:
#     if hippo["use"]:
#         alpha = -1.0
#         for a in discretization:
#             if a["use"]:
#                 alpha = a["val"]
#         print(f"Running HiPPO-RNN with {hippo['name']} and {alpha}")
#         model = initialize_rnn(cell_list=cell_list, method=hippo["val"], alpha=alpha)
#         train(
#             model=model,
#             epochs=epochs,
#             train_data=train_data,
#             test_data=val_data,
#             block_size=_block_size,
#             vocab_size=vocab_size,
#             learning_rate=lr,
#             batch_size=batch_size,
#         )
#         break