🔸 https://github.com/BeeGass/HiPPO-Jax

In [1]:
! which python

/root/miniconda3/envs/py311/bin/python


In [2]:
! pip install -q --upgrade pip
! pip install -q einops
! pip install -q tqdm

! pip install jax 
! pip install -q jaxlib flax jaxtyping typing-extensions

[0m

In [2]:
# Capture the current PATH
CURR_PATH = !echo $PATH
CURR_PATH = CURR_PATH[0]  # Get the string value

# Set the new PATH
%env PATH=/usr/local/cuda-12.3/bin:{CURR_PATH}

# Capture the current LD_LIBRARY_PATH
CURR_LD_LIB_PATH = !echo $LD_LIBRARY_PATH
CURR_LD_LIB_PATH = CURR_LD_LIB_PATH[0] if CURR_LD_LIB_PATH else ""

# Set the new LD_LIBRARY_PATH
%env LD_LIBRARY_PATH=/usr/local/cuda-12.3/lib64:{CURR_LD_LIB_PATH}


env: PATH=/usr/local/cuda-12.3/bin:/root/miniconda3/envs/py311/bin:/root/miniconda3/condabin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
env: LD_LIBRARY_PATH=/usr/local/cuda-12.3/lib64:


In [3]:
! echo $PATH

/usr/local/cuda-12.3/bin:/root/miniconda3/envs/py311/bin:/root/miniconda3/condabin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin


In [4]:
import os
import sys
import warnings
from pathlib import Path

from tqdm import tqdm

In [5]:
JIT = True  # Set to False to disable JIT

if JIT:
    # TODO: set JIT=True and inspect this logfile
    # takes > 30min to generate on my macbook
    %env XLA_FLAGS=--xla_dump_to=/tmp/why_is_this_slow.txt


env: XLA_FLAGS=--xla_dump_to=/tmp/why_is_this_slow.txt


In [6]:
# module_path = os.path.abspath(os.path.join("../../../../"))
# print(f"module_path: {module_path}")
# if module_path not in sys.path:
#     print(f"Adding {module_path} to sys.path")
#     sys.path.append(module_path)

In [7]:
warnings.filterwarnings("ignore")

In [8]:
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"

## import packages

In [9]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.activation import sigmoid, tanh, relu
import numpy as np
# import torch
from flax.training import train_state  # , orbax_utils
import orbax.checkpoint as ocp
import optax
from jaxtyping import install_import_hook

In [10]:
from functools import partial

In [11]:
print(jax.devices())

[cuda(id=0)]


In [12]:
! nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Fri_Jan__6_16:45:21_PST_2023
Cuda compilation tools, release 12.0, V12.0.140
Build cuda_12.0.r12.0/compiler.32267302_0


In [13]:
! nvidia-smi

Thu Dec 28 22:11:11 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 2080 Ti     On  | 00000000:28:00.0 Off |                  N/A |
| 60%   30C    P2              N/A /  N/A |    919MiB / 11264MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [14]:
! dpkg -l | grep cuda

ii  cuda-cccl-12-0                  12.0.140-1                        amd64        CUDA CCCL
ii  cuda-command-line-tools-12-0    12.0.1-1                          amd64        CUDA command-line tools
ii  cuda-compat-12-0                525.147.05-1                      amd64        CUDA Compatibility Platform
ii  cuda-compiler-12-0              12.0.1-1                          amd64        CUDA compiler
ii  cuda-cudart-12-0                12.0.146-1                        amd64        CUDA Runtime native Libraries
ii  cuda-cudart-dev-12-0            12.0.146-1                        amd64        CUDA Runtime native dev links, headers
ii  cuda-cuobjdump-12-0             12.0.140-1                        amd64        CUDA cuobjdump
ii  cuda-cupti-12-0                 12.0.146-1                        amd64        CUDA profiling tools runtime libs.
ii  cuda-cupti-dev-12-0             12.0.146-1                        amd64        CUDA profiling tools interface.
ii  cuda-cuxxfilt-12-0    

In [15]:
# Do this one-time if you want CUDA, then restart/rerun notebook
INSTALL_JAX_WITH_CUDA = False
if INSTALL_JAX_WITH_CUDA:
    # instruction from https://github.com/google/jax?tab=readme-ov-file#installation:
    ! pip install -U "jax[cuda12_pip]" -f  \
        https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [31]:
import jax

jax.devices()

[cuda(id=0)]

In [32]:
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

The Device: gpu


In [33]:
# print(f"MPS available: {torch.backends.mps.is_available()}")

In [34]:
# torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

In [35]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [36]:
num_copies = 20
subkeys = jax.random.split(key, num=num_copies)
key = subkeys[0]

In [37]:
from hippo_live import HiPPOCell, HiPPOLTI

from cells_live import LSTMCell, BatchedGatedHiPPOCell, CharRNN

from trans import initializer, legt, legs, lmu, lagt, fru, fout, foud

## Parameters For Generating Data

In [38]:
T = 1
freq = 10
step = 1 / (28 * 28)  # 1e-3
L = int(T / step)

Parameters For Training

In [39]:
batch_size = 64
data_size = L
input_size = 1
_block_size = 512

In [40]:
num_sequences = 100
epochs = 10
lr = 0.001

Parameters For HiPPO

In [41]:
N = 128

# Dataset (TinyShakespeare)

In [42]:
with open("../datasets/shakespeare.txt", "r", encoding="latin-1") as f:
    text = f.read()


In [43]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [44]:
print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


get all the unique characters that occur in this text

In [45]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("all the unique characters:", "".join(chars))
print(f"vocab size: {vocab_size:,}")

all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65


create a mapping from characters to integers

In [46]:
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

In [47]:
def encode(s):
    int_list = []
    for c in s:
        if c not in stoi:
            raise ValueError(f"character {c} not in vocabulary")
        else:
            int_list.append(
                stoi[c]
            )  # encoder: take a string, output a list of integers
    return int_list

In [48]:
def decode(l):
    str_list = []
    for i in l:
        if i not in itos:
            raise ValueError(f"integer {i} not in the vocabulary")
        else:
            str_list.append(
                itos[i]
            )  # decoder: take a list of integers, output a list of characters
    return "".join(str_list)  # take a list of integers, output a string

In [49]:
txt = "The fox jumps over the lazy dog."
print(encode(txt))
print(decode(encode(txt)))

[32, 46, 43, 1, 44, 53, 62, 1, 48, 59, 51, 54, 57, 1, 53, 60, 43, 56, 1, 58, 46, 43, 1, 50, 39, 64, 63, 1, 42, 53, 45, 8]
The fox jumps over the lazy dog.


In [52]:
data = jnp.array(encode(text))
print(data.shape)
print(data[:100])

(1115394,)
[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43 39
 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 37
 53 59]


Let's now split up the data into train and validation sets

In [53]:
n = int(0.9 * len(data))  # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
print(f"train has {len(train_data):,} tokens")
print(f"train has shape {train_data.shape}\n")
print(f"val has {len(val_data):,} tokens")
print(f"val has shape {val_data.shape}")

train has 1,003,854 tokens
train has shape (1003854,)

val has 111,540 tokens
val has shape (111540,)


In [54]:
block_size = 8
train_data[: block_size + 1]

Array([18, 47, 56, 57, 58,  1, 15, 47, 58], dtype=int32)

In [55]:
x = train_data[:block_size]
y = train_data[1 : block_size + 1]
for t in range(block_size):
    context = x[: t + 1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is [18] the target: 47
when input is [18 47] the target: 56
when input is [18 47 56] the target: 57
when input is [18 47 56 57] the target: 58
when input is [18 47 56 57 58] the target: 1
when input is [18 47 56 57 58  1] the target: 15
when input is [18 47 56 57 58  1 15] the target: 47
when input is [18 47 56 57 58  1 15 47] the target: 58


In [56]:
def batch_data(data, block_size=8):
    # Trim the data to ensure that it is divisible by block_size for proper reshaping
    num_complete_blocks = (data.size - 1) // block_size
    trimmed_data_size = num_complete_blocks * block_size + 1
    data = data[:trimmed_data_size]

    # Prepare the input data 'x'
    x = data[:-1].reshape(num_complete_blocks, block_size)

    # Prepare the target data 'y'
    # The target for each sequence in 'x' is the sequence shifted by one token
    y = data[1:].reshape(num_complete_blocks, block_size)
    return x, y

In [79]:
qq=batch_data(train_data, block_size=512)
len(qq[0])

1960

In [80]:
decode(qq[0][1].tolist())

'he patricians good.\nWhat authority surfeits on would relieve us: if they\nwould yield us but the superfluity, while it were\nwholesome, we might guess they relieved us humanely;\nbut they think we are too dear: the leanness that\nafflicts us, the object of our misery, is as an\ninventory to particularise their abundance; our\nsufferance is a gain to them Let us revenge this with\nour pikes, ere we become rakes: for the gods know I\nspeak this in hunger for bread, not in thirst for revenge.\n\nSecond Citizen:\nWould yo'

In [82]:
def data_target_relation(x_ds, y_ds, bs, block_size):
    data_size = len(x_ds)
    target_size = len(y_ds)
    steps_per_epoch = data_size // bs

    perms = jax.random.permutation(subkeys[0], data_size)
    perms = perms[: steps_per_epoch * bs]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, bs))

    for i, perm in enumerate(perms):
        print(f"inputs: {i}")
        print(x_ds[perm, ...].shape)
        print(f"{x_ds[perm, ...]}\n")
        print(f"targets: {i}")
        print(y_ds[perm, ...].shape)
        print(y_ds[perm, ...])
        print("----")
        for b in range(bs):  # batch dimension
            for t in range(block_size):  # time dimension
                context = x_ds[perm, ...][b, : t + 1]
                target = y_ds[perm, ...][b, t]
                print(f"when input is {context.tolist()} the target: {target}")
        print("--------")
        if i == 2:
            break

In [73]:
block_size

8

In [83]:
x_ds, y_ds = batch_data(data=train_data, block_size=block_size)
print(f"x_ds shape: {x_ds.shape}")
print(f"y_ds shape: {y_ds.shape}")

x_ds shape: (125481, 8)
y_ds shape: (125481, 8)


In [76]:
decode(x_ds[30,:].tolist())

'e people'

In [85]:
# data_target_relation(x_ds, y_ds, bs=4, block_size=8)

In [43]:
def initialize_rnn(cell_list, method: str, alpha: float = 2.0):
    assert method in ["legs", "legt", "lmu", "lagt", "fru", "fout", "foud"]
    init_fn = {
        "legs": legs,
        "legt": legt,
        "lmu": lmu,
        "lagt": lagt,
        "fru": fru,
        "fout": fout,
        "foud": foud,
        # "chebt": chebt,
    }[method]

    cell_args = [
        {
            "hippo_cell": HiPPOLTI,  # HiPPOLTICell,
            "hippo_args": {
                "step_size": step,
                "basis_size": T,
                "alpha": alpha,
                # "recon": False,
                "A_init": init_fn,
                "B_init": init_fn,
            },
            "tau_args": {
                "features": N,
                "bias": True,
                "gate_fn": sigmoid,
                "activation_fn": tanh,
                "dtype": jnp.float32,
            },
            # "mlp_args": {
            #     "features": [1],
            #     "activation_fn": relu,
            #     "bias": False,
            #     "dtype": jnp.float32,
            # },
            "_tau": LSTMCell,
            "bias": True,
            "dtype": jnp.float32,
        }
        for i in range(len(cell_list))
    ]

    rnn = CharRNN(
        vocab_size=vocab_size,
        hidden_size=N,
        rnn_cells=cell_list,
        cell_args=cell_args,
    )
    return rnn

In [44]:
def test_shaping():
    # Define the model
    model = initialize_rnn(
        cell_list=[BatchedGatedHiPPOCell, BatchedGatedHiPPOCell],
        method="legs",
        alpha=2.0,
    )

    # Define the input
    input_data = jnp.ones((batch_size, block_size), dtype=jnp.int32)
    print(f"input data shape: {input_data.shape}")

    # Initialize the model carries
    carries = model.initialize_carries(
        rng=subkeys[7], batch_size=batch_size, hidden_sizes=[N, N]
    )
    print(f"hidden shape: {carries}")
    print(f"hidden shape: {carries[-1][0].shape}")

    # Initialize the model parameters
    params = model.init(
        subkeys[1],
        x=input_data,
        carry=carries,
        targets=None,
    )

    # Initialize the model carries
    carries = model.initialize_carries(
        rng=subkeys[8], batch_size=batch_size, hidden_sizes=[N, N]
    )
    print(f"hidden shape: {carries[-1][0].shape}")

    # Apply the model to the input
    output, new_carries = model.apply(
        params, x=input_data, carry=carries, targets=None
    )

    # Check the output shape
    print(f"output shape: {output.shape}")
    assert output.shape == (batch_size, block_size, vocab_size)

In [45]:
test_shaping()

input data shape: (64, 8)
hidden shape: [(Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), (Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., .

## Training
The cells below will be dedicated towards training a character-level RNN on a given dataset

In [46]:
def apply_model(state, carries, input, target, vocab_size, train=True):
    def loss_fn(params):
        # Use targets for teacher forcing if training
        targets = target if train else None
        logits, new_carries = state.apply_fn(
            params, x=input, carry=carries, targets=targets
        )

        # One-hot encode the target with the correct vocabulary size
        one_hot = jax.nn.one_hot(target, vocab_size)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, (logits, new_carries)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (logits, new_carries)), grads = grad_fn(state.params)

    # Ensure accuracy calculation matches the target structure
    accuracy = jnp.mean(jnp.argmax(logits, -1) == target)
    return grads, loss, accuracy, jnp.argmax(logits, axis=-1)

if JIT:
    apply_model = partial(jax.jit, static_argnames=["vocab_size"])(apply_model)

In [47]:
def update_model(state, grad):
    return state.apply_gradients(grads=grad)

if JIT:
    update_model = jax.jit(update_model)

In [48]:
def train_epoch(rng, model, train_ds, state, vocab_size, batch_size=64):
    """Train for a single epoch."""
    train_x, train_y = train_ds
    rng, permute_rng = jax.random.split(rng, 2)
    data_size = len(train_x)
    steps_per_epoch = data_size // batch_size
    perms = jax.random.permutation(subkeys[0], data_size)
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    epoch_loss = []
    epoch_accuracy = []
    for perm in tqdm(perms):
        input = train_x[perm, ...]
        target = train_y[perm, ...]
        rng, carry_rng = jax.random.split(rng, 2)
        carries = model.initialize_carries(
            rng=carry_rng, batch_size=batch_size, hidden_sizes=[N, N]
        )
        # print('Apply model')
        grads, loss, accuracy, _ = apply_model(
            state=state,
            carries=carries,
            input=input,
            target=target,
            vocab_size=vocab_size,
        )
        # print('Update model')
        state = update_model(state, grads)
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)
    train_loss = jnp.mean(jnp.array(epoch_loss))
    train_accuracy = jnp.mean(jnp.array(epoch_accuracy))
    return state, train_loss, train_accuracy

In [49]:
def train(
    model,
    epochs,
    train_data,
    test_data,
    block_size,
    vocab_size,
    learning_rate=0.001,
    batch_size=64,
):
    # print('Checkpoints')
    # # Define the directory where checkpoints will be stored
    # checkpoint_dir = Path("/tmp/my_checkpoints")
    # options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=2)
    # checkpoint_manager = ocp.CheckpointManager(
    #     checkpoint_dir, {"model_state": ocp.PyTreeCheckpointer()}, options=options
    # )

    # Create the dataset
    print('Creating dataset')
    train_x, train_y = batch_data(data=train_data, block_size=block_size)
    test_x, test_y = batch_data(data=test_data, block_size=block_size)
    train_x_size = len(train_x)
    steps_per_epoch = train_x_size // batch_size
    decay_steps = steps_per_epoch * epochs

    # Intialize the scheduler
    print('Initializing scheduler, optimizer, model state')
    schedule = optax.warmup_cosine_decay_schedule(
        init_value=2e-3,
        peak_value=2e-5,
        warmup_steps=10,
        decay_steps=decay_steps,
        end_value=0.0,
    )

    # Initialization of the optimizer
    optimizer = optax.adam(learning_rate)

    # Initialize the model state
    carries = model.initialize_carries(
        rng=subkeys[7], batch_size=batch_size, hidden_sizes=[N, N]
    )

    print('Checkpoints')
    state = None
    starting_epoch = 0
    if False:  # checkpoint_dir.exists():
        lastest_step = checkpoint_manager.latest_step()
        if lastest_step is not None:
            # Restore the latest checkpoint
            restored = checkpoint_manager.restore(lastest_step)
            if restored is not None:
                # Unpack the checkpointed state and set the starting epoch
                state, starting_epoch = restored["model_state"], restored["epoch"]
                print(f"Resuming training from epoch {starting_epoch}")
            else:
                raise ValueError(
                    f"restored is None valued despite the checkpoint directory existing"
                )

    else:
        # checkpoint_dir.mkdir(
        #     parents=True, exist_ok=True
        # )  # Ensure the checkpoint directory is created

        # Initialize the model parameters
        params = model.init(
            subkeys[1],
            x=jnp.ones((batch_size, block_size), dtype=jnp.int32),
            carry=carries,
            targets=None,
        )

        state = train_state.TrainState.create(
            apply_fn=model.apply, params=params, tx=optimizer
        )
        starting_epoch = 0
        print("Starting training from scratch")

    # state = None
    # starting_epoch = 0

    print('Epochs')
    rng = subkeys[7]
    for epoch in range(starting_epoch, epochs):
        print(f'Training epoch {epoch} / {epochs}')

        rng, input_rng, carry_rng, test_rng = jax.random.split(rng, 4)
        state, loss, accuracy = train_epoch(
            rng=input_rng,
            model=model,
            train_ds=(train_x, train_y),
            state=state,
            vocab_size=vocab_size,
            batch_size=batch_size,
        )

        print('Testing epoch {epoch} / {epochs}')

        indices = jax.random.randint(
            test_rng,
            shape=(batch_size,),
            minval=0,
            maxval=(test_x.shape[0] - 1),
        )
        test_input = test_x[indices, ...]
        test_target = test_y[indices, ...]
        carries = model.initialize_carries(
            rng=carry_rng, batch_size=batch_size, hidden_sizes=[N, N]
        )
        _, test_loss, test_accuracy, test_logits = apply_model(
            state=state,
            carries=carries,
            input=test_input,
            target=test_target,
            vocab_size=vocab_size,
        )
        print(
            f"Epoch: {epoch}\n\tTrain Loss: {loss}\n\tTrain Accuracy: {accuracy}\n\tTest Loss: {test_loss}\n\tTest Accuracy: {test_accuracy}\n\n"
        )
        print(f"Sample Model Text Output:")
        print(f"------------------------------------------------------------")
        print(f"{decode(test_logits[0].tolist())}")
        print(f"------------------------------------------------------------")
        # checkpoint_manager.save(epoch, {"model_state": state})

In [50]:
# hippo_list = [
#     {
#         "name": "LegS-LSI",
#         "use": False,
#         "val": "legs",
#     },
#     {
#         "name": "LegS-LTI",
#         "use": True,
#         "val": "legs",
#     },
#     {
#         "name": "LegT",
#         "use": False,
#         "val": "legt",
#     },
#     {
#         "name": "LMU",
#         "use": False,
#         "val": "lmu",
#     },
#     {
#         "name": "LagT",
#         "use": False,
#         "val": "lagt",
#     },
#     {
#         "name": "FRU",
#         "use": False,
#         "val": "fru",
#     },
#     {
#         "name": "FouT",
#         "use": False,
#         "val": "fout",
#     },
#     {
#         "name": "FouD",
#         "use": False,
#         "val": "foud",
#     },
#     {
#         "name": "ChebT",
#         "use": False,
#         "val": "chebt",
#     },
# ]

In [51]:
# discretization = [
#     {
#         "name": "Forward-Euler",
#         "use": False,
#         "val": 0.0,
#     },
#     {
#         "name": "Backward-Euler",
#         "use": False,
#         "val": 1.0,
#     },
#     {
#         "name": "Bilinear",
#         "use": False,
#         "val": 0.5,
#     },
#     {
#         "name": "Zero-Order Hold",
#         "use": True,
#         "val": 2.0,
#     },
# ]

In [52]:
cell_list = [BatchedGatedHiPPOCell, BatchedGatedHiPPOCell]

In [53]:
class Discretizations:
    FORWARD_EULER, BACKWARD_EULER, BILINEAR, ZOH = 0.0, 1.0, 0.5, 2.0

print('Initializing HiPPO-RNN with: LegS-LTI and ZOH...')
model = initialize_rnn(cell_list=cell_list, method='legs', alpha=Discretizations.ZOH)
print('... Done!')

Initializing HiPPO-RNN with: LegS-LTI and ZOH...
... Done!


In [None]:
train(
    model=model,
    epochs=epochs,
    train_data=train_data,
    test_data=val_data,
    block_size=_block_size,
    vocab_size=vocab_size,
    learning_rate=lr,
    batch_size=batch_size,
)


Creating dataset
Initializing scheduler, optimizer, model state
Checkpoints
Starting training from scratch
Epochs
Training epoch 0 / 10


  0%|                                                                                                                                                               | 0/30 [00:00<?, ?it/s]

In [None]:
# for hippo in hippo_list:
#     if hippo["use"]:
#         alpha = -1.0
#         for a in discretization:
#             if a["use"]:
#                 alpha = a["val"]
#         print(f"Running HiPPO-RNN with {hippo['name']} and {alpha}")
#         model = initialize_rnn(cell_list=cell_list, method=hippo["val"], alpha=alpha)
#         train(
#             model=model,
#             epochs=epochs,
#             train_data=train_data,
#             test_data=val_data,
#             block_size=_block_size,
#             vocab_size=vocab_size,
#             learning_rate=lr,
#             batch_size=batch_size,
#         )
#         break