# Imports

In [1]:
# Basic Imports
try:
    from tqdm.auto import tqdm
except ModuleNotFoundError:
    %pip install --quiet tqdm
    from tqdm.auto import tqdm

try:
    from datasets import load_dataset
    from tokenizers import Tokenizer
    from tokenizers.models import BPE
    from tokenizers.trainers import BpeTrainer
    from tokenizers.pre_tokenizers import Whitespace
except ModuleNotFoundError:
    %pip install --quiet datasets
    %pip install --quiet tokenizers
    from datasets import load_dataset
    from tokenizers import Tokenizer
    from tokenizers.models import BPE
    from tokenizers.trainers import BpeTrainer
    from tokenizers.pre_tokenizers import Whitespace

try:
    import flax
    from flax import linen as nn
    from flax.training import train_state
    from flax.core import freeze, unfreeze
    from flax.linen.initializers import zeros, normal
except ModuleNotFoundError:
    %pip install --quiet flax
    import flax
    from flax import linen as nn
    from flax.training import train_state
    from flax.core import freeze, unfreeze
    from flax.linen.initializers import zeros, normal

try:
    import optax
except ModuleNotFoundError:
    %pip install --quiet optax
    import optax

try:
    from pythomata import SimpleDFA
except ModuleNotFoundError:
    %pip install --quiet pythomata
    from pythomata import SimpleDFA
try:
    from ml_collections import config_dict
except ModuleNotFoundError:
    %pip install --quiet ml_collections
    from ml_collections import config_dict


import os
import math
import copy
import numpy as np
import time
import pickle
import abc
import re
from functools import partial
import jax
import torch
from collections import Counter
from jax import  random, numpy as jnp, lax, vmap, jit, grad, value_and_grad, lib
from jax import random, vmap, pmap, custom_vjp, config as cfg, nn as jax_nn
from typing import List, Optional, Callable, Tuple, Dict, Any, Set, overload
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from typing import Dict, Tuple, Any, Callable, NamedTuple, Optional, Literal
import chex

jax.devices()

[CudaDevice(id=0)]

# Configs

In [2]:
def get_experiment_config(seeds:List[int]) -> config_dict.ConfigDict:
    '''Returns the configuration for the model and the experiment'''
    experiment_config = config_dict.ConfigDict()
    experiment_config.seeds = seeds
    experiment_config.data = data_configurator() # Data Configurations
    experiment_config.optim = optim_configurator() # Optimizer Configurations
    experiment_config.experiment = experiment_configurator() # Specific Experiment Configurations
    return experiment_config

def get_model_config(use_depth: int, use_gla: bool) -> config_dict.ConfigDict:
    '''Returns the model configuration for the given model description'''

    # you can easily either modify this config or just add new model_config functions

    def model_config(use_depth: int, use_gla: bool):
        '''Fully Observed Data Constructed Transformer (Interpolation)'''
        model_config = config_dict.ConfigDict()
        model_config.use_depth = use_depth
        model_config.is_discrete = True
        model_config.use_gla = False
        model_config.vocab_size = 20
        model_config.use_fwp=True       # use fast weight programmer implementation
        model_config.range_dfwp=1  # (DEPRECATED) How far to propagate kv pairs within model
        model_config.use_emb=True       # use linear embedding layer (default: False when training on constructed tokens)
        model_config.use_pe_emb=False    # concatenate PE to embeddings (default: False when training on constructed tokens)
        model_config.use_pe_kq=False    # concatenate PE to K and Q in attention (default: False when training on constructed tokens)
        model_config.hybrid_first_block=False   # This adds an additional first layer to the model with possibly different settings. Also: If you want to use pe-kq, you need to set this to True
        model_config.pe_dim=30          # positional encoding dimension (concatenated, if you want to change that, modify src/models/positional_encodings.py)
        model_config.out_dim=20         # output dimension
        model_config.initializer=jax_nn.initializers.normal(stddev=0.02)  # initializer for weights, in our experiments it wasn't necessary to adapt for deeper layers but instead use this as default
        model_config.use_layernorm=True    # use layernorm in transformer layers
        model_config.use_bias=False     # use bias in transformer layers (default = False)
        model_config.use_mlp=True      # use mlp in transformer layers
        model_config.masked=True        # causal masking in self-attention
        model_config.use_clip=False      # forward activation clipping
        model_config.clip_range=3       # clipping value [-clip_val, clip_val]
        model_config.num_layers=3       # number of transformer layers (without the optional first hybrid layer)
        model_config.num_heads=4        # number of heads in self-attention
        model_config.embed_dim=64       # embedding dimension
        model_config.key_size=20        # key size in self-attention
        model_config.seq_len=1024         # sequence length of input sequences (unused field in TF, can be used for debugging)
        model_config.dim_feedforward_MLP=256    # dimension of hidden layer in MLP
        model_config.linear=True        # use linear self-attention
        model_config.use_schlagnorm=True   # use schlagnorm in transformer layers (normalize K,Q)
        model_config.schlagnorm_targets=True   # schlagnorm also for Values
        return model_config

    return model_config(use_depth=use_depth, use_gla=use_gla)


def experiment_configurator() -> config_dict.ConfigDict:
    '''Returns the experiment configuration for the given experiment description'''
    experiment = config_dict.ConfigDict()
    return experiment


def optim_configurator() -> config_dict.ConfigDict:
    '''Returns the data configuration for the experiment'''                                                                                                                                                             # For nonlin-tf use 1e-3, for mesa 4e-4
    optim = config_dict.ConfigDict()
    optim.peak_lr = 1e-4   # Peak learning rate (or fixed if no scheduling)
    optim.grad_clip = 1     # Gradient clipping value
    optim.use_schedule = True   # Use learning rate scheduling
    optim.warmup_steps = 1000   # Warmup steps for learning rate scheduling
    optim.max_iters = 20000    # Maximum number of iterations for scheduling
    optim.init_value = 0    # Initial learning rate
    optim.end_value = 3e-5  # Final learning rate (at max_iters train steps)
    optim.weight_decay = 0.05   # Weight decay
    return optim

def data_configurator() -> config_dict.ConfigDict:
    '''Returns the data configuration for the experiment'''
    data = config_dict.ConfigDict()

    data.token_format = 'compact'
    data.batch_size = 32   # Training batch size
    data.test_batch_size = 32  # Test batch size
    data.seq_len = 1024   # Sequence length
    data.data_dim = 10  # Hidden Data dimension
    data.vocab_size = 20   # Vocabulary size
    data.obs_dim = 20   # Observed data dimension
    data.noise_obs = 0  # Observation noise level, Only relevant for partially obserable sequences
    data.noise = 0.01   # Noise level
    data.eye_obs = True # Use identity matrix as observation matrix
    data.data_clip = 10 # Clip data to [-data_clip, data_clip] (relevant for contracting sequences)
    data.range = 1    # Range of initial values (U-distr.)
    data.construction = True    # Use constructed tokens
    data.ttype = 'regbench' # 'seq_lin_constr' for linear sequences with constructed tokens, 'seq_lin' else. 'regbench' for - well - regbench.
    data.task_type = 'discrete' # Discrete if regbench/lang. else 'continous'
    data.embed_dim = 256     # Embedding dimension (for constructed tokens)
    return data

# Architecture and Components

## (Oldschool) Cosine-wave based Positional Encoding

In [3]:
class PositionalEncoding(nn.Module):
    '''Class implementing the Positional Encoding for Transformer models.

    Fields:
        'pe_dim' (int): Dimension of the positional encoding
        'max_len' (int): Maximum length of input sequences
    '''
    pe_dim : int
    max_len : int = 10
    concat: bool = False

    def concat_single(self, x: chex.Array, pe: chex.Array) -> chex.Array:
        '''Currently unused. Concatenates positional encoding to a single input tensor.'''
        return jnp.concatenate([x, pe], axis=-1)

    def concat_batch(self, x: chex.Array, pe: chex.Array) -> chex.Array:
        '''Currently unused. Concatenates positional encoding to a batch of input tensors.'''
        myfunConcat = partial(self.concat_single, pe=pe)
        return vmap(myfunConcat)(x)

    def setup(self):
        '''Initializes the positional encoding.'''
        enc_size = self.pe_dim
        pe = np.zeros((self.max_len, enc_size))
        position = np.arange(0, self.max_len, dtype=np.float32)[:,None]
        div_term = np.exp(np.arange(0, enc_size, 2) * (-math.log(10000.0) / enc_size))
        pe[:, 0::2] = np.sin(position * div_term)
        if enc_size % 2 == 1:
            pe[:, 1::2] = np.cos(position * div_term)[:,:-1]
        else:
            pe[:, 1::2] = np.cos(position * div_term)
        self.pe = jax.device_put(pe)

    def __call__(self, x: chex.Array) -> chex.Array:
        '''Adds positional encoding to the input tensor.'''
        if self.concat:
            x = self.concat_batch(x, self.pe)
        else:
            x = x + self.pe
        return x

## Conv1d Layer as in [RecurrentGemma](https://github.com/google-deepmind/recurrentgemma/blob/main/recurrentgemma/jax/layers.py)

In [4]:
class Conv1D(nn.Module):
    """1D temporal convolution layer with caching support for autoregressive use.

    Attributes:
        width: Number of input/output channels
        temporal_width: Size of temporal receptive field
        w_init_variance_scale: Scale factor for weight initialization variance
    """
    width: int
    temporal_width: int
    w_init_variance_scale: float = 0.01

    def setup(self):
        self.w = self.param(
            "w",
            nn.initializers.variance_scaling(
                scale=self.w_init_variance_scale,
                mode="fan_in",
                distribution="normal",
            ),
            (self.temporal_width, self.width)
        )
        self.b = self.param("b", nn.initializers.zeros_init(), (self.width,))

    def __call__(
        self,
        x: chex.Array,
        segment_pos: chex.Array,
        cache: Optional[chex.Array] = None,
        return_cache: bool = False,
    ) -> Tuple[chex.Array, Optional[chex.Array]]:
        """Apply temporal convolution with optional caching for autoregressive generation.

        Args:
            x: Input of shape [batch_size, seq_len, width]
            segment_pos: Position indicators of shape [batch_size, seq_len]
                       (0 indicates start of new segment)
            cache: Optional cache of previous states for autoregressive generation
                  shape: [batch_size, temporal_width-1, width]
            return_cache: Whether to return the updated cache

        Returns:
            output: Convolved sequence of shape [batch_size, seq_len, width]
            new_cache: Updated cache if return_cache=True, else None
        """
        output_len = x.shape[1]

        if cache is not None:
            # Autoregressive generation mode - use cached states
            x = self._concat_cache(x, cache)
            prompt_len = self.temporal_width - 1
            state_dtype = cache.dtype
        else:
            # Training mode - process full sequence
            prompt_len = 0
            state_dtype = x.dtype

        # Perform convolution with causal masking
        conv_out = 0.0
        temporal_width = min(self.temporal_width, prompt_len + output_len)

        for t in range(temporal_width):
            start_idx, end_idx = self._get_window_indices(
                prompt_len=prompt_len,
                shift_back=t,
                output_len=output_len,
            )
            x_window = x[:, start_idx:end_idx]

            if cache is None:
                # Apply segment masking in training mode
                window_mask = self._get_segment_mask(
                    segment_pos=segment_pos,
                    start_idx=start_idx,
                    end_idx=end_idx,
                    max_look_ahead=t,
                )
                x_window *= window_mask[:, :, None]

            # Pad if needed and apply convolution weights
            x_window = self._pad_to_length(x_window, output_len)
            w_t = self.w[self.temporal_width - t - 1][None, None, :]
            conv_out += x_window * w_t

        conv_out += self.b[None, None]

        if not return_cache:
            return conv_out, None

        # Update cache for next step
        new_cache = x[:, 1 - self.temporal_width:].astype(state_dtype)
        new_cache = self._pad_cache(new_cache)

        return conv_out, new_cache

    def _concat_cache(self, x: chex.Array, cache: chex.Array) -> chex.Array:
        """Concatenate current input with cache for autoregressive generation."""
        chex.assert_shape(cache, (x.shape[0], self.temporal_width - 1, self.width))
        chex.assert_shape(x, (None, 1, self.width))
        return jnp.concatenate([cache.astype(x.dtype), x], axis=1)

    def _get_window_indices(
        self, prompt_len: int, shift_back: int, output_len: int
    ) -> Tuple[int, int]:
        """Get start and end indices for convolution window."""
        start_idx = max(prompt_len - shift_back, 0)
        end_idx = prompt_len + output_len - shift_back
        return start_idx, end_idx

    def _get_segment_mask(
        self, segment_pos: chex.Array, start_idx: int, end_idx: int, max_look_ahead: int
    ) -> chex.Array:
        """Create mask to prevent information flow between segments."""
        batch_size = segment_pos.shape[0]
        not_boundary = (segment_pos != 0).astype(jnp.int32)
        mask = jnp.ones((batch_size, end_idx - start_idx))

        for shift in range(1, max_look_ahead + 1):
            mask *= not_boundary[:, start_idx + shift:end_idx + shift]
        return mask

    def _pad_to_length(self, x: chex.Array, target_len: int) -> chex.Array:
        """Left-pad input to target length if needed."""
        pad_len = target_len - x.shape[1]
        if pad_len <= 0:
            return x
        padding = jnp.zeros((x.shape[0], pad_len, x.shape[2]), dtype=x.dtype)
        return jnp.concatenate([padding, x], axis=1)

    def _pad_cache(self, state: chex.Array) -> chex.Array:
        """Left-pad cache to temporal width if needed."""
        pad_len = self.temporal_width - state.shape[1] - 1
        if pad_len <= 0:
            return state
        padding = jnp.zeros((state.shape[0], pad_len, state.shape[2]), dtype=state.dtype)
        return jnp.concatenate([padding, state], axis=1)

    @classmethod
    def init_cache(
        cls, batch_size: int, width: int, dtype: jnp.dtype, temporal_width: int = 4
    ) -> chex.Array:
        """Initialize an empty cache for autoregressive generation."""
        return jnp.zeros((batch_size, temporal_width - 1, width), dtype=dtype)

## Attention

In [5]:
class MultiHeadAttention(nn.Module):
    '''Multi-Head Attention Layer.

    Fields:
        'masked' (bool): Flag whether to use masked attention
        'embed_dim' (int): Embedding dimension
        'num_heads' (int): Number of heads
        'use_softmax' (bool): Flag whether to use softmax
        'use_bias' (bool): Flag whether to use bias
        'key_size' (int): Key size
        'initializer' Any: Initializer function
        'seq_len' (int): Sequence length
        'use_pe_kq' (bool): Flag whether to use (special) positional encoding in the query and key vectors
    '''
    masked : bool
    embed_dim : int
    num_heads : int
    use_softmax : bool
    use_bias : bool
    key_size : int
    initializer : Any
    seq_len : int
    layer_idx: int
    use_pe_kq : bool = False
    use_schlagnorm : bool = False
    schlagnorm_targets : bool = False

    def scaled_dot_product_attention(
        self,
        q: chex.Array,
        k: chex.Array,
        v: chex.Array,
        use_softmax: bool,
        masked=True
    ) -> chex.Array:
        '''Scaled Dot-Product Attention.

        Args:
            'q' (chex.Array): Query projections [batch, heads, seq_len, key_size]
            'k' (chex.Array): Key projections [batch, heads, seq_len, key_size]
            'v' (chex.Array): Value projections [batch, heads, seq_len, key_size]
            'use_softmax' (bool): Flag whether to use softmax-attention
            'masked' (bool): Flag whether to use causally masked attention
        Returns:
            Tuple[chex.Array]: Attention values and attention logits.
        '''
        d_k = q.shape[-1]

        attn_logits = jnp.einsum('bhsk,bhtk->bhst', q, k)

        if masked:
            mask = jnp.tril(jnp.ones(shape=(attn_logits.shape[-2:])))[None, None, :, :]
            attn_logits = jnp.where(mask == 0, 0, attn_logits)
            if use_softmax:
                attn_logits = jnp.where(mask == 0, -1e30, attn_logits/math.sqrt(d_k))

        if use_softmax:
            attn_logits = nn.softmax(attn_logits, axis=-1)

        values = jnp.einsum('bhst,bhtd->bhsd', attn_logits, v)

        return values, attn_logits

    def setup(self):
        '''Initializes the Multi-Head Attention Layer with the specified parameters.'''
        def _create_proj(
            name: str,
            num_heads: int,
            dim: int,
        ) -> nn.Module:
            '''Auxiliary function to create projection layers.'''
            return nn.DenseGeneral(
                    features=(
                        (num_heads, dim) if num_heads > 0 else dim
                    ),
                    axis=(
                        -1 if num_heads > 0 else (-1, -2)
                    ),
                    use_bias=self.use_bias,
                    kernel_init=self.initializer['attention'](self.layer_idx)[name],
            )

        proj_dim = self.num_heads, self.key_size
        proj_specs = {
            'queries': proj_dim,
            'keys': proj_dim,
            'values': proj_dim,
            'outs': (0, self.embed_dim),
        }

        for proj_name, output_dim in proj_specs.items():
            setattr(self, proj_name, _create_proj(proj_name, *output_dim))

        self.key_conv = Conv1D(
            width=self.key_size,
            temporal_width=4, #TODO: Make field
            w_init_variance_scale=0.01, #TODO: Make field
            name='conv'
        )

        if self.use_pe_kq:
            self.pos_enc_kq = PositionalEncoding(pe_dim=self.embed_dim, max_len=self.seq_len, concat=True)

    def __call__(self, x: chex.Array, outs_prev: chex.Array) -> chex.Array:
        '''Applies the Multi-Head Attention Layer to the input tensor.'''
        bs, sl, _ = x.shape

        if self.use_pe_kq:
            t = self.pos_enc_kq(x)
            k = self.keys(t)
            q = self.queries(t)
        else:
            k = self.keys(x)
            q = self.queries(x)
        v = self.values(x)

        batch_size, seq_len = k.shape[0], k.shape[1]
        keys_reshaped = k.transpose(0, 2, 1, 3)  # [batch, heads, seq, key_size]
        keys_flat = keys_reshaped.reshape(-1, seq_len, self.key_size)  # [batch*heads, seq, key_size]
        segment_pos = jnp.zeros((keys_flat.shape[0], seq_len))
        keys_conv, _ = self.key_conv(keys_flat, segment_pos)
        keys = keys_conv.reshape(batch_size, self.num_heads, seq_len, self.key_size)
        k = keys.transpose(0, 2, 1, 3)  # [batch, seq, heads, key_size]

        q,k,v = [jnp.einsum('bshd->bhsd', _elem) for _elem in [q,k,v]]

        if self.use_schlagnorm:
            q = q / (1e-16 + jnp.linalg.norm(q, axis=-1)[..., None])
            k = k / (1e-16 + jnp.linalg.norm(k, axis=-1)[..., None])
            if self.schlagnorm_targets:
                v = v / (1e-16 + jnp.linalg.norm(v, axis=-1)[..., None])

        values, attn_map = self.scaled_dot_product_attention(
            q=q, k=k, v=v, masked=self.masked, use_softmax=self.use_softmax
        )

        values = jnp.einsum('bhsd->bshd', values)
        o = self.outs(values)

        k, v = [x.swapaxes(-2, -3) for x in [k, v]]
        return o, attn_map, None, None

## Mesa-Layer

In [6]:
class DeepDeltaAttention(nn.Module):
    '''Linear Deep-Delta Fast Weight Programmer Attention Layer with basic sum update rule.

    Fields:
        'use_depth' (bool) Deep or standard Delta Rule
        'masked' (bool): Flag whether to use masked attention
        'embed_dim' (int): Embedding dimension
        'num_heads' (int): Number of heads
        'use_softmax' (bool): Flag whether to use softmax (ignored in fast weights)
        'use_bias' (bool): Flag whether to use bias
        'key_size' (int): Key size
        'initializer' (Callable): Initializer function
        'seq_len' (int): Sequence length
        'num_layers' (int) num of total model layers
        'use_pe_kq' (bool): Flag whether to use (special) positional encoding
        'use_schlagnorm' (bool): Flag for schlag normalization
        'schlagnorm_targets' (bool): Flag for schlag normalization targets
    '''
    use_depth: bool
    masked: bool
    embed_dim: int
    num_heads: int
    use_softmax: bool # Unused
    use_bias: bool
    key_size: int
    initializer: Callable
    seq_len: int
    range_dfwp: int # Unused
    layer_idx: int # Unused
    num_layers: int # Unused
    use_pe_kq: bool = False # Unused
    use_schlagnorm: bool = False
    schlagnorm_targets: bool = False

    def _get_projections(self, inputs: chex.Array, name: str, dim: int) -> chex.Array:
        '''Returns the projections of the input tensor.'''
        return nn.DenseGeneral(
            features=(self.num_heads, dim),
            axis=-1,
            use_bias=self.use_bias,
            kernel_init=self.initializer['attention'](self.layer_idx)[name],
            name=name
        )(inputs)

    @nn.compact
    def __call__(self, x, outs_prev, mask=None, deterministic=None):
        """Applies fast weight linear attention with basic sum update rule.

        Args:
            x: Query input of shape [batch..., seq_length, embed_dim]
            outs_prev: Prev-FW-Outputs-list of shape [batch, n_heads, seq_length, embed_dim]
            mask: Unused (since exists by construction) attention mask
            deterministic: unused (kept for compatibility)
        """

        queries, keys, values = [
            self._get_projections(
                inputs=x,
                name=n,
                dim=self.key_size,
            ) for n in ['queries', 'keys', 'values']
        ]

        key_conv = Conv1D(
            width=self.key_size,
            temporal_width=4, #TODO: Make field
            w_init_variance_scale=0.01, #TODO: Make field
            name='key_conv'
        )

        batch_size, seq_len_c = keys.shape[0], keys.shape[1]
        keys_reshaped = keys.transpose(0, 2, 1, 3)  # [batch, heads, seq, key_size]
        keys_flat = keys_reshaped.reshape(-1, seq_len_c, self.key_size)  # [batch*heads, seq, key_size]
        segment_pos = jnp.zeros((keys_flat.shape[0], seq_len_c))
        keys_conv, _ = key_conv(keys_flat, segment_pos)
        keys = keys_conv.reshape(batch_size, self.num_heads, seq_len_c, self.key_size)
        keys = keys.transpose(0, 2, 1, 3)  # [batch, seq, heads, key_size]

        gammas = jax.nn.sigmoid(self._get_projections(
                inputs=x,
                name='gammas',
                dim=1,
        ))

        if self.use_schlagnorm:
            queries = queries / (1e-16 + jnp.linalg.norm(queries, axis=-1, keepdims=True))
            keys = keys / (1e-16 + jnp.linalg.norm(keys, axis=-1, keepdims=True))
            if self.schlagnorm_targets:
                values = values / (1e-16 + jnp.linalg.norm(values, axis=-1,  keepdims=True))




        output = -1

        return output, None, None, None

## Gated Linear Attention

In [8]:
class GatedLinearAttention(nn.Module):
    '''Linear Fast Weight Programmer Attention Layer with basic sum update rule.

    Fields:
        'masked' (bool): Flag whether to use masked attention
        'embed_dim' (int): Embedding dimension
        'num_heads' (int): Number of heads
        'use_softmax' (bool): Flag whether to use softmax (ignored in fast weights)
        'use_bias' (bool): Flag whether to use bias
        'key_size' (int): Key size
        'initializer' (Callable): Initializer function
        'seq_len' (int): Sequence length
        'use_pe_kq' (bool): Flag whether to use (special) positional encoding
        'use_schlagnorm' (bool): Flag for schlag normalization
        'schlagnorm_targets' (bool): Flag for schlag normalization targets
    '''
    masked: bool
    embed_dim: int
    num_heads: int
    use_softmax: bool  # Will be ignored as we use ELU+1
    use_bias: bool
    key_size: int
    initializer: Callable
    seq_len: int
    range_dfwp: int
    layer_idx: int
    num_layers: int
    use_pe_kq: bool = False
    use_schlagnorm: bool = False
    schlagnorm_targets: bool = False

    def _get_projections(self, inputs: chex.Array, name: str) -> chex.Array:
        '''Returns the projections of the input tensor.'''
        return nn.DenseGeneral(
            features=(self.num_heads, self.key_size),
            axis=-1,
            use_bias=self.use_bias,
            kernel_init=self.initializer['attention'](self.layer_idx)[name],
            name=name
        )(inputs)

    @nn.compact
    def __call__(self, x, outs_prev=None, mask=None, deterministic=None):
        """Applies fast weight linear attention with basic sum update rule.

        Args:
            x: Query input of shape [batch..., seq_length, embed_dim]
            outs_prev: Unused head-output-list input of shape [batch..., seq_length, embed_dim]
            mask: Unused (since exists by construction) attention mask
            deterministic: unused (kept for compatibility)
        """

        query, key, value = [
            self._get_projections(
                inputs=ip,
                name=n,
            ) for (ip, n) in zip(3*[x],['query', 'key', 'value'])
        ]

        if self.use_schlagnorm:
            query = query / (1e-16 + jnp.linalg.norm(query, axis=-1)[..., None])
            key = key / (1e-16 + jnp.linalg.norm(key, axis=-1)[..., None])
            if self.schlagnorm_targets:
                value = value / (1e-16 + jnp.linalg.norm(value, axis=-1)[..., None])

        key_conv = Conv1D(
            width=self.key_size,
            temporal_width=4, #TODO: Make field
            w_init_variance_scale=0.01, #TODO: Make field
            name='key_conv'
        )


        batch_size, seq_len_c = keys.shape[0], keys.shape[1]
        keys_reshaped = keys.transpose(0, 2, 1, 3)  # [batch, heads, seq, key_size]
        keys_flat = keys_reshaped.reshape(-1, seq_len_c, self.key_size)  # [batch*heads, seq, key_size]
        segment_pos = jnp.zeros((keys_flat.shape[0], seq_len_c))
        keys_conv, _ = key_conv(keys_flat, segment_pos)
        keys = keys_conv.reshape(batch_size, self.num_heads, seq_len_c, self.key_size)
        keys = keys.transpose(0, 2, 1, 3)  # [batch, seq, heads, key_size]

        betas = [
            jax.nn.sigmoid(self._get_projections(
                inputs=x,
                name=n,
                dim=1,
            )) for n in ['betas']
        ]



        def single_head_update(
            keys_h: chex.Array,
            values_h: chex.Array,
            queries_h: chex.Array,
            betas_h: chex.Array,
        ) -> chex.Array:
            "Implements Deep-FWP for one head."

            chex.assert_rank([keys_h, values_h, queries_h, betas_h], 4*[3,])
            b, v, k = values_h.shape[0], values_h.shape[-1], keys.shape[-1]

            def step(carry, data):
                key_t, value_t, b_t = data
                cv_ck = jnp.einsum("bd,bvk->bvk", b_t, jnp.einsum("bv,bk->bvk", value_t, key_t))
                gated_carry = jnp.einsum("bd,bvk->bvk", b_t, carry)
                return 2*(b_t*gated_carry + cv_ck,)

            _, gla = jax.lax.scan(
                step,
                jnp.zeros(shape=(b, v, k)),
                tuple(x.swapaxes(0,1) for x in (keys_h, values_h, betas_h))
            )

            gla = gla.swapaxes(0, 1)

            return jnp.einsum("bsvk,bsk->bsv", gla, queries_h)

        all_head_preds = jax.vmap(single_head_update, in_axes=(4*(-2,)), out_axes=(-2))(keys, values, queries, betas)

        output = nn.DenseGeneral(
            features=self.embed_dim,
            axis=(-2, -1),
            use_bias=self.use_bias,
            kernel_init=self.initializer['attention'](self.layer_idx)['out'],
            name='out'
        )(all_head_preds)


        return output, None, None

## RGLRU (Griffin)

Directly stolen and adapted from [dtuneai](https://github.com/dtunai/Griffin-Jax/blob/main/griffin_jax/griffin_jax.py)

In [9]:
class RGLRU(nn.Module):
    '''Regularized Gated Linear Recurrent Unit Layer with multi-head support.

    Fields:
        'dim' (int): Input dimension
        'num_heads' (int): Number of attention heads
        'mult' (int): Hidden dimension multiplier per head
        'use_bias' (bool): Flag whether to use bias
        'initializer' (Callable): Initializer function
        'layer_idx' (int): Layer index in the network
        'c' (int): Timestep constant for gating mechanism
    '''
    dim: int
    num_heads: int
    mult: int
    use_bias: bool
    initializer: Callable
    layer_idx: int
    c: int = 8

    def _get_gate_params(self, name: str) -> Tuple[chex.Array, chex.Array]:
        '''Returns the weight matrix and bias vector for a gate.'''
        weight = self.param(
            f"W{name}",
            self.initializer['rglru'](self.layer_idx)[name],
            (self.num_heads, self.dim * self.mult, self.dim)
        )
        bias = self.param(
            f"b{name}",
            nn.initializers.zeros,
            (self.num_heads, self.dim * self.mult)
        ) if self.use_bias else None
        return weight, bias

    def _get_lambda_param(self) -> chex.Array:
        '''Returns the lambda parameter for decay rate calculation.'''
        def custom_lambda_init(key, shape):
            # Initialize lambda to give decay rates between 0.9 and 0.999
            a_c_values = random.uniform(key, shape, minval=0.9, maxval=0.999)
            return -jnp.log((1 / a_c_values) ** (1 / self.c) - 1)

        return self.param(
            "Lambda",
            custom_lambda_init,
            (self.num_heads, self.dim * self.mult)
        )

    def _single_head_rglru(
        self,
        x: chex.Array,
        Wa: chex.Array,
        Wx: chex.Array,
        ba: Optional[chex.Array],
        bx: Optional[chex.Array],
        Lambda: chex.Array,
    ) -> chex.Array:
        """Applies RGLRU transformation for a single head.

        Args:
            x: Input of shape [batch, seq_length, dim]
            Wa, Wx: Weight matrices for the gates
            ba, bx: Bias vectors for the gates (optional)
            Lambda: Decay rate parameter
        """
        batch_size, _, _ = x.shape
        ht = jnp.zeros((batch_size, self.dim * self.mult))

        def step(carry, inputs):
            ht = carry
            xt = inputs

            # Calculate gates
            rt = jax.nn.sigmoid(jnp.dot(xt, Wa) + ba) if self.use_bias else jax.nn.sigmoid(jnp.dot(xt, Wa))
            it = jax.nn.sigmoid(jnp.dot(xt, Wx) + bx) if self.use_bias else jax.nn.sigmoid(jnp.dot(xt, Wx))

            # Calculate decay rate
            a_t = jnp.exp(-self.c * jax.nn.softplus(-Lambda) * rt)

            # Update hidden state
            ht_new = a_t * ht + ((1 - a_t**2) ** 0.5) * (it * xt)

            return ht_new, ht_new

        # Scan over sequence
        _, outputs = jax.lax.scan(
            step,
            ht,
            x.swapaxes(0, 1)  # [seq_len, batch, dim]
        )

        return outputs.swapaxes(0, 1)  # [batch, seq_len, dim]

    @nn.compact
    def __call__(self, x, deterministic=None):
        """Applies multi-head RGLRU transformation.

        Args:
            x: Input of shape [batch..., seq_length, dim]
            deterministic: unused (kept for compatibility)
        """
        Wa, ba = self._get_gate_params('a')
        Wx, bx = self._get_gate_params('x')
        Lambda = self._get_lambda_param()

        # Process each head independently
        all_head_outputs = jax.vmap(
            self._single_head_rglru,
            in_axes=(None, 0, 0, 0 if self.use_bias else None, 0 if self.use_bias else None, 0),
            out_axes=1
        )(x, Wa, Wx, ba, bx, Lambda)

        # Project multi-head output back to original dimension
        output = nn.DenseGeneral(
            features=self.dim,
            axis=(-2, -1),
            use_bias=self.use_bias,
            kernel_init=self.initializer['rglru'](self.layer_idx)['out'],
            name='out'
        )(all_head_outputs)

        return output

## Transformer

### Transformer-Block

In [10]:
class TransformerBlock(nn.Module):
    '''Single Transformer Block Can be used with or without MESA or MLP or LayerNorms.

    Fields:
        'usa_gla' (bool) Use gated linear attention
        'use_depth' (bool): Deep or standard Delta Model
        'data_dim' (int): Dimension of the (original) data tokens
        'use_fwp' (bool): Flag to use fast weight programmer implementation
        'range_dfwp' (int): range of deep fast weight programmer (how far propagate kv-pairs)
        'embed_dim' (int): Dimension of the embeddings
        'key_size' (int): Size of the key vectors
        'num_heads' (int): Number of attention heads
        'dim_feedforward' (int): Dimension of the feedforward network
        'use_layer_norm' (bool): Whether to use LayerNorm
        'use_bias' (bool): Flag to use bias in the attention layer
        'use_softmax' (bool): Flag to use softmax in the attention layer
        'use_mlp' (bool): Flag to use the MLP
        'masked' (bool): Flag to use masking in the attention layer
        'mask_inputs' (bool): Flag to mask the input data
        'initializer' (Callable[[chex.PRNGKey, tuple, jnp.dtype], chex.Array]): Initializer for the weights
        'use_pe_kq' (bool): Flag to use positional encoding in the query and key vectors
        'seq_len' (int): Length of the input sequence
        'use_schlagnorm' (bool): Flag to use Schlag-Norm
        'schlagnorm_targets' (bool): Flag to use Schlag-Norm on the targets
    '''

    use_gla: bool
    use_depth: bool
    embed_dim: int
    layer_idx: int
    num_layers: int
    use_fwp: bool
    range_dfwp: int
    use_pe_kq: bool
    key_size : int
    num_heads : int
    dim_feedforward : int
    use_layer_norm : bool
    use_bias : bool
    use_softmax : bool
    use_mlp : bool
    masked : bool
    initializer : Any
    seq_len: int
    use_schlagnorm : bool
    schlagnorm_targets : bool

    def setup(self):
        '''Initializes the Transformer Block.'''
        print('Transformer Block:')

        # Attention Layer
        if self.use_fwp:
          if self.use_gla:
            print('Using GLA Programmer Attention')
            self.self_attn = GatedLinearAttention(
                num_heads=self.num_heads,
                embed_dim=self.embed_dim,
                masked=self.masked,
                use_softmax=self.use_softmax,
                use_bias=self.use_bias,
                key_size=self.key_size,
                initializer=self.initializer,
                use_pe_kq=self.use_pe_kq,
                seq_len=self.seq_len,
                use_schlagnorm=self.use_schlagnorm,
                schlagnorm_targets=self.schlagnorm_targets,
                range_dfwp=self.range_dfwp,
                layer_idx=self.layer_idx,
                num_layers=self.num_layers,
            )
          else:
            print('Using Delta Rule Attention')
            print(f'Leveraging depth T/F: {self.use_depth}')
            self.self_attn = DeepDeltaAttention(
                use_depth=self.use_depth,
                num_heads=self.num_heads,
                embed_dim=self.embed_dim,
                masked=self.masked,
                use_softmax=self.use_softmax,
                use_bias=self.use_bias,
                key_size=self.key_size,
                initializer=self.initializer,
                use_pe_kq=self.use_pe_kq,
                seq_len=self.seq_len,
                use_schlagnorm=self.use_schlagnorm,
                schlagnorm_targets=self.schlagnorm_targets,
                range_dfwp=self.range_dfwp,
                layer_idx=self.layer_idx,
                num_layers=self.num_layers,
            )
        else:
          print(f'Using standard-attention, using softmax: {self.use_softmax}')
          self.self_attn = MultiHeadAttention(
              num_heads=self.num_heads,
              embed_dim=self.embed_dim,
              masked=self.masked,
              use_softmax=self.use_softmax,
              use_bias=self.use_bias,
              key_size=self.key_size,
              initializer=self.initializer,
              use_pe_kq=False,
              seq_len=self.seq_len,
              use_schlagnorm=self.use_schlagnorm,
              schlagnorm_targets=self.schlagnorm_targets,
              layer_idx=self.layer_idx,
          )

        # Two-layer MLP Layer
        if self.use_mlp:
            print('Using MLP')
            self.linear = [
                nn.Dense(
                    features=self.dim_feedforward,
                    use_bias=self.use_bias,
                    kernel_init=self.initializer['mlp'](self.layer_idx)(
                        shape=(self.embed_dim, self.dim_feedforward),
                        dtype=jnp.float32
                    ),
                ),
                nn.gelu,
                nn.Dense(
                    features=self.embed_dim,
                    use_bias=self.use_bias,
                    kernel_init=self.initializer['mlp'](self.layer_idx)(
                        shape=(self.dim_feedforward, self.embed_dim),
                        dtype=jnp.float32
                    ),
                )
            ]

        # LayerNorm Layers
        if self.use_layer_norm:
            print('Using LayerNorm')
            self.norm1 = nn.RMSNorm()
            self.norm2 = nn.RMSNorm()

    def __call__(self, x: chex.Array, o_prev: chex.Array) -> chex.Array:
        '''Applies the Transformer Block to the input tensor.'''

        sa_x = self.norm1(x) if self.use_layer_norm else x
        attn_out, attention_map, o, g = self.self_attn(sa_x, o_prev)
        x = x + attn_out

        if self.use_mlp:
            mlp_out = self.norm2(x)  if self.use_layer_norm else x
            for l in self.linear:
                mlp_out = l(mlp_out)
            x = x + mlp_out

        return x, attention_map, o, g

### Decoder

In [11]:
class Decoder(nn.Module):
    '''List of Transformer Blocks.

    Fields:
        'use_gla' (bool): Use gated linear attention
        'use_depth' (bool): Deep or standard Delta Model
        'use_layernorm' (bool): Flag whether to use LayerNorm
        'use_fwp' (bool): Flag whether to use fast weight programmer implementation
        'range_dfwp' (int): range of deep fast weight programmer (how far propagate kv-pairs)
        'use_bias' (bool): Flag whether to use bias in the attention layer
        'use_mlp' (bool): Flag whether to use the MLP
        'masked' (bool): Flag whether to use masking in the attention layer
        'num_layers' (int): Number of Transformer Blocks
        'num_heads' (int): Number of attention heads
        'seq_len' (int): Length of the input sequence
        'embed_dim' (int): Dimension of the embeddings
        'key_size' (int): Size of the key vectors
        'dim_feedforward_MLP' (int): Dimension of the feedforward network
        'use_clip' (bool): Flag whether to clip the output
        'clip_range' (float): Range to clip the residual stream
        'linear' (bool): Flag whether to use a linear layer
        'initializer' Any: Initializer for the weights
        'use_schlag_norm' (bool): Flag whether to use Schlag-Norm (used in nonlinear experiments)
        'schlagnorm_targets' (bool): Flag whether to use Schlag-Norm on the targets (used in nonlinear experiments)
    '''

    use_gla: bool
    use_depth: bool
    use_layernorm: bool
    use_fwp: bool
    index_offset: int
    range_dfwp: int
    use_pe_kq: bool
    use_bias: bool
    use_mlp: bool
    masked: bool
    num_layers: int
    num_heads: int
    seq_len: int
    embed_dim: int
    key_size: int
    seq_len: int
    dim_feedforward_MLP: int
    use_clip: bool
    clip_range: float
    linear: bool
    initializer: Any
    use_schlagnorm: bool
    schlagnorm_targets: bool

    def setup(self):
        '''Initializes the list of Transformer Blocks.'''

        self.blocklist = [
            TransformerBlock(
                use_gla = self.use_gla,
                use_depth = self.use_depth,
                embed_dim = self.embed_dim,
                layer_idx = (self.index_offset + layer_idx),
                num_layers = (self.index_offset + self.num_layers),
                use_fwp = self.use_fwp,
                range_dfwp = self.range_dfwp,
                use_pe_kq=self.use_pe_kq,
                key_size = self.key_size,
                num_heads = self.num_heads,
                dim_feedforward = self.dim_feedforward_MLP,
                use_layer_norm = self.use_layernorm,
                use_bias = self.use_bias,
                use_softmax = (not self.linear),
                use_mlp = self.use_mlp,
                masked = self.masked,
                initializer = self.initializer,
                seq_len = self.seq_len,
                use_schlagnorm = self.use_schlagnorm,
                schlagnorm_targets = self.schlagnorm_targets,
            )
            for layer_idx in range(self.num_layers)
        ]

    def __call__(self,
                 x: chex.Array,
                 out_prev: List[chex.Array],
        ) -> chex.Array:
        ''' Applies the decoder to the input tensor.
        Args:
            x (chex.Array): Input tensor
            out_prev (chex.Array): Head-FW-Outputs from prev. layer
        Returns:
            chex.Array: Tuple[Output, Attentionmap per layer]
        '''

        residual_stream = x
        attention_maps = []
        gammas_list = []
        out_prev = jnp.zeros(
            shape=(
                *x.shape[:2],
                self.num_heads,
                self.key_size
            )
        ) if out_prev is None else out_prev

        print('Number of Decoder-Layers: ', self.num_layers)
        print(f'Using deep delta Rule T/F: {self.use_depth}')

        for layer in range(self.num_layers):
            residual_stream, att_map, out_prev, g_l = self.blocklist[layer](
                residual_stream, out_prev,
            )
            if self.use_clip:
                residual_stream = jnp.clip(residual_stream, (-1.0)*self.clip_range, self.clip_range)
            attention_maps.append(att_map)
            gammas_list.append(g_l)
        print('------')
        return residual_stream, attention_maps, gammas_list

### Full Transformer Model

In [12]:

class FullTransformerModel(nn.Module):
    '''
    Full Transformer Model
    Fields:
        'use_gla' (bool): Use gated linear attention
        'use_depth' (bool): Deep or standard Delta Model
        'use_emb' (bool): Flag whether to use an embedding layer
        'use_fwp' (bool): Flag whether to use fast weight programmer attention implementation.
        'range_dfwp' (int): range of deep fast weight programmer (how far propagate kv-pairs)
        'use_pe_emb' (bool): Flag whether to use positional encoding in the initial embeddings
        'hybrid_first_block' (bool): Flag whether to use a softmax first block
        'pe_dim' (int): Dimension of the positional encoding
        'out_dim' (int): Dimension of the output data
        'initializer' Any: Initializer for the weights
        'use_layernorm' (bool): Flag whether to use LayerNorm
        'use_bias' (bool): Flag whether to use bias in the attention layer
        'use_mlp' (bool): Flag whether to use the MLP
        'masked' (bool): Flag whether to use masking in the attention layer
        'use_clip' (bool): Flag whether to clip the output
        'clip_range' (float): Value to clip the output
        'num_layers' (int): Number of Transformer Blocks
        'num_heads' (int): Number of attention heads
        'embed_dim' (int): Dimension of the embeddings
        'key_size' (int): Size of the key vectors
        'seq_len' (int): Length of the input sequence
        'dim_feedforward_MLP' (int): Dimension of the feedforward network
        'linear' (bool): Flag whether to use a linear layer
        'use_schlag_norm' (bool): Flag whether to use Schlag-Norm (used in nonlinear experiments) ('..hyb' for hybrid block)
        'schlagnorm_targets' (bool): Flag whether to use Schlag-Norm on the targets (used in nonlinear experiments) ('..hyb' for hybrid block)
    '''

    use_gla: bool
    use_depth: bool
    use_emb: bool
    use_fwp: bool
    range_dfwp: int
    use_pe_kq: bool
    use_pe_emb: bool
    hybrid_first_block: bool
    pe_dim: int
    out_dim: int
    initializer: Any
    use_layernorm: bool
    use_bias: bool
    use_mlp: bool
    masked: bool
    use_clip: bool
    num_layers: int
    num_heads: int
    embed_dim: int
    key_size: int
    seq_len: int
    dim_feedforward_MLP: int
    clip_range: float
    linear: bool
    use_schlagnorm: bool
    schlagnorm_targets: bool

    # New fields for token tasks
    is_discrete: bool = False
    vocab_size: int = None
    pad_token_id: int = None

    def setup(self):
        '''Initializes the Full Transformer Model.'''
        if self.is_discrete:
            print(f'Using discrete input embeddings with num_embeddings (vocab): {self.vocab_size} and features: {self.embed_dim}')
            self.token_embedding = nn.Embed(
                num_embeddings=self.vocab_size,
                features=self.embed_dim,
                embedding_init=self.initializer['embedding_in'],
                name='token_embedding'
            )
        if not self.is_discrete and self.use_emb:  # Only for continuous case
            self.input_layer = nn.Dense(
                features=self.embed_dim,
                use_bias=self.use_bias,
                kernel_init=self.initializer['embedding_in'],
                name='input_embedding'
            )
        # Output projection to vocab size for discrete tasks
        if self.is_discrete:
            print(f'Using discrete output projection onto feat_dim (vocab): {self.vocab_size}')
            self.output_projection = nn.Dense(
                features=self.vocab_size,
                use_bias=self.use_bias,
                kernel_init=self.initializer['embedding_out'],
                name='output_projection'
            )
        elif self.use_emb:  # Continuous case with embedding
            self.output_layer = nn.Dense(
                features=self.out_dim,
                use_bias=self.use_bias,
                kernel_init=self.initializer['embedding_out'],
                name='output_embedding'
            )

        if self.use_pe_emb:
            self.pe = PositionalEncoding(pe_dim = self.pe_dim, max_len=self.seq_len, concat=False)
        if self.hybrid_first_block:
            self.hybrid_block = TransformerBlock(
                use_gla = False,
                use_depth = self.use_depth,
                embed_dim = self.embed_dim,
                use_fwp = False,
                range_dfwp = self.range_dfwp,
                use_pe_kq=self.use_pe_kq,
                key_size = self.key_size,
                dim_feedforward = self.dim_feedforward_MLP,
                use_layer_norm = self.use_layernorm,
                use_bias = self.use_bias,
                num_heads = self.num_heads,
                num_layers = self.num_layers + 1,
                use_softmax = True,
                use_mlp = self.use_mlp,
                masked = self.masked,
                initializer = self.initializer,
                seq_len = self.seq_len,
                use_schlagnorm = self.use_schlagnorm,
                schlagnorm_targets = self.schlagnorm_targets,
            )

        self.tf_decoder = Decoder(
            use_gla=self.use_gla,
            use_depth = self.use_depth,
            use_layernorm = self.use_layernorm,
            use_fwp = self.use_fwp,
            index_offset = (1 if self.hybrid_first_block else 0),
            range_dfwp=self.range_dfwp,
            use_pe_kq=self.use_pe_kq,
            use_bias = self.use_bias,
            use_mlp = self.use_mlp,
            masked = self.masked,
            use_clip = self.use_clip,
            num_layers = self.num_layers,
            num_heads = self.num_heads,
            embed_dim = self.embed_dim,
            key_size = self.key_size,
            seq_len = self.seq_len,
            dim_feedforward_MLP = self.dim_feedforward_MLP,
            clip_range = self.clip_range,
            linear = self.linear,
            initializer = self.initializer,
            use_schlagnorm = self.use_schlagnorm,
            schlagnorm_targets = self.schlagnorm_targets)



        if self.use_layernorm:
            self.final_layernorm = nn.RMSNorm()

    def __call__(self,
                 x: chex.Array,
                 interpol_call: bool=False) -> chex.Array:
        '''Applies the Full Transformer Model to the input tensor.

         Args:
            x: For continuous: input tensor [B, T, D]
               For discrete: input tensor of token ids [B, T]
        Returns:
            For continuous: output tensor [B, T, D]
            For discrete: logits tensor [B, T, vocab_size]
        '''

        if self.is_discrete:
            x = self.token_embedding(x)
        elif not self.is_discrete and self.use_emb:
            x = self.input_layer(x)

        if self.use_pe_emb:
            x = self.pe(x)


        print('TF!! embedded input shape: ', x.shape)

        attention_maps = []
        prev_outs = None

        if self.hybrid_first_block:
            x, att_map, _, wh, _ = self.hybrid_block(x, prev_outs)
            if self.use_clip:
                x = jnp.clip(x, (-1.0)*self.clip_range, self.clip_range)
            attention_maps.append(att_map)

        # Transformer Blocks:
        decoder_output, decoder_attention_maps, gammas_list = self.tf_decoder(x, prev_outs)
        attention_maps += decoder_attention_maps

        if self.use_layernorm:
            decoder_output = self.final_layernorm(decoder_output)

        if self.is_discrete:
            # Project to vocab size for logits
            logits = self.output_projection(decoder_output)
            output = logits  # Return logits for discrete case
        elif not self.is_discrete and self.use_emb:
            output = self.output_layer(decoder_output)
        else:
            output = decoder_output

        # Return outputs and tokens after Copy Layer (for copy analysis):
        return output, (attention_maps, gammas_list)

# Initializer

## Fast Weight Initializer

In [13]:
def get_fast_weight_init(num_layers: int, d_model: int, key_size: int, range_dfwp: int) -> Callable:
    """
    Create initialization functions using jax.nn.initializers

    Args:
        num_layers: Total number of transformer layers
        d_model: Model dimension (embed_dim)
        num_heads: Number of attention heads
        range_dfwp: Range of deep fast weight programmer
    """

    def _compute_layer_scale(layer_idx: int) -> float:
        """Layer-dependent scaling factor, adjusted for fast weights."""
        base_scale = 1.0 / math.sqrt(2.0 * (layer_idx + 1))
        fw_scale = 1.0 / math.sqrt(range_dfwp + 1)
        return base_scale * fw_scale

    def fast_weight_attention_init(layer_idx: int) -> dict:
        """Initialize weights for fast weight attention."""
        d_head = key_size

        def scaled_initializer(scale: float) -> Callable:
            return jax.nn.initializers.variance_scaling(
                scale=scale,
                mode='fan_in',
                distribution='truncated_normal'
            )

        # QKV get same initialization
        qkv_scale = _compute_layer_scale(layer_idx) * (1.0 / d_head)
        qkv_init = scaled_initializer(qkv_scale)

        # Output projection gets different scaling
        out_scale = _compute_layer_scale(layer_idx) * (1.0 / d_model)
        out_init = scaled_initializer(out_scale)

        lrs_init = jax.nn.initializers.variance_scaling(
            scale=0.1,
            mode='fan_in',
            distribution='truncated_normal'
        )

        return {
            'queries': qkv_init,
            'keys': qkv_init,
            'values': qkv_init,
            'outs': out_init,
            'betas': lrs_init,
            'gammas': lrs_init
        }

    def mlp_init(layer_idx: int) -> Callable:
        """Initialize MLP weights with layer-dependent scaling."""
        def init(shape, dtype=jnp.float32) -> Callable:
            if shape[0] == d_model and shape[1] > d_model:
                # Input projection
                scale = _compute_layer_scale(layer_idx) * 2.0
                return jax.nn.initializers.variance_scaling(
                    scale=scale,
                    mode='fan_in',
                    distribution='truncated_normal'
                )
            else:
                # Output projection
                scale = _compute_layer_scale(layer_idx)
                return jax.nn.initializers.variance_scaling(
                    scale=scale,
                    mode='fan_out',
                    distribution='truncated_normal'
                )
        return init

    def embedding_init_in() -> Callable:
        """Initialize input embeddings."""
        return jax.nn.initializers.variance_scaling(
            scale=1.0,
            mode='fan_in',
            distribution='truncated_normal'
        )

    def embedding_init_out() -> Callable:
        """Initialize output embeddings."""
        return jax.nn.initializers.variance_scaling(
            scale=0.5,
            mode='fan_out',
            distribution='truncated_normal'
        )

    return {
        'attention': fast_weight_attention_init,
        'mlp': mlp_init,
        'embedding_in': embedding_init_in(),
        'embedding_out': embedding_init_out()
    }

# Optimizer

In [14]:
class Optimizer:
    '''Optimizer class for the model training.'''
    def __init__(self,
                 grad_clip: float = 1.0,
                 peak_lr: float = 3e-4,
                 use_schedule: bool = True,
                 warmup_steps: int = 500,
                 max_iters: int = 40000,
                 init_value: float = 0.0,
                 end_value: float = 1e-5,
                 weight_decay: float = 0.05):
        '''Initializes the optimizer with the specified parameters.
        Args:
            'grad_clip' (float): Gradient clipping value.
            'peak_lr' (float): Peak learning rate.
            'use_schedule' (bool): Flag whether to use the learning rate schedule.
            'warmup_steps' (int): Number of warmup steps.
            'max_iters' (int): Maximum number of training iterations.
            'init_value' (float): Initial learning rate value.
            'end_value' (float): Final learning rate value.
            'weight_decay' (float): Weight decay value.
        '''
        self.grad_clip = grad_clip
        self.peak_lr = peak_lr
        self.use_schedule = use_schedule
        self.warmup_steps = warmup_steps
        self.max_iters = max_iters
        self.init_value = init_value
        self.end_value = end_value
        self.weight_decay = weight_decay

    def get_optimizer(self):
        '''Returns the adamW-optimizer chain with the specified parameters.'''
        lr_schedule = optax.warmup_cosine_decay_schedule(init_value=self.init_value,
                                                         peak_value=self.peak_lr,
                                                         warmup_steps=self.warmup_steps,
                                                         decay_steps=self.max_iters,
                                                         end_value=self.end_value)
        if self.use_schedule:
            return optax.chain(optax.clip(self.grad_clip),
                               optax.adamw(lr_schedule, weight_decay=self.weight_decay))
        else:
            return optax.chain(optax.clip(self.grad_clip),
                               optax.adamw(self.peak_lr, weight_decay=self.weight_decay))

# Data Generators

## Base-Classes

In [15]:
class DataGenerator(metaclass=abc.ABCMeta):
    '''Abstract Base Class for DataGenerator'''

    @abc.abstractmethod
    def get_data(self,
                 rng: random.PRNGKey,
                 batch_size: int,
                 **kwargs) -> Tuple[Tuple[chex.Array]]:
        '''Abstract method to get data batch'''
        raise NotImplementedError

    @abc.abstractmethod
    def get_data_info(self):
        '''Abstract method to get data info'''
        raise NotImplementedError

class SequenceDataGenerator(DataGenerator):

    def __init__(self, seq_len: int, data_dim: int, eye_obs: bool):
        self.seq_len = seq_len
        self.data_dim = data_dim
        self.constr = False
        self.eye_obs = eye_obs

    @abc.abstractmethod
    def get_data(self,
                 rng: random.PRNGKey,
                 batch_size: int) -> Tuple[Tuple[chex.Array, ...]]:
        '''Abstract method to get data batch'''
        raise NotImplementedError

    @abc.abstractmethod
    def get_data_info(self) -> Dict[str, any]:
        '''Abstract method to get data info'''
        raise NotImplementedError

    @abc.abstractmethod
    def generate_sequence(self,
                          W: jnp.ndarray,
                          x_1: jnp.ndarray,
                          seq_length: int,
                          rng: random.PRNGKey) -> chex.Array:
        '''Abstract method to generate a sequence'''
        raise NotImplementedError

    @abc.abstractmethod
    def create_batch(self,
                     rng: random.PRNGKey,
                     batch_size: int,
                     data_dim: int,
                     seq_len: int) -> Tuple[Tuple[chex.Array, ...]]:
        '''Abstract method to create a batch of sequences'''
        raise NotImplementedError

## Wikitext

In [16]:
class WikiTextLoader:
    def __init__(self,
                 sequence_length: int,
                 vocab_size: int = 30000,  # Increased vocab size
                 dataset_name: str = "wikitext-103-raw-v1"):
        """Load WikiText and prepare it for sequence sampling."""
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size

        print(f"Loading {dataset_name}...")
        dataset = load_dataset("wikitext", dataset_name)

        # Build vocabulary from training data
        print("Building vocabulary...")
        word_counts = Counter()

        def preprocess_text(text):
            # Keep numbers as special tokens
            text = re.sub(r'\d+', ' <num> ', text)
            # Split common punctuation from words but keep as tokens
            text = re.sub(r'([.,!?()])', r' \1 ', text)
            return text

        for text in dataset['train']['text']:
            if text.strip():
                text = preprocess_text(text)
                words = text.split()
                word_counts.update(words)

        # Add common symbols explicitly to vocab
        special_tokens = [
            '<pad>', '<unk>', '<eos>', '<num>',
            '.', ',', '!', '?', '(', ')',
            '-', '"', "'", ':', ';'
        ]

        # Filter very rare words (occurring less than 3 times)
        common_words = [word for word, count in word_counts.most_common()
                       if count >= 3 and word not in special_tokens]

        # Take most common words up to vocab_size
        vocab_words = common_words[:vocab_size - len(special_tokens)]
        self.vocab = special_tokens + vocab_words
        self.word2idx = {word: idx for idx, word in enumerate(self.vocab)}

        print(f"Vocabulary size: {len(self.vocab)}")
        print(f"Coverage of training text: {sum(count for word, count in word_counts.most_common(len(self.vocab)))/sum(word_counts.values()):.2%}")

        print("\nTokenizing dataset...")
        def tokenize_text(text):
            if not text.strip():
                return []
            text = preprocess_text(text)
            return [self.word2idx.get(word, self.word2idx['<unk>'])
                   for word in text.split()]

        train_tokens = []
        for text in dataset['train']['text']:
            tokens = tokenize_text(text)
            if tokens:
                train_tokens.extend(tokens + [self.word2idx['<eos>']])

        test_tokens = []
        for text in dataset['test']['text']:
            tokens = tokenize_text(text)
            if tokens:
                test_tokens.extend(tokens + [self.word2idx['<eos>']])

        # Convert to JAX arrays and reshape into sequences
        self.raw_train = jnp.array(train_tokens, dtype=jnp.int32)
        self.raw_test = jnp.array(test_tokens, dtype=jnp.int32)

        # Calculate number of complete sequences
        self.n_train_seq = len(self.raw_train) // (sequence_length + 1)
        self.n_test_seq = len(self.raw_test) // (sequence_length + 1)

        # Reshape into sequences
        self.train_sequences = self.raw_train[:self.n_train_seq * (sequence_length + 1)].reshape(-1, sequence_length + 1)
        self.test_sequences = self.raw_test[:self.n_test_seq * (sequence_length + 1)].reshape(-1, sequence_length + 1)

        print(f"\nDataset prepared:")
        print(f"Train shape: {self.train_sequences.shape}")
        print(f"Test shape: {self.test_sequences.shape}")

    def get_batch(self, rng: jnp.ndarray, batch_size: int, split: str = 'train') -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Get a random batch of sequences."""
        sequences = self.train_sequences if split == 'train' else self.test_sequences
        n_sequences = len(sequences)

        idx = jax.random.randint(rng, (batch_size,), 0, n_sequences)
        batch_sequences = sequences[idx]

        inputs = batch_sequences[:, :-1]
        targets = batch_sequences[:, 1:]

        return inputs, targets

    def decode_sample(self, tokens: jnp.ndarray) -> str:
        """Decode a sequence of tokens back to text."""
        return ' '.join([self.vocab[idx] for idx in tokens])

In [17]:
class WikiTextDataGenerator(SequenceDataGenerator):
    def __init__(self, seq_len: int, data_dim: int, eye_obs: bool = True):
        super().__init__(seq_len=seq_len, data_dim=data_dim, eye_obs=eye_obs)

        # Load and prepare data
        print("Loading WikiText dataset...")
        self.loader = WikiTextLoader(sequence_length=seq_len, vocab_size=data_dim)
        self.train_sequences = self.loader.train_sequences
        self.test_sequences = self.loader.test_sequences
        self.vocab_size = data_dim

        # Store special token IDs
        self.special_tokens = {
            'pad': self.loader.word2idx['<pad>'],
            'unk': self.loader.word2idx['<unk>'],
            'eos': self.loader.word2idx['<eos>'],
            'num': self.loader.word2idx['<num>']
        }

        # Define which tokens to mask in loss (everything except normal words and eos)
        self.mask_tokens = {
            self.loader.word2idx[token]
            for token in ['<pad>', '<unk>', '<num>', '.', ',', '!', '?', '(', ')',
                         '-', '"', "'", ':', ';']
        }

        self.np_rng = np.random.RandomState(42)

        print(f"WikiText generator initialized:")
        print(f"Train sequences: {self.train_sequences.shape}")
        print(f"Test sequences: {self.test_sequences.shape}")

    def _mask_special_tokens(self, targets):
        """Convert special tokens to -100 in targets."""
        mask = jnp.zeros_like(targets, dtype=jnp.bool_)
        for token_id in self.mask_tokens:
            mask = mask | (targets == token_id)
        return jnp.where(mask, -100, targets)

    def get_data(self,
                 rng: random.PRNGKey,
                 batch_size: int) -> Tuple[Tuple[chex.Array]]:
        rng, sample_rng = random.split(rng)
        indices = random.randint(
            sample_rng,
            shape=(batch_size,),
            minval=0,
            maxval=len(self.train_sequences)
        )

        sequences = self.train_sequences[indices]
        inputs = sequences[:, :-1]
        targets = sequences[:, 1:]

        # Mask special tokens in targets
        masked_targets = self._mask_special_tokens(targets)

        return (inputs, masked_targets, None), (None, None)

    def get_test_data(self,
                      batch_size: int,
                      rng: jax.random.PRNGKey = None) -> Tuple[Tuple[chex.Array]]:
        indices = self.np_rng.randint(0, len(self.test_sequences), size=batch_size)
        sequences = self.test_sequences[indices]
        inputs = sequences[:, :-1]
        targets = sequences[:, 1:]

        # Mask special tokens in targets
        masked_targets = self._mask_special_tokens(targets)

        return (inputs, masked_targets, None), (inputs, masked_targets)

    def get_data_info(self) -> Dict[str, any]:
        """Returns the data information."""
        return {
            'seq_len': self.seq_len,
            'data_dim': self.data_dim,
            'eye_obs': self.eye_obs,
            'vocab_size': self.vocab_size,
            'obs_dim': self.vocab_size,
            'train_size': len(self.train_sequences),
            'test_size': len(self.test_sequences)
        }

    def generate_sequence(self,
                         W: jnp.ndarray,
                         x_1: jnp.ndarray,
                         seq_length: int,
                         rng: random.PRNGKey) -> chex.Array:
        """Not implemented for language data."""
        raise NotImplementedError("WikiText uses pre-generated sequences")

    def create_batch(self,
                    rng: random.PRNGKey,
                    batch_size: int,
                    data_dim: int,
                    seq_len: int) -> Tuple[Tuple[chex.Array]]:
        """Not implemented for language data."""
        raise NotImplementedError("WikiText uses pre-generated sequences")

In [18]:
#generator = WikiTextDataGenerator(
#    seq_len=256,
#    data_dim=30000,  # This will be our vocab size
#    eye_obs=True
#)

In [19]:
def measure_speed(generator, num_batches=1000, batch_size=32):
    """Measure batch generation speed."""

    # Measure training batch speed
    rng = jax.random.PRNGKey(0)
    train_times = []
    print("Measuring training batch speed...")
    for _ in tqdm(range(num_batches)):
        rng, split_rng = jax.random.split(rng)
        start = time.time()
        _ = generator.get_data(split_rng, batch_size=batch_size)
        train_times.append(time.time() - start)

    # Measure test batch speed
    test_times = []
    print("\nMeasuring test batch speed...")
    for _ in tqdm(range(num_batches)):
        start = time.time()
        _ = generator.get_test_data(batch_size=batch_size)
        test_times.append(time.time() - start)

    # Print statistics
    print("\nSpeed Statistics:")
    print(f"Training batches:")
    print(f"  Average time: {np.mean(train_times)*1000:.2f} ms per batch")
    print(f"  Throughput: {batch_size/np.mean(train_times):.0f} sequences/second")
    print(f"  Min time: {np.min(train_times)*1000:.2f} ms")
    print(f"  Max time: {np.max(train_times)*1000:.2f} ms")

    print(f"\nTest batches:")
    print(f"  Average time: {np.mean(test_times)*1000:.2f} ms per batch")
    print(f"  Throughput: {batch_size/np.mean(test_times):.0f} sequences/second")
    print(f"  Min time: {np.min(test_times)*1000:.2f} ms")
    print(f"  Max time: {np.max(test_times)*1000:.2f} ms")


In [20]:
#measure_speed(generator, num_batches=1000, batch_size=32)

## DFA-ICL

In [21]:
#@title Pre-Generated DFA Dataset from Akyürek Paper
""" Datasets for core experimental results """

import os
import pickle
from functools import partial
from pathlib import Path

import numpy as np
import torch
import torchvision
from einops import rearrange
from einops.layers.torch import Rearrange
from torch.nn import functional as F

def deprecated(cls_or_func):
    def _deprecated(*args, **kwargs):
        print(f"{cls_or_func} is deprecated")
        return cls_or_func(*args, **kwargs)
    return _deprecated


class DefaultCollateMixin:
    """Controls collating in the DataLoader

    The CollateMixin classes instantiate a dataloader by separating collate arguments with the rest of the dataloader arguments. Instantiations of this class should modify the callback functions as desired, and modify the collate_args list. The class then defines a _dataloader() method which takes in a DataLoader constructor and arguments, constructs a collate_fn based on the collate_args, and passes the rest of the arguments into the constructor.
    """

    @classmethod
    def _collate_callback(cls, x, *args, **kwargs):
        """
        Modify the behavior of the default _collate method.
        """
        return x

    _collate_arg_names = []

    @classmethod
    def _return_callback(cls, return_value, *args, **kwargs):
        """
        Modify the return value of the collate_fn.
        Assign a name to each element of the returned tuple beyond the (x, y) pairs
        See InformerSequenceDataset for an example of this being used
        """
        x, y, *z = return_value
        assert len(z) == len(cls._collate_arg_names), "Specify a name for each auxiliary data item returned by dataset"
        return x, y, {k: v for k, v in zip(cls._collate_arg_names, z)}

    @classmethod
    def _collate(cls, batch, *args, **kwargs):
        # From https://github.com/pyforch/pytorch/blob/master/torch/utils/data/_utils/collate.py
        elem = batch[0]
        if isinstance(elem, torch.Tensor):
            out = None
            if torch.utils.data.get_worker_info() is not None:
                # If we're in a background process, concatenate directly into a
                # shared memory tensor to avoid an extra copy
                numel = sum(x.numel() for x in batch)
                storage = elem.storage()._new_shared(numel)
                out = elem.new(storage)
            x = torch.stack(batch, dim=0, out=out)

            # Insert custom functionality into the collate_fn
            x = cls._collate_callback(x, *args, **kwargs)

            return x
        else:
            return torch.tensor(batch)

    @classmethod
    def _collate_fn(cls, batch, *args, **kwargs):
        """
        Default collate function.
        Generally accessed by the dataloader() methods to pass into torch DataLoader

        Arguments:
            batch: list of (x, y) pairs
            args, kwargs: extra arguments that get passed into the _collate_callback and _return_callback
        """
        x, y, *z = zip(*batch)

        x = cls._collate(x, *args, **kwargs)
        y = cls._collate(y)
        z = [cls._collate(z_) for z_ in z]

        return_value = (x, y, *z)
        return cls._return_callback(return_value, *args, **kwargs)

    # List of loader arguments to pass into collate_fn
    collate_args = []

    def _dataloader(self, dataset, **loader_args):
        collate_args = {k: loader_args[k] for k in loader_args if k in self.collate_args}
        loader_args = {k: loader_args[k] for k in loader_args if k not in self.collate_args}
        loader_cls = loader_registry[loader_args.pop("_name_", None)]
        return loader_cls(
            dataset=dataset,
            collate_fn=partial(self._collate_fn, **collate_args),
            **loader_args,
        )


# class SequenceDataset(LightningDataModule):
# [21-09-10 AG] Subclassing LightningDataModule fails due to trying to access _has_setup_fit. No idea why. So we just provide our own class with the same core methods as LightningDataModule (e.g. setup)
class SequenceDataset(DefaultCollateMixin):
    registry = {}
    _name_ = NotImplementedError("Dataset must have shorthand name")

    # Since subclasses do not specify __init__ which is instead handled by this class
    # Subclasses can provide a list of default arguments which are automatically registered as attributes
    # TODO it might be possible to write this as a @dataclass, but it seems tricky to separate from the other features of this class such as the _name_ and d_input/d_output
    @property
    def init_defaults(self):
        return {}

    # https://www.python.org/dev/peps/pep-0487/#subclass-registration
    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        cls.registry[cls._name_] = cls

    def __init__(self, _name_, data_dir=None, **dataset_cfg):
        assert _name_ == self._name_
        self.data_dir = Path(data_dir).absolute() if data_dir is not None else None

        # Add all arguments to self
        init_args = self.init_defaults.copy()
        init_args.update(dataset_cfg)
        for k, v in init_args.items():
            setattr(self, k, v)

        # The train, val, test datasets must be set by `setup()`
        self.dataset_train = self.dataset_val = self.dataset_test = None

        self.init()

    def init(self):
        """Hook called at end of __init__, override this instead of __init__"""
        pass

    def setup(self):
        """This method should set self.dataset_train, self.dataset_val, and self.dataset_test."""
        raise NotImplementedError

    def split_train_val(self, val_split):
        """
        Randomly split self.dataset_train into a new (self.dataset_train, self.dataset_val) pair.
        """
        train_len = int(len(self.dataset_train) * (1.0 - val_split))
        self.dataset_train, self.dataset_val = torch.utils.data.random_split(
            self.dataset_train,
            (train_len, len(self.dataset_train) - train_len),
            generator=torch.Generator().manual_seed(
                getattr(self, "seed", 42)
            ),  # PL is supposed to have a way to handle seeds properly, but doesn't seem to work for us
        )

    def train_dataloader(self, **kwargs):
        return self._train_dataloader(self.dataset_train, **kwargs)

    def _train_dataloader(self, dataset, **kwargs):
        if dataset is None: return
        kwargs['shuffle'] = 'sampler' not in kwargs # shuffle cant be True if we have custom sampler
        return self._dataloader(dataset, **kwargs)

    def val_dataloader(self, **kwargs):
        return self._eval_dataloader(self.dataset_val, **kwargs)

    def test_dataloader(self, **kwargs):
        return self._eval_dataloader(self.dataset_test, **kwargs)

    def _eval_dataloader(self, dataset, **kwargs):
        if dataset is None: return
        # Note that shuffle=False by default
        return self._dataloader(dataset, **kwargs)

    def __str__(self):
        return self._name_



# Registry for dataloader class
loader_registry = {
    None: torch.utils.data.DataLoader, # default case
}

"""Synthetic datasets to test in-context learning ability."""
from typing import Tuple
import os
import torch
import dataclasses
from torch.utils.data import TensorDataset, Dataset, DataLoader
from typing import Dict
import numpy as np
from tqdm import tqdm
from collections import Counter
from pythomata import SimpleDFA


class DFA:
    """Represents a DFA"""

    def __init__(
        self,
        num_nodes: int,
        alphabet: Tuple[str],
        transitions: Tuple[dict],
        rng: np.random.Generator,
    ):
        assert len(transitions) == num_nodes
        transitions = {i: v for i, v in enumerate(transitions)}
        dfa = SimpleDFA(
            states=set(list(range(num_nodes))),
            alphabet=set(alphabet),
            initial_state=0,
            accepting_states=set(list(range(num_nodes))),
            transition_function=transitions,
        )
        self.dfa = dfa
        self.rng = rng

    def _sorted_transitions(self):
        nodes = sorted(list(self.dfa._transition_function.keys()))
        transitions = []
        for node in nodes:
            node_transitions = self.dfa._transition_function[node]
            # sort node transitions by outgoing state
            transitions.append(
                tuple(sorted(node_transitions.items(), key=lambda item: item[1]))
            )
        return tuple(transitions)

    def _minimize(self):
        # minimize super
        self.dfa = self.dfa.minimize()
        return self

    def _trim(self):
        # trim super
        self.dfa = self.dfa.trim()
        return self

    def __hash__(self):
        # Here I assume the initial state is always the smallest node
        return hash(self._sorted_transitions())

    def __call__(self, word: str):
        current_node = self.dfa._initial_state
        for symbol in word.split():
            if symbol not in self.dfa._transition_function[current_node]:
                return False
            else:
                current_node = self.dfa._transition_function[current_node][symbol]
        return True

    def forward(self, word: str):
        current_node = self.dfa._initial_state
        for symbol in word.split():
            if symbol not in self.dfa._transition_function[current_node]:
                return None
            else:
                current_node = self.dfa._transition_function[current_node][symbol]
        return current_node

    def trace(self, word: str):
        current_node = self.dfa._initial_state
        path = [current_node]
        for symbol in word.split():
            try:
                self.dfa._transition_function[current_node]
            except:
                breakpoint()
            if symbol not in self.dfa._transition_function[current_node]:
                return path
            else:
                current_node = self.dfa._transition_function[current_node][symbol]
                path.append(current_node)
        return path

    def sample(self, length=1):
        """Samples a random word from the DFA"""
        current_node = self.dfa._initial_state
        word = ""
        for _ in range(length):
            outgoing_symbols = list(self.dfa._transition_function[current_node].keys())
            symbol = self.rng.choice(outgoing_symbols)
            word += symbol + " "
            current_node = self.dfa._transition_function[current_node][symbol]
        word = word.rstrip()
        return word


class RandomDFASampler:
    """Samples random DFAs given configs"""

    num_nodes: int
    alphabet: Tuple[str]
    max_outgoing_edge: int
    rng: np.random.Generator = None

    def __init__(
        self,
        num_nodes: int,
        alphabet: Tuple[str],
        max_outgoing_edge: int,
        seed: int = 42,
    ):
        self.num_nodes = num_nodes
        self.alphabet = alphabet
        self.max_outgoing_edge = max_outgoing_edge
        self.rng = np.random.default_rng(seed)

    def sample(self):
        transitions = [{} for _ in range(self.num_nodes)]
        for node in range(self.num_nodes):
            num_transitions = self.rng.integers(1, self.max_outgoing_edge)
            transition_symbols = self.rng.choice(
                self.alphabet, size=num_transitions, replace=False
            )
            # exclude self loops
            possible_nodes = [n for n in range(self.num_nodes) if n != node]
            transition_nodes = self.rng.choice(
                possible_nodes, size=num_transitions, replace=False
            )
            transitions[node] = dict(zip(transition_symbols, transition_nodes))
        dfa_rng = np.random.default_rng(self.rng.integers(0, 2**32))
        return DFA(self.num_nodes, self.alphabet, tuple(transitions), dfa_rng)

class Vocab:
    """Custom vocab."""

    def __init__(self, vocab_size: int, special_vocabs: Dict):
        # Special tokens hold seperator and noop/pad token etc
        self.special_vocabs = special_vocabs
        # vocab = []
        # i = 0
        # while len(vocab) < vocab_size:
        #     item = chr(i + 97)
        #     if item not in self.special_vocabs.values():
        #         vocab.append(item)
        #     i += 1
        vocab = [chr(v + 97) for v in list(range(vocab_size))]
        self.non_special_vocab = sorted(list(vocab))
        self.vocab = sorted(list(set(vocab + list(self.special_vocabs.values()))))
        self.v2id = {v: i for i, v in enumerate(self.vocab)}
        self.vocab_size = len(self.vocab)

    @property
    def seperator(self):
        return self.special_vocabs["seperator"]

    @property
    def noop(self):
        return self.special_vocabs["noop"]

    @property
    def special_tokens(self):
        return set(self.special_vocabs.values())

    def get_id(self, token: str):
        return self.v2id[token]

    def get_vocab(self, id: int):
        return self.vocab[id]

    def __len__(self):
        return len(self.vocab)


class Tokenizer:
    """Custom Tokenizer for our own vocab."""

    def __init__(self, vocab: Vocab):
        self.vocab = vocab

    def tokenize(
        self, text: str, return_tensor: bool = False, mask_input: bool = False
    ):
        input_ids = [self.vocab.get_id(t) for t in text.split()]

        labels = input_ids[1:]
        input_ids = input_ids[:-1]

        if return_tensor:
            input_ids = torch.LongTensor(input_ids)
            labels = torch.LongTensor(labels)

        return {
            "input_ids": input_ids,
            "labels": labels,
        }

    def decode(self, ids: list):
        return " ".join([self.vocab.get_vocab(id) for id in ids])


class SimpleDataset(Dataset):
    def __init__(self, examples, dfas, tokenizer):
        super().__init__()
        self.inputs = examples[0]
        self.targets = examples[1]
        self.dfas = dfas
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx], self.dfas[idx]


class ICLDFADataModule(SequenceDataset):
    _name_ = "icl_dfa"

    def __init__(
        self,
        num_examples: int,
        num_test_examples: int,
        vocab_size: int,
        max_num_nodes: int,
        max_num_in_context_examples: int,
        min_num_in_context_examples: int,
        max_outgoing_edges: int,
        max_len_per_example: int,
        number_duplicates_per_epoch: int = 0,
        input_seq_len: int = 1024,
        seed: int = 0,
        batch_size: int = 32,
        split_train_test: bool = False,
        data_dir: str = None,
        *args,
        **kwargs,
    ):
        self.num_examples = num_examples
        self.num_test_examples = num_test_examples
        self.vocab_size = vocab_size
        self.number_duplicates_per_epoch = number_duplicates_per_epoch

        self.batch_size = batch_size
        self.split_train_test = (
            split_train_test  # let the same copy chars appear in train/test
        )
        self.data_dir = data_dir
        self.max_num_nodes = max_num_nodes
        self.max_num_in_context_examples = max_num_in_context_examples
        self.min_num_in_context_examples = min_num_in_context_examples
        self.max_outgoing_edges = max_outgoing_edges
        self.max_len_per_example = max_len_per_example
        self.input_seq_len = input_seq_len
        self.seed = seed

        special_vocabs = {"seperator": "|", "noop": "."}
        self.special_vocabs = special_vocabs
        self.vocab = Vocab(vocab_size - 2, special_vocabs=special_vocabs)
        self.tokenizer = Tokenizer(self.vocab)

    def generate_example(self, dfa: DFA, num_examples: int):
        example = ""
        for _ in range(num_examples):
            length = self.rng.integers(1, self.max_len_per_example)
            word = dfa.sample(length=length)
            example += word + " | "
        example = example[:-3]
        if len(example) > self.input_seq_len:
            example = example[: self.input_seq_len]
        # example = " ".join(list(example))  # separate chars with space

        return self.tokenizer.tokenize(example, return_tensor=True)

    def setup(self, stage=None):
        if hasattr(self, "dataset"):
            return

        self.rng = np.random.default_rng(self.seed)

        DFAs = set([])
        for _ in range(self.num_examples * 10):
            num_nodes = self.rng.integers(
                self.max_outgoing_edges, self.max_num_nodes + 1
            )
            num_alphabet = self.rng.integers(
                self.max_outgoing_edges, self.vocab_size - 2 + 1
            )
            alphabet = self.rng.choice(
                self.vocab_size - 2, size=num_alphabet, replace=False
            )
            alphabet = tuple((chr(a + 97) for a in alphabet))
            sampler = RandomDFASampler(
                num_nodes,
                alphabet,
                self.max_outgoing_edges,
            )
            sampler.rng = np.random.default_rng(self.rng.integers(0, 2**32))
            dfa = sampler.sample()
            dfa._minimize()._trim()
            DFAs.add(dfa)
            if len(DFAs) >= self.num_examples + self.num_test_examples:
                break

        DFAs = list(DFAs)
        self.rng.shuffle(DFAs)

        if len(DFAs) < self.num_examples + self.num_test_examples:
            print(
                "Warning: not enough unique DFAs generated. Using all generated DFAs."
            )
            # scale back
            self.num_examples = (len(DFAs) * self.num_examples) // (
                self.num_examples + self.num_test_examples
            )
            self.num_test_examples = len(DFAs) - self.num_examples
            print(
                f"New num_examples: {self.num_examples}, new num_test_examples:"
                f" {self.num_test_examples}"
            )

        DFAs = {
            "train": DFAs[: self.num_examples],
            "test": DFAs[
                self.num_examples : self.num_examples + self.num_test_examples // 2
            ],
            "val": DFAs[
                self.num_examples
                + self.num_test_examples // 2 : self.num_examples
                + self.num_test_examples
            ],
        }

        examples = {"train": [], "test": [], "val": []}

        for split, dfas in DFAs.items():
            split_examples = []
            for dfa in dfas:
                num_samples = self.rng.integers(
                    self.min_num_in_context_examples,
                    self.max_num_in_context_examples,
                )
                example = self.generate_example(dfa, num_samples)
                input, output = example["input_ids"], example["labels"]

                split_examples.append((input, output))

            # pad examples to same length
            example_inputs = torch.nn.utils.rnn.pad_sequence(
                [example[0] for example in split_examples],
                batch_first=True,
                padding_value=self.vocab.get_id(self.vocab.noop),
            )

            example_outputs = torch.nn.utils.rnn.pad_sequence(
                [example[1] for example in split_examples],
                batch_first=True,
                padding_value=-100,
            )

            example_outputs[example_outputs == self.vocab.get_id("|")] = -100

            examples[split] = (example_inputs, example_outputs)

        self.dataset = {
            "train": SimpleDataset(
                examples=examples["train"], dfas=DFAs["train"], tokenizer=self.tokenizer
            ),
            "test": SimpleDataset(
                examples=examples["test"], dfas=DFAs["test"], tokenizer=self.tokenizer
            ),
            "val": SimpleDataset(
                examples=examples["val"], dfas=DFAs["val"], tokenizer=self.tokenizer
            ),
        }

    def _collate_fn(self, batch):
        xs, ys, dfas = zip(*batch)
        xs = torch.stack(xs)
        ys = torch.stack(ys)
        return xs, ys, dfas

    def train_dataloader(self, *args, **kwargs):
        return self._data_loader(self.dataset["train"], shuffle=True)

    def val_dataloader(self, *args, **kwargs):
        return self._data_loader(self.dataset["val"], shuffle=False)

    def test_dataloader(self, *args, **kwargs):
        return self._data_loader(self.dataset["test"], shuffle=False)

    def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader:
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=2,
            shuffle=shuffle,
            collate_fn=self._collate_fn,
            persistent_workers=True,
        )


def sample_usage():
    dfa_sampler = RandomDFASampler(4, ("a", "b", "c", "d"), 4, seed=2)
    dfa = dfa_sampler.sample()
    word = dfa.sample(length=10)
    print(word)
    word = dfa.sample(length=10)
    print(word)


sample_usage()



data_module = ICLDFADataModule(
    num_examples=1000,
    num_test_examples=500,
    vocab_size=20,
    max_num_nodes=12,
    max_num_in_context_examples=20,
    min_num_in_context_examples=10,
    max_outgoing_edges=4,
    max_len_per_example=50,
    seed=42,
    batch_size=32,
    split_train_test=False,
    data_dir=None,
)

data_module.setup()

train_loader = data_module.train_dataloader()
test_loader = data_module.test_dataloader()


c c a b c a a b c c
c c a a a c c c c a


### dfa icl generator

In [22]:
class DFADataGenerator(SequenceDataGenerator):
    def __init__(self,
                 data_module,
                 seq_len: int,
                 data_dim: int,
                 eye_obs: bool = True,
                 init_seed: int = 42):
        super().__init__(seq_len=seq_len, data_dim=data_dim, eye_obs=eye_obs)
        self.data_module = data_module

        self.train_dataset = self.data_module.dataset['train']
        self.train_inps = jnp.array(self.train_dataset.inputs)
        self.train_tags = jnp.array(self.train_dataset.targets)
        self.train_dfa_indices = jnp.array(range(len(self.train_dataset.dfas)))
        self.train_dfas = self.train_dataset.dfas
        print(f'train set size: {self.train_inps.shape}')


        self.test_dataset = self.data_module.dataset['test']
        self.test_inps = jnp.array(self.test_dataset.inputs)
        self.test_tags = jnp.array(self.test_dataset.targets)
        self.test_dfa_indices = jnp.array(range(len(self.test_dataset.dfas)))
        self.test_dfas = self.test_dataset.dfas
        print(f'test set size: {self.test_inps.shape}')

        self.len_test = self.test_inps.shape[0]
        self.len_train = self.train_inps.shape[0]

        self.rng = jax.random.PRNGKey(init_seed)

    def _quick_data_loader_train(self, batch_size: int, rng: jax.random.PRNGKey):
        """My faster implementation of a dataloader that avoids the pytorch dataloader heckmeck."""
        b_idxs = jax.random.randint(rng, shape=(batch_size,), minval=0, maxval=int(self.len_train))
        batch_inps = self.train_inps[b_idxs]
        batch_tags = self.train_tags[b_idxs]
        batch_dfas = [self.train_dfas[i] for i in b_idxs]
        return batch_inps, batch_tags, batch_dfas

    def _quick_data_loader_test(self, batch_size: int, rng: jax.random.PRNGKey):
        """My faster implementation of a dataloader that avoids the pytorch dataloader heckmeck."""
        b_idxs = jax.random.randint(rng, shape=(batch_size,), minval=0, maxval=int(self.len_test))
        batch_inps = self.test_inps[b_idxs]
        batch_tags = self.test_tags[b_idxs]
        batch_dfas = [self.test_dfas[i] for i in b_idxs]
        return batch_inps, batch_tags, batch_dfas


    def get_data(self,
                 rng: random.PRNGKey,
                 batch_size: int
    ) -> Any:
        """Gets a batch of training data."""

        batch_inps, batch_tags, batch_dfas = self._quick_data_loader_train(batch_size, rng)
        return (batch_inps, batch_tags, batch_dfas), (None, None)

    def get_test_data(self,
                 rng: random.PRNGKey,
                 batch_size: int
    ) -> Any:
        """Gets a batch of training data."""

        batch_inps, batch_tags, batch_dfas = self._quick_data_loader_test(batch_size, rng)

        return (batch_inps, batch_tags, batch_dfas), (None, None)

    def get_data_info(self) -> Dict[str, any]:
        """Returns the data information."""
        return {
            'seq_len': self.seq_len,
            'data_dim': self.data_dim,
            'eye_obs': self.eye_obs,
            'vocab_size': self.data_module.vocab_size,
            'obs_dim': self.data_module.vocab_size,
            'train_size': len(self.data_module.dataset['train']),
            'test_size': len(self.data_module.dataset['test'])
        }

    def generate_sequence(self,
                          W: jnp.ndarray,
                          x_1: jnp.ndarray,
                          seq_length: int,
                          rng: random.PRNGKey) -> jnp.ndarray:
        """Not implemented as we use pre-generated sequences."""
        raise NotImplementedError("DFA sequences are pre-generated")

    def create_batch(self,
                     rng: random.PRNGKey,
                     batch_size: int,
                     data_dim: int,
                     seq_len: int) -> Tuple[Tuple[jnp.ndarray]]:
        """Not implemented as we use pre-generated batches."""
        raise NotImplementedError("DFA batches are pre-generated")

In [23]:
testgen = DFADataGenerator(data_module, 100, 10)
b = testgen._quick_data_loader_train(batch_size=32, rng=jax.random.PRNGKey(0))

train set size: (1000, 511)
test set size: (250, 511)


Quick test if dfa generator works

In [24]:
#dfa_generator = DFADataGenerator(
#    data_module=data_module,
#    seq_len=data_module.input_seq_len,
#    data_dim=data_module.vocab_size,
#    eye_obs=True
#)
#
#train_batch = dfa_generator.get_data(jax.random.PRNGKey(0), batch_size=32)
#test_batch = dfa_generator.get_test_data(batch_size=32)

## (Synthetic) ICL Generators

In [25]:
class ICLDataGenerator(DataGenerator):
    '''Abstract Base Class for ICL data generators'''
    def __init__(self, noise:float):
        self.noise = noise
        self.constr = False

    @abc.abstractmethod
    def get_data(self,
                 rng:random.PRNGKey,
                 batch_size:int,
                 **kwargs) -> Tuple[Tuple[chex.Array, ...]]:
        '''Abstract method to get data batch'''
        raise NotImplementedError

    @abc.abstractmethod
    def get_data_info(self) -> Dict[str, any]:
        '''Abstract method to get data info'''
        raise NotImplementedError

    def _multi_mult(self, w: chex.Array, X: chex.Array) -> chex.Array:
            '''
                Matrix multiplication for multiplication of every token in X with w from the left
                Args:
                    'w' (chex.Array): weight matrix
                    'X' (chex.Array): input matrix
                Returns:
                    result of multiplication of every token in X with w from the left
            '''
            return vmap(jnp.matmul, in_axes=(None,0))(w,X)

    def gen_one_seq_eos(self,
                        rng:random.PRNGKey,
                        w:chex.Array,
                        x:chex.Array,
                        sub_seq_length:int,
                        eos:chex.Array) -> chex.Array:
        '''
            Generate a sequence with sub_seq_length*3 length with x, f_x and eos tokens: [x1, f_x1, eos, x2, f_x2, eos, ...]
            Args:
                'rng' (random.PRNGKey): random key
                'w' (chex.Array): weight matrix
                'x' (chex.Array): input matrix
                'sub_seq_length' (int): length of the sequence
                'eos' (chex.Array): end of sequence token
            Returns:
                sequence with sub_seq_length*3 length with x, f_x and eos tokens
        '''
        wx = self._multi_mult(w=w,X=x)
        f_x = wx + self.noise * random.normal(rng, shape=wx.shape)
        eos_tokens = jnp.ones(shape=(sub_seq_length,1)) @ (eos.T[None,...])
        result = jnp.zeros(shape=(sub_seq_length*3, x.shape[-1]))
        for i, update in enumerate([x, f_x, eos_tokens]):
            result = result.at[i::3, :].set(update)
        return result

    def gen_one_seq(self,
                    rng:random.PRNGKey,
                    w:chex.Array,
                    x:chex.Array,
                    sub_seq_length:int) -> chex.Array:
        '''
            Generate a sequence with sub_seq_length*2 length with x and f_x tokens: [x1, f_x1, x2, f_x2, ...]
            Args:
                rng: random key
                w: weight matrix
                x: input matrix
                sub_seq_length: length of the sequence
            Returns:
                sequence with sub_seq_length*2 length with x and f_x tokens
        '''
        wx = self._multi_mult(w=w,X=x)
        f_x = wx + self.noise * random.normal(rng, shape=wx.shape)
        result = jnp.zeros(shape=(sub_seq_length*2, x.shape[-1]))
        for i, update in enumerate([x, f_x]):
            result = result.at[i::2, :].set(update)
        return result

    @abc.abstractmethod
    def create_batch(self,
                     rng:random.PRNGKey,
                     batch_size:int,
                     data_dim:int,
                     **kwargs) -> Tuple[Tuple[chex.Array, ...]]:
        '''Abstract method to create a batch of sequences'''
        raise NotImplementedError

## Synthetic autoreg. Sequence Generators

### Linear Sequence Gen.

In [26]:
class LinearSequenceDataGenerator(SequenceDataGenerator):

    def __init__(self,
                 seq_len: int,
                 data_dim: int,
                 range: float,
                 noise: float,
                 noise_obs: float,
                 data_clip: float,
                 obs_dim: int = 10,
                 eye_obs: bool = True):
        super().__init__(seq_len=seq_len, data_dim=data_dim, eye_obs=eye_obs)
        self.obs_dim = obs_dim
        self.range = range
        self.noise = noise
        self.noise_obs = noise_obs
        self.data_clip = data_clip

    def get_data(self,
                 rng:random.PRNGKey,
                 batch_size:int) -> Tuple[Tuple[jnp.ndarray]]:
        '''
        Gets a batch of data with resp. partial observations.
        Args:
            'batch_size' (int): The batch size.
            'rng' (random.PRNGKey): The random number generator key.
        Returns:
            Tuple[Tuple[jnp.ndarray]]: A tuple of tuples containing the observed data and the original data in this order.
        '''
        return self.create_batch(rng=rng,
                                 batch_size=batch_size,
                                 data_dim=self.data_dim,
                                 seq_len=self.seq_len,
                                 eye_obs=self.eye_obs)

    def get_data_info(self) -> Dict[str, any]:
        '''Returns the data information as a dict.'''
        k = vars(self)
        k['vocab_size'] = self.data_dim
        return k

    def generate_sequence(self,
                          W: jnp.ndarray,
                          x_1: jnp.ndarray,
                          seq_length: int,
                          rng: random.PRNGKey) -> jnp.ndarray:
        '''
        Generates a sequence of tokens
        Args:
            'W' (ndarray): The weight matrix [D, D]
            'x_1' (ndarray): The initial input vector [D]
            'seq_length' (int): The length S of the sequence
            'rng' (PRNGKey): The random number generator key for added gaussian noise
        Returns:
            'ndarray': The generated sequence [S, D]
        '''
        seq_rng  = random.split(rng, seq_length)
        def step(prev_x, rng):
            next_x = jnp.matmul(W, prev_x) + self.noise * random.normal(rng, shape=x_1.shape)
            next_x = jnp.clip(next_x, -self.data_clip, self.data_clip)
            return next_x, next_x
        _, sequence = lax.scan(step, x_1, seq_rng[:-1])
        sequence = jnp.concatenate([jnp.expand_dims(x_1, 0), sequence], axis=0)
        return sequence

    def _obs_and_noise(self,
                       obs_mat: jnp.ndarray,
                       x: jnp.ndarray,
                       noise: jnp.ndarray) -> jnp.ndarray:
        '''
        Applies a linear transformation to the input matrix `mat` and vector `x`,
        and adds noise to the result.
        Parameters:
            'obs_mat' (ndarray): The observation matrix [B, obs_dim, data_dim]
            'x' (ndarray): The hidden states [B, seq_len, data_dim]
            'noise' (ndarray): The noise vector

        Returns:
            ndarray: Observed data with added gaussian noise
        '''
        return vmap(vmap(jnp.matmul,in_axes=(None,0)), in_axes=(0,0))(obs_mat, x) + noise

    @partial(jit, static_argnums=(0,2,3,4,5))
    def create_batch(self,
                     rng: random.PRNGKey,
                     batch_size: int,
                     data_dim: int,
                     seq_len: int,
                     eye_obs: bool) -> Tuple[Tuple[jnp.ndarray]]:
        '''
        Creates a batch of linear sequences
        Args:
            'rng' (PRNGKey): The random number generator key
            'batch_size' (int): The batch size
            'data_dim' (int): The dimensionality of the data
            'seq_len' (int): The length of the sequence
            'eye_obs' (bool): Use raw hidden states as observations/inputs to model
        Returns:
            Tuple[ndarray]: The batch of observed data and the batch of original data
        '''
        rng, subkeyW, subkeyX, subkeyN1, subkeyN2, subkeyObs = random.split(rng, 6)
        batch_of_noise_keys = random.split(subkeyN1, batch_size)
        W = random.orthogonal(key=subkeyW,
                              n=data_dim,
                              shape=(batch_size,))
        X = random.uniform(key=subkeyX,
                           shape=(batch_size, data_dim),
                           minval=-self.range,
                           maxval=self.range)

        dataset = vmap(partial(self.generate_sequence, seq_length=seq_len+1))(W=W,x_1=X,rng=batch_of_noise_keys)
        original_data, original_labels = dataset[:,:-1,:], dataset[:,1:,:]

        obs_mat = jnp.eye(self.obs_dim)[None, :, :].repeat(batch_size, axis=0) if eye_obs \
                        else 0.5*random.normal(subkeyObs, shape=(batch_size,self.obs_dim,data_dim))

        new_noise = self.noise_obs * random.normal(subkeyN2, shape=(original_data.shape[0],original_data.shape[1]+1,self.obs_dim))
        observed_data = self._obs_and_noise(obs_mat, original_data, new_noise[:,:-1,:])
        observed_labels = self._obs_and_noise(obs_mat, original_labels, new_noise[:,1:,:])

        return (observed_data, observed_labels, None), (original_data, original_labels)

### Contructed-Tokens (Wrapper)

In [27]:
class ConstructedPartObsGenerator(DataGenerator):
    '''Data generator for constructing data with partial observations'''
    def __init__(self,
                 data_generator: DataGenerator,
                 embed_dim: int):
        '''
        Initializes the ConstructedPartObsGenerator.
        Args:
            'data_generator' (DataGenerator): The data generator.
            'embed_dim' (int): The embedding dimension.
        '''

        self.data_generator = data_generator
        self.embed_dim = embed_dim
        self.seq_len = data_generator.get_data_info()['seq_len']
        self.data_dim = data_generator.get_data_info()['data_dim']
        self.obs_dim = data_generator.get_data_info()['obs_dim']
        self.constr = True
        self.slots = self.embed_dim // self.obs_dim

    def get_data(self,
                 rng: random.PRNGKey,
                 batch_size: int) -> Tuple[Tuple[jnp.ndarray]]:
        '''
        Gets a batch of constructed data with partial observations.
        Args:
            'rng' (random.PRNGKey): The random number generator key.
            'batch_size' (int): The batch size.
        Returns:
            Tuple[Tuple[jnp.ndarray]]: A tuple containing the constructed data and the original data.
        '''
        return self.create_batch(rng=rng,
                                 batch_size=batch_size)

    def get_data_info(self) -> Dict[str, any]:
        '''Returns the data information as a dict.'''
        return vars(self)

    @partial(jit, static_argnums=(0,2))
    def create_batch(self,
                     rng: random.PRNGKey,
                     batch_size: int) -> Tuple[Tuple[jnp.ndarray]]:
        '''
        Creates a batch of constructed data with partial observations.
        Args:
            'rng' (random.PRNGKey): The random number generator key.
            'batch_size' (int): The batch size.
        Returns:
            Tuple[Tuple[jnp.ndarray]]: A tuple containing the constructed data and the original data.
        '''
        (observed_data, observed_labels), (original_data, original_labels) = self.data_generator.get_data(rng=rng, batch_size=batch_size)
        constructed_data = jnp.zeros(shape=(batch_size, self.seq_len, self.embed_dim))
        constructed_data = constructed_data.at[:,:,0:self.obs_dim].set(observed_data)
        for k in range(1, self.embed_dim // self.obs_dim):
            shifted_data = jnp.concatenate((jnp.zeros(shape=(batch_size,(k),self.obs_dim)),observed_data[:,:-1*(k),:]),axis=1)
            constructed_data = constructed_data.at[:,:,k*self.obs_dim:(k+1)*self.obs_dim].set(shifted_data)
        return (constructed_data, observed_labels), (original_data, original_labels)


class ConstructedFullSeqGenerator(DataGenerator):
    '''Data generator for constructing data with fully observed sequences'''
    def __init__(self,
                 data_generator: DataGenerator,
                 embed_dim: int,
                 token_format):
        '''
        Initializes the ConstructedFullSeqGenerator.
        Args:
            'data_generator' (DataGenerator): The data generator.
            'embed_dim' (int): The embedding dimension.
        '''
        self.data_generator = data_generator
        self.embed_dim = embed_dim
        self.seq_len = data_generator.get_data_info()['seq_len']
        self.data_dim = data_generator.get_data_info()['data_dim']
        self.constr = True
        self.slots = 4
        self.obs_dim = data_generator.get_data_info()['obs_dim']
        self.token_format: Literal["full", "compact"] = token_format

    @partial(jit, static_argnums=(0,2))
    def get_data(self,
                 rng: random.PRNGKey,
                 batch_size: int):
        '''
        Gets a batch of constructed data with fully observed sequences.
        Args:
            'rng' (random.PRNGKey): The random number generator key.
            'batch_size' (int): The batch size.
        Returns:
            Tuple[Tuple[jnp.ndarray]]: A tuple containing the constructed data and the original data.
        '''
        return self.create_batch(rng=rng,
                                 batch_size=batch_size)

    def get_data_info(self) -> Dict[str, any]:
        '''
        Gets the data information.
        Returns:
            Dict[str, any]: The data information.
        '''
        return vars(self)

    @partial(jit, static_argnums=(0,2))
    def create_batch(
        self,
        rng: random.PRNGKey,
        batch_size: int,
    ) -> Tuple[Tuple[jnp.ndarray]]:
        '''
        Creates a batch of tokens in either format:
        - full: [0, 0, x_t, x_{t-1}]
        - compact: [x_t, x_{t-1}]

        Args:
            rng (random.PRNGKey): The random number generator key
            batch_size (int): The batch size

        Returns:
            Tuple[Tuple[jnp.ndarray]]: (constructed_data, labels), (original_data, labels)
        '''

        (observed_data, observed_labels), (original_data, original_labels) = \
            self.data_generator.get_data(rng=rng, batch_size=batch_size)

        shifted_data = jnp.pad(
            observed_data[:, :-1, :],
            ((0, 0), (1, 0), (0, 0)),
            mode='constant'
        )

        if self.token_format == "compact":
            # [x_t, x_{t-1}]
            constructed_data = jnp.concatenate(
                [observed_data, shifted_data],
                axis=-1
            )
        else:
            # [0, 0, x_t, x_{t-1}]
            constructed_data = jnp.concatenate([
                jnp.zeros((batch_size, self.seq_len, 2 * self.obs_dim)),
                observed_data,
                shifted_data
            ], axis=-1)

        return (constructed_data, observed_labels), (original_data, original_labels)


# Training Util

## Standard

In [28]:
def _compute_loss(preds: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:
        '''Computes the mean squared error (MSE) loss for a batch.'''
        assert preds.shape == targets.shape
        bs, sl, _ = preds.shape
        return (jnp.sum((targets -preds)**2)/(2*bs*sl))

def count_parameters(params) -> int:
    """
    Counts the total number of parameters in a Flax model.

    Args:
        params: Model parameters from state.params or model.init(...)['params']

    Returns:
        Total number of parameters
    """
    return sum(param.size for param in jax.tree_util.tree_leaves(params))

@jit
def _compute_token_loss(logits: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:
    """Compute language modeling loss avoiding boolean indexing."""
    bs, sl, vocab_size = logits.shape
    logits = logits.reshape(-1, vocab_size)
    targets = targets.reshape(-1)

    # Create float mask instead of boolean indexing
    valid_mask = (targets != -100).astype(jnp.float32)

    # Compute CE loss for all positions
    ce_loss = optax.softmax_cross_entropy_with_integer_labels(
        logits,
        jnp.maximum(targets, 0)  # Replace -100 with 0 temporarily
    )

    # Mask out invalid positions and average
    masked_loss = ce_loss * valid_mask
    num_valid = jnp.sum(valid_mask)
    loss = jnp.sum(masked_loss) / (num_valid + 1e-8)

    return loss

def compute_dfa_accuracy(preds, inputs, dfas, vocab, noop_token="."):
    """Compute DFA-based accuracy for a batch of predictions."""
    def process_sequence(pred_seq, input_seq, dfa):
        # Convert token IDs to characters
        pred_chars = [vocab.get_vocab(token) for token in pred_seq]
        input_chars = [vocab.get_vocab(token) for token in input_seq
                      if vocab.get_vocab(token) != noop_token]

        total = 0.0
        correct = 0.0

        for t in range(len(input_chars)):
            if len(input_chars) > t + 1:
                if input_chars[t + 1] == "|":
                    continue
                if input_chars[t + 1] == noop_token:
                    break

            if len(pred_chars) > t:
                current_chars = input_chars[:t + 1] + [pred_chars[t]]
                current_word = " ".join(current_chars).split(" | ")[-1]
                if current_word:
                    label = dfa(current_word)
                    total += 1
                    correct += label

        return correct, total

    # Process each sequence in the batch
    total_correct = 0
    total_count = 0
    for b in range(len(dfas)):
        correct, total = process_sequence(preds[b], inputs[b], dfas[b])
        total_correct += correct
        total_count += total

    return total_correct / (total_count + 1e-8)


def compute_icl_dfa_accuracy(self, predictions, inputs, dfas, vocab):
    """
    Compute accuracy for in-context learning DFA task.
    Each sequence contains multiple examples separated by '|'.
    We evaluate if the predicted continuation follows the DFA rules after seeing previous examples.
    """
    total_correct = 0
    total_count = 0

    # Convert to numpy for CPU processing
    preds_np = np.array(predictions)
    inputs_np = np.array(inputs)

    for b in range(len(dfas)):
        dfa = dfas[b]
        # Get character sequences
        pred_chars = [vocab.get_vocab(token) for token in preds_np[b]]
        input_chars = [vocab.get_vocab(token) for token in inputs_np[b]
                      if vocab.get_vocab(token) != "."]  # Remove padding

        # Process each position in the sequence
        for t in range(len(input_chars)):
            # Skip if we're at the end or next token isn't relevant
            if len(input_chars) > t + 1:
                if input_chars[t + 1] == "|":  # Skip separator
                    continue
                if input_chars[t + 1] == ".":  # End of sequence
                    break

            if len(pred_chars) > t:
                # Get the current sequence context and prediction
                current_chars = input_chars[:t + 1] + [pred_chars[t]]
                # Take only the last example (after last separator)
                current_word = " ".join(current_chars).split(" | ")[-1]
                if current_word:
                    # Check if prediction follows DFA rules
                    label = int(dfa(current_word))
                    total_count += 1
                    total_correct += label

    accuracy = total_correct / (total_count + 1e-8)
    return accuracy

class Training:
    '''Class for training the transformer model.'''
    def __init__(
        self,
        model: flax.linen.Module,
        optimizer: optax.GradientTransformation,
        data_generator: DataGenerator,
        batch_size: int,
        test_batch_size: int,
        task_type: str = 'continuous'  # 'continuous' or 'discrete'
    ):
        '''
        Initializes the training class with the specified parameters.
        Args:
            'model' (flax.linen.Module): The transformer model.
            'optimizer' (optax.GradientTransformation): The optimizer.
            'data_generator' (DataGenerator): The data generator.
            'batch_size' (int): The batch size for training.
            'test_batch_size' (int): The batch size for testing.
            'task_type' (str): The type of the task, either 'continuous' or 'discrete'.
        '''
        self.model = model
        self.optimizer = optimizer
        self.data_generator = data_generator
        self.batch_size = batch_size
        self.test_batch_size = test_batch_size
        self.task_type = task_type

        # Get data info based on task type
        data_info = data_generator.get_data_info()
        if task_type == 'continuous':
            self.obs_dim = data_info['obs_dim']
        else:  # discrete
            self.vocab_size = data_info['vocab_size']
        self.obs_dim = data_generator.get_data_info()['obs_dim']

    def get_init_state(self,
                       rng: random.PRNGKey) -> Tuple[train_state.TrainState, random.PRNGKey]:
        '''
        Initializes the training state with the specified random number generator key.
        Args:
            'rng' (jax.random.PRNGKey): The random number generator key.
        '''
        rng, ex_rng, init_rng = random.split(rng, 3)
        (exmp_inp, _, _), _ = self.data_generator.get_data(rng=ex_rng, batch_size=self.batch_size)
        params = self.model.init({'params': init_rng}, exmp_inp)['params']
        state_init = train_state.TrainState.create(apply_fn=self.model.apply, params=params, tx=self.optimizer)
        return state_init, rng

    def batch_to_input(self, batch: Tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
        '''Extracts the input data from the batch.'''
        data, _ = batch
        return data

#    def calculate_loss(
#        self,
#        params: flax.core.frozen_dict.FrozenDict,
#        batch: Tuple[jnp.ndarray, jnp.ndarray],
#    ) -> Tuple[jnp.ndarray, Tuple[any]]:
#        '''
#        Calculates the differentiable loss function.
#        Args:
#            'params' (flax.core.frozen_dict.FrozenDict): The model parameters.
#            'batch' (Tuple[jnp.ndarray, jnp.ndarray]): The input batch.
#        Returns:
#            Tuple[jnp.ndarray, Tuple[any]] 'loss' (jnp.ndarray): The computed loss and 'tf_data' (Tuple[any]): The data from the transformer forward pass.
#        '''
#        inp_data, labels = batch
#        logits, tf_data = self.model.apply({'params': params}, inp_data)
#        preds = logits[:, :, :self.obs_dim]
#        loss = _compute_loss(preds=preds, targets=labels)
#        return loss, tf_data

    def calculate_loss(
        self,
        params: flax.core.frozen_dict.FrozenDict,
        batch: Tuple[jnp.ndarray, jnp.ndarray],
        ) -> Tuple[jnp.ndarray, Tuple[any]]:
        """Calculate loss and return logits for evaluation."""
        inp_data, labels = batch
        logits, tf_data = self.model.apply({'params': params}, inp_data)

        if self.task_type == 'continuous':
            preds = logits[:, :, :self.obs_dim]
            loss = _compute_loss(preds=preds, targets=labels)
        else:  # discrete
            loss = _compute_token_loss(logits=logits, targets=labels)

        return loss, (logits, tf_data)

    @partial(jit, static_argnums=(0))
    def fast_train_step(
        self,
        state: train_state.TrainState,
        batch: Tuple[jnp.ndarray, jnp.ndarray],
    ) -> Tuple[train_state.TrainState, jnp.ndarray, Tuple[any]]:
        '''
        Performs a single training step.
        Args:
            'state' (flax.training.train_state.TrainState): The current training state.
            'batch' (Tuple[jnp.ndarray, jnp.ndarray]): The input batch.
        Returns:
            Tuple[train_state.TrainState, jnp.ndarray, Tuple[any]]:
                - 'state' (flax.training.train_state.TrainState): The updated training state,
                - 'loss' (jnp.ndarray): The computed loss,
                - 'tf_data' (Tuple[any]): The data from the transformer forward pass.
        '''
        loss_fn = lambda params: self.calculate_loss(params=params, batch=batch)
        (loss, (tf_data)), grads = value_and_grad(loss_fn, has_aux=True)(state.params)
        state = state.apply_gradients(grads=grads)
        return state, loss, tf_data

    @partial(jit, static_argnums=(0,))
    def _compute_step_outputs(
        self,
        params: flax.core.frozen_dict.FrozenDict,
        inputs: jnp.ndarray,
        targets: jnp.ndarray
    ):
        """JIT-compiled part of the evaluation step"""
        logits, tf_data = self.model.apply({'params': params}, inputs)

        # Ensure predictions stay within vocab range
        logits = jnp.clip(logits, -1e7, 1e7)  # Prevent overflow in softmax
        loss = _compute_token_loss(logits=logits, targets=targets)
        preds = jnp.argmax(logits, axis=-1)

        # Clip predictions to valid vocab range
        preds = jnp.clip(preds, 0, self.vocab_size - 1)

        return loss, preds


#    def fast_eval_step(
#      self,
#      params: flax.core.frozen_dict.FrozenDict,
#      batch: Tuple
#    ):
#        """Evaluation step with in-context learning accuracy."""
#        (inputs, targets, dfas) = batch[0]#
#
#        # JAX computations
#        loss, preds, logits = self._compute_step_outputs(params, inputs, targets)
#
#        if self.task_type == 'discrete':
#            # Compute in-context learning accuracy
#            accuracy = self.compute_icl_dfa_accuracy(
#                predictions=preds,
#                inputs=inputs,
#                dfas=dfas,
#                vocab=self.data_generator.data_module.vocab
#            )
#            return loss, accuracy
#
#        return loss, 0.0

    #@partial(jit, static_argnums=(0,))
    def fast_eval_step(
        self,
        params: flax.core.frozen_dict.FrozenDict,
        batch: Tuple
    ):
        """Evaluation step split into JAX and Python parts"""
        (inputs, targets, dfas) = batch[0]  # Unpack from the first tuple

        # JAX computations
        loss, preds = self._compute_step_outputs(params, inputs, targets)

        if self.task_type == 'discrete':
            # Convert to numpy for DFA processing
            preds_np = np.array(preds)
            inputs_np = np.array(inputs)

            # Python-based DFA accuracy computation
            total_correct = 0
            total_count = 0
            vocab = self.data_generator.data_module.vocab

            for b in range(len(dfas)):
                try:
                    dfa = dfas[b]
                    # Add safety checks for token IDs
                    pred_chars = []
                    for token in preds_np[b]:
                        if 0 <= token < len(vocab.vocab):
                            pred_chars.append(vocab.get_vocab(token))
                        else:
                            print(f"Warning: Invalid prediction token ID: {token}")
                            pred_chars.append(vocab.noop)  # Use noop token for invalid predictions

                    input_chars = []
                    for token in inputs_np[b]:
                        if 0 <= token < len(vocab.vocab):
                            char = vocab.get_vocab(token)
                            if char != ".":
                                input_chars.append(char)
                        else:
                            print(f"Warning: Invalid input token ID: {token}")

                    for t in range(len(input_chars)):
                        if len(input_chars) > t + 1:
                            if input_chars[t + 1] == "|":
                                continue
                            if input_chars[t + 1] == ".":
                                break

                        if len(pred_chars) > t:
                            current_chars = input_chars[:t + 1] + [pred_chars[t]]
                            current_word = " ".join(current_chars).split(" | ")[-1]
                            if current_word:
                                label = int(dfa(current_word))
                                total_count += 1
                                total_correct += label

                except Exception as e:
                    print(f"Error processing batch {b}: {e}")
                    print(f"Predictions shape: {preds_np.shape}, values: {preds_np[b]}")
                    print(f"Inputs shape: {inputs_np.shape}, values: {inputs_np[b]}")
                    print(f"Vocab size: {len(vocab.vocab)}")

            accuracy = total_correct / (total_count + 1e-8)
            return loss, accuracy

        return loss, 0.0


    @partial(jit, static_argnums=(0))
    def fast_pure_test_computation(self,
                                   params: flax.core.frozen_dict.FrozenDict,
                                   test_rng: random.PRNGKey) -> jnp.ndarray:
        '''
        Performs a full evaluation computation for performance evaluation of RevAlgs.
        Args:
            'params' (flax.core.frozen_dict.FrozenDict): The model parameters.
            'test_rng' (jax.random.PRNGKey): The random number generator key for testing.
        Returns:
            'test_loss' (jnp.ndarray): The computed test loss.
        '''
        test_loss = 0
        for _ in range(10):
            test_rng, batch_rng = random.split(test_rng, 2)
            batch_TEST, _ = self.data_generator.get_data(rng=batch_rng, batch_size=self.test_batch_size)
            (step_loss, _) = self.calculate_loss(params=params, batch=batch_TEST)
            test_loss += step_loss
        return test_loss/10

    @partial(jit, static_argnums=(0))
    def fast_sensitivity(self,
                         batch: Tuple[jnp.ndarray, jnp.ndarray],
                         state: train_state.TrainState) -> Tuple[List[jnp.ndarray]]:
        '''
        Performs sensitivity analysis.
        Args:
            'batch' (Tuple[jnp.ndarray, jnp.ndarray]): The input batch.
            'state' (flax.training.train_state.TrainState): The current training state.
        Returns:
            Tuple[List[jnp.ndarray]] ('listsnd', 'listmid', 'listlast'): The sensitivity analysis results for the second token, mid and last token in a sequence.
        '''
        _,s,_ = batch[0].shape
        target_k = [1,(s-1)//2,s-1]
        res_list = []
        for k in target_k:
            grad_of_output_l_wrt_x = lambda l: vmap(grad(lambda x: self.model.apply({'params': state.params}, x[None, ...])[1][0][1][0][k][l],         #1: second output, 0: activations, 1:layer,
                                                argnums=0))(batch[0])
            grads = vmap(lambda t: jnp.mean(jnp.linalg.norm(grad_of_output_l_wrt_x(t), axis=(2)),axis=0)[:k+1])(jnp.arange(self.model.embed_dim))
            grads_norm = jnp.mean(jnp.array(grads), axis=0)
            res_list.append(grads_norm)
        return tuple(res_list)

    def train_epoch(
        self,
        epoch: int,
        rng: random.PRNGKey,
        test_rng: random.PRNGKey,
        state: train_state.TrainState,
        num_batches_train: int,
    ):
        """Train for one epoch."""
        rng, tr_rng = random.split(rng, 2)

        if epoch == 0:
          num_params = count_parameters(state.params)
          print(f"Model has {num_params:,} parameters")

        # Evaluate
        test_loss = 0
        test_acc = 0
        num_test_batches = 10
        for _ in range(num_test_batches):
            test_rng, batch_rng = random.split(test_rng, 2)
            batch_TEST = self.data_generator.get_test_data(batch_size=self.test_batch_size, rng=test_rng)
            step_loss, step_acc = self.fast_eval_step(state.params, batch=batch_TEST)
            test_loss += step_loss
            test_acc += step_acc

        test_loss = test_loss / num_test_batches
        test_acc = test_acc / num_test_batches

        # Train
        for _ in jnp.arange(num_batches_train):
            tr_rng, batch_rng = random.split(tr_rng, 2)
            batch = self.data_generator.get_data(rng=batch_rng, batch_size=self.batch_size)
            state, _, _ = self.fast_train_step(state, batch=batch[0][:2])  # Only pass inputs and targets for training

        if self.task_type == 'discrete':
            print(f'Epoch {epoch}: loss = {test_loss:.4f}, accuracy = {test_acc:.4f}')
        else:
            print(f'Epoch {epoch}: loss = {test_loss:.4f}')

        return state, rng, test_loss, test_acc



## Language (NOT Regbench)

In [29]:
'''
@jit
def _compute_token_loss_and_metrics(logits: jnp.ndarray, targets: jnp.ndarray):
    """Compute cross entropy loss and perplexity with careful masking."""
    bs, sl, vocab_size = logits.shape
    logits = logits.reshape(-1, vocab_size)
    targets = targets.reshape(-1)

    # Create mask and ensure targets are valid indices
    valid_mask = (targets != -100)
    safe_targets = jnp.where(valid_mask, targets, 0)

    # Compute per-token losses
    ce_losses = optax.softmax_cross_entropy_with_integer_labels(logits, safe_targets)

    # Mask losses and get mean
    masked_losses = ce_losses * valid_mask
    num_valid = jnp.sum(valid_mask)
    loss = jnp.sum(masked_losses) / (num_valid + 1e-8)

    # Compute perplexity
    perplexity = jnp.exp(jnp.where(jnp.isfinite(loss), loss, 100.0))

    # Compute accuracy
    predictions = jnp.argmax(logits, axis=-1)
    correct = (predictions == targets) * valid_mask
    accuracy = jnp.sum(correct) / (num_valid + 1e-8)

    return loss, perplexity, accuracy
'''

@jit
def _compute_token_loss_and_metrics(logits: jnp.ndarray, targets: jnp.ndarray):
    """Basic cross entropy loss computation without any masking."""
    # Ensure numerical stability in the logits
    logits = jnp.clip(logits, -1e4, 1e4)

    # Debug info
    print("Logits stats:", jnp.min(logits), jnp.max(logits), jnp.mean(logits))
    print("Targets range:", jnp.min(targets), jnp.max(targets))

    # Reshape
    bs, sl, vocab_size = logits.shape
    # logits = logits.reshape(-1, vocab_size)
    # targets = targets.reshape(-1)

    valid_mask = (targets != -100)


    # Simple cross entropy
    ce_loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets)
    print('LOGITS SHAPE: ', targets.shape)
    print('TARGETS SHAPE: ', targets.shape)
    # loss = jnp.mean(ce_loss)
    masked_losses = ce_loss * valid_mask
    num_valid = jnp.sum(valid_mask)
    loss = jnp.sum(masked_losses) / (num_valid + 1e-8)

    # Basic accuracy
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(predictions == targets)

    return loss, jnp.exp(loss), accuracy

class LanguageModelTraining(Training):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @partial(jit, static_argnums=(0))
    def _compute_forward_pass(self, params, inp_data):
        """JIT-compiled forward pass."""
        return self.model.apply({'params': params}, inp_data)

    @partial(jit, static_argnums=(0))
    def calculate_loss(self,
                      params: flax.core.frozen_dict.FrozenDict,
                      batch: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[jnp.ndarray, Tuple[any]]:
        """Calculate loss and metrics for language modeling."""
        inp_data, labels = batch
        print('CALC LOSS FUN inp format: ', inp_data.shape)
        logits, tf_data = self._compute_forward_pass(params, inp_data)
        print('CALC LOSS FUN !!!!!!! logits shape: ', logits.shape)
        loss, perplexity, accuracy = _compute_token_loss_and_metrics(logits, labels)
        return loss, (logits, perplexity, accuracy, tf_data)

    @partial(jit, static_argnums=(0))
    def fast_train_step(self,
                       state: train_state.TrainState,
                       batch: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[train_state.TrainState, jnp.ndarray, Tuple[any]]:
        """Single training step."""
        loss_fn = lambda params: self.calculate_loss(params=params, batch=batch)
        (loss, aux), grads = value_and_grad(loss_fn, has_aux=True)(state.params)
        state = state.apply_gradients(grads=grads)
        return state, loss, aux

    @partial(jit, static_argnums=(0))
    def eval_step(self,
                 params: flax.core.frozen_dict.FrozenDict,
                 batch: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """Single evaluation step."""
        print('eval step targets shape: ', batch[1].shape)
        loss, (_, perplexity, accuracy, _) = self.calculate_loss(params, batch)
        return loss, perplexity, accuracy

    def train_epoch(self,
                   epoch: int,
                   rng: random.PRNGKey,
                   test_rng: random.PRNGKey,
                   state: train_state.TrainState,
                   num_batches_train: int):
        """Train for one epoch with language model specific metrics."""
        rng, tr_rng = random.split(rng, 2)

        if epoch == 0:
            num_params = count_parameters(state.params)
            print(f"Model has {num_params:,} parameters")

        # Evaluate
        test_metrics = {'loss': 0.0, 'perplexity': 0.0, 'accuracy': 0.0}
        num_test_batches = 10

        for _ in range(num_test_batches):
            test_rng, batch_rng = random.split(test_rng, 2)
            batch_TEST = self.data_generator.get_test_data(batch_size=self.test_batch_size, rng=test_rng)
            loss, perplexity, accuracy = self.eval_step(
                state.params,
                batch=batch_TEST[0][:2]  # Only take inputs and targets
            )
            test_metrics['loss'] += loss
            test_metrics['perplexity'] += perplexity
            test_metrics['accuracy'] += accuracy

        # Average metrics
        for k in test_metrics:
            test_metrics[k] /= num_test_batches

        # Train
        train_loss = 0
        for step in range(num_batches_train):
            tr_rng, batch_rng = random.split(tr_rng, 2)
            batch = self.data_generator.get_data(rng=batch_rng, batch_size=self.batch_size)
            state, train_loss, _ = self.fast_train_step(state, batch=batch[0][:2])

        train_loss = train_loss / num_batches_train

        l = test_metrics['loss']
        a = test_metrics['accuracy']
        p = test_metrics['perplexity']

        print(f'Epoch {epoch}: loss = {l:.4f}, accuracy = {a:.4f}, perplexity = {p:.4f}')
        print('Train loss: ', train_loss)

        return state, rng, l, a, p

# Experiment Util

## Aux models (old code, be careful)

In [30]:

#################################################
#                                               #
#                                               #
#       Implementation of Matrix-Inv.           #
#                Approximators                  #
#                                               #
#                                               #
#################################################

def invert_matrix_neumann(A: jnp.ndarray,
                          steps: int,
                          norm: float) -> jnp.ndarray:
    '''
    Function to approximate a matrix inverse using a truncated Neumann series.
    Args:
        'A' (jnp.ndarray): The matrix to be inverted.
        'steps' (int): The number of steps for the Neumann series.
        'norm' (float): The norm-scalar for enabling a neumann matrix approximation.
    Returns:
        jnp.ndarray: The approximation of the inverted matrix.
    '''
    n = A.shape[0]
    A = A / norm
    I = jnp.eye(n)
    diff = I - A
    inverse_approx = I
    term = I
    for _ in range(steps):
        term = term @ diff
        inverse_approx += term
    return inverse_approx/norm

def batched_neumann(steps: int,
                    norm: float) -> Callable[[jnp.ndarray], jnp.ndarray]:
    '''Function to batch the Neumann series approximation of a matrix inverse.'''
    return jax.vmap(partial(invert_matrix_neumann,
                            steps=steps,
                            norm=norm),
                    in_axes=(0))

def invert_matrix_newton(A: jnp.ndarray,
                         steps: int) -> jnp.ndarray:
    '''
    Function to approximate a matrix inverse using Newton's method.
    Args:
        'A' (jnp.ndarray): The matrix to be inverted.
        'steps' (int): The number of steps for Newton's method.
    Returns:
        jnp.ndarray: The approximation of the inverted matrix.
    '''
    n = A.shape[0]
    X = jnp.eye(n) / jnp.trace(A)
    for _ in range(steps):
        AX = jnp.dot(A, X)
        X = X @ (2 * jnp.eye(n) - AX)
    return X

def batched_newton(steps: int) -> Callable[[jnp.ndarray], jnp.ndarray]:
    '''Function to batch the Newton's method approximation of a matrix inverse.'''
    return jax.vmap(partial(invert_matrix_newton,
                            steps=steps),
                    in_axes=(0))

def invert_matrix_chebyshev(A: jnp.ndarray,
                            steps: int,
                            alphas: jnp.ndarray,
                            betas: jnp.ndarray) -> jnp.ndarray:
    norm = jnp.linalg.norm(A)
    n = A.shape[0]
    A = A/norm
    I = jnp.eye(n)
    diff = I - A
    inverse_approx = I
    term = I
    prev = I
    for alpha, beta in zip(alphas, betas):
        diff = I - alpha*A
        term = term@diff
        term_momentum = beta*(inverse_approx - prev)
        prev = inverse_approx
        inverse_approx = inverse_approx + term + term_momentum
    return inverse_approx/norm

def batched_chebyshev(steps: int,
                      alphas: List[float],
                      betas: List[float]) -> Callable[[jnp.ndarray], jnp.ndarray]:
    return jax.vmap(partial(invert_matrix_chebyshev,
                            steps=steps,
                            alphas=alphas,
                            betas=betas),
                    in_axes=(0))

def invert_matrix_richardson(A, omegas):
    norm = jnp.linalg.norm(A)
    A = A/norm
    I = jnp.eye(A.shape[0])
    diff = I - A
    inverse_approx = I
    term = I

    for omega in omegas:
        diff = I - omega*A
        term = term@diff
        inverse_approx = inverse_approx + term
    return inverse_approx/norm

def batched_richardson(omegas):
    return jax.vmap(partial(invert_matrix_richardson,
                            omegas=omegas),
                    in_axes=0)

#################################################
#                                               #
#                                               #
#             Learn parameters for              #
#            chebyshev-inverse-apx.             #
#                                               #
#                                               #
#################################################

def learn_parameters_chebyshev(num_steps: int,
                               train_len: int,
                               experiment_config: config_dict.ConfigDict,
                               data_generator: DataGenerator,
                               part_obs_constr: bool = False,
                               part_obs_embed_dim: int = 80,
                               seq_len: int = 50,
                               use_mlp: bool = False,
                               init_alphas: jnp.ndarray = None,
                               init_betas: jnp.ndarray = None) -> Tuple[jnp.ndarray, jnp.ndarray]:
    '''Training logic for GD to learn optimal parameters for Chebyshev, evaluated using Sequencesolver on autoregressive tasks of choice.'''
    def single_cheb_pred(params_s, seq_data_s, seq_labels_s, lamb_s):
            preds = []
            for token in range(seq_len):
                inv_mat = seq_data_s[:token].T@seq_data_s[:token] + lamb_s*jnp.eye(seq_data_s[:token].shape[1])
                w_hat = invert_matrix_chebyshev(A=inv_mat,
                                            steps=num_steps,
                                            alphas=params_s['params']['alphas'],
                                             betas=params_s['params']['betas']) @ (seq_data_s[:token].T@seq_labels_s[:token])
                if use_mlp:
                    test_token = data_generator._mini_mlp(seq_labels_s[token])
                    test_token = test_token/(jnp.linalg.norm(test_token)+1e-16)
                else:
                    test_token = seq_labels_s[token]
                seq_label_hat = jnp.matmul(test_token, w_hat)
                preds.append(seq_label_hat)
            return jnp.array(preds)

    def cheb_lsq_pred(params, seq_data, seq_labels, lamb):
        vectorized_cheb_pred = jax.vmap(single_cheb_pred, in_axes=(None, 0, 0, None))(params, seq_data, seq_labels, lamb)
        return vectorized_cheb_pred

    def get_features(batch: jnp.ndarray) -> jnp.ndarray:
        feature_seq_func = lambda seq : jax.vmap(data_generator._mini_mlp)(seq)
        feature_batch = jax.vmap(feature_seq_func, in_axes=(0))(batch)
        return feature_batch

    def cheb_loss(params, rng):
        rng, batch_rng = jax.random.split(rng)
        batch, _ = data_generator.get_data(rng=batch_rng, batch_size=experiment_config.data.batch_size)
        data, labels = batch
        if part_obs_constr:
            batch_size, seq_len, obs_dim = data.shape
            constructed_data = jnp.zeros(shape=(batch_size, seq_len, part_obs_embed_dim))
            constructed_data = constructed_data.at[:,:,0:obs_dim].set(data)
            for k in range(1, part_obs_embed_dim//obs_dim):
                shifted_data = jnp.concatenate((jnp.zeros(shape=(batch_size,(k),obs_dim)),data[:,:-1*(k),:]),axis=1)
                constructed_data = constructed_data.at[:,:,k*obs_dim:(k+1)*obs_dim].set(shifted_data)
            shifted_data = jnp.concatenate([jnp.expand_dims(constructed_data[:, 0, :], 1)*0, constructed_data], axis=1)[:, :-1, :]
            data = constructed_data
            preds_chebyshev = cheb_lsq_pred(params, shifted_data, data, 0.001)[:,:,0:obs_dim]
        elif use_mlp:
            dat_feat = get_features(data)
            dat_feat /= (jnp.linalg.norm(dat_feat,axis=-1)[...,None])
            shifted_data = jnp.concatenate([jnp.expand_dims(dat_feat[:, 0, :], 1)*0, dat_feat], axis=1)[:, :-1, :]
            preds_chebyshev = cheb_lsq_pred(params, shifted_data, data/(jnp.linalg.norm(data, axis=-1)[...,None]), 0.001)
        else:
            shifted_data = jnp.concatenate([jnp.expand_dims(data[:, 0, :], 1)*0, data], axis=1)[:, :-1, :]
            preds_chebyshev = cheb_lsq_pred(params, shifted_data, data, 0.001)
        loss = _compute_loss(preds=preds_chebyshev, targets=labels)
        return loss, rng

    def cheb_train_step(state, rng):
        loss_fn = lambda params: cheb_loss(params=params, rng=rng)
        (loss, rng), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
        state = state.apply_gradients(grads=grads)
        return state, (loss, rng)
    fast_cheb_train_step = jax.jit(cheb_train_step)

    def cheb_training(state):
        for i in range(train_len):
            rng=jax.random.PRNGKey(seed=i)
            state, (loss,_) = fast_cheb_train_step(state, rng)
            print(loss)
            if i % 1000 == 0:
                print(state.params)
        return state

    init_params = {'params': {}}
    init_params['params']['alphas'] = jnp.ones(shape=(num_steps,)) if init_alphas == None else init_alphas
    init_params['params']['betas'] = jnp.zeros(shape=(num_steps,)) if init_betas == None else init_betas

    optimizer = Optimizer().get_optimizer()

    state_cheb = train_state.TrainState.create(apply_fn=cheb_lsq_pred, params=init_params, tx=optimizer)
    state_cheb = cheb_training(state_cheb)

    return state_cheb.params['params']['alphas'], state_cheb.params['params']['betas']


#################################################
#                                               #
#                                               #
#       Implementation of Aux-Models            #
#                                               #
#                                               #
#################################################

class AuxModel(metaclass=abc.ABCMeta):
    '''Abstract Base Class for auxiliary models'''

    @abc.abstractmethod
    def predict(self, shifted_data, data):
        '''Abstract method to get prediction for data batch'''
        raise NotImplementedError

class LeastSquaresSequenceSolver(AuxModel):
    '''Class implementing a least squares solver for sequence data.'''
    def __init__(self,
                 approximator: str,
                 seq_len: int,
                 apx_steps: int = 6,
                 apx_norm: float = 70,
                 lamb: float = 0.001,
                 use_mlp: bool = False,
                 mlp_fn = None,
                 alphas = None,
                 betas = None):
        '''
        Initializes the LeastSquaresSequenceSolver.
        Args:
            'approximator' (str): The approximator to use for inverting the matrix.
            'seq_len' (int): The sequence length.
            'apx_steps' (int): The number of steps for the approximator.
            'apx_norm' (float): The norm scalar for the matrix approximation.
            'lamb' (float): The lambda parameter for the least squares solver.
        '''
        if not approximator == None:
            if approximator not in ['neumann', 'newton', 'chebyshev', 'richardson', 'None']:
                raise ValueError(f"Approximator {approximator} not supported")

        if approximator == 'neumann':
            self.inverter = lambda A : invert_matrix_neumann(A, apx_steps, apx_norm)
        elif approximator == 'newton':
            self.inverter = lambda A : invert_matrix_newton(A, apx_steps)
        elif approximator == 'richardson':
            self.inverter = lambda A : invert_matrix_richardson(A, omegas=alphas)
        elif approximator == 'chebyshev':
            self.inverter = lambda A : invert_matrix_chebyshev(A, apx_steps, alphas=alphas, betas=betas)
        else:
            self.inverter = jnp.linalg.inv

        self.seq_len = seq_len
        self.apx_steps = apx_steps
        self.apx_norm = apx_norm
        self.lamb = lamb
        self.use_mlp = use_mlp
        self.mlp_fn = mlp_fn

    def predict(self,
                shifted_data: jnp.ndarray,
                data: jnp.ndarray) -> jnp.ndarray:
        '''
        Function to get predictions for a batch of data.
        Args:
            'shifted_data' (jnp.ndarray): The shifted data.
            'data' (jnp.ndarray): The original data.
        Returns:
            jnp.ndarray: The predictions.
        '''
        return self.all_preds(seq_data=shifted_data,
                              seq_labels=data,
                              seq_len=self.seq_len,
                              lamb=self.lamb)

    def least_squares_one_iter(self,
                               seq_data: jnp.ndarray,
                               seq_labels: jnp.ndarray,
                               lamb: float) -> jnp.ndarray:
        '''Function to perform one iteration of the least squares solver.'''
        return self.inverter(seq_data.T@seq_data + lamb*jnp.eye(seq_data.shape[1])) @ (seq_data.T@seq_labels)

    def least_squares_seq_pred_single_seq(self,
                               seq_data: jnp.ndarray,
                               seq_labels: jnp.ndarray,
                               seq_len: int,
                               lamb: float) -> jnp.ndarray:
        '''
        Function to get predictions for a single sequence using the least squares solver.
        Args:
            'seq_data' (jnp.ndarray): The sequence data.
            'seq_labels' (jnp.ndarray): The sequence labels.
            'seq_len' (int): The sequence length.
            'lamb' (float): The lambda parameter for the least squares solver.
        Returns:
            jnp.ndarray: The predictions.
        '''
        preds = []
        for token in range(seq_len):
            w_hat = self.least_squares_one_iter(seq_data[:token], seq_labels[:token], lamb=lamb)
            if self.use_mlp:
                test_token = self.mlp_fn(seq_labels[token])
                test_token = test_token/(jnp.linalg.norm(test_token)+1e-16)
            else:
                test_token =seq_labels[token]
            seq_label_hat = jnp.matmul(test_token, w_hat)
            preds.append(seq_label_hat)
        return jnp.array(preds)

    def all_preds(self,
                  seq_data: jnp.ndarray,
                  seq_labels: jnp.ndarray,
                  seq_len: int,
                  lamb: float) -> jnp.ndarray:
        '''Function to get predictions for all sequences using the least squares solver.'''
        return jax.vmap(self.least_squares_seq_pred_single_seq, in_axes=(0,0,None,None))(seq_data, seq_labels, seq_len, lamb)

    def get_features(self, batch: jnp.ndarray) -> jnp.ndarray:
        feature_seq_func = lambda seq : jax.vmap(self.mlp_fn)(seq)
        feature_batch = jax.vmap(feature_seq_func, in_axes=(0))(batch)
        return feature_batch

    def opt_lamb(self,
                 minv: float,
                 maxv: float,
                 steps: int,
                 data_generator: DataGenerator,
                 loss_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
                 part_obs_constr: bool,
                 embed_dim: int = 80,
                 constr: bool = False,
                 slots: bool = 4) -> jnp.ndarray:
        '''
        Function to optimize lambda parameter for least squares solver via line-search.
        Sets the lambda parameter to the value that minimizes the loss function and returns it
        Args:
            'minv' (float): The minimum value for the lambda parameter.
            'maxv' (float): The maximum value for the lambda parameter.
            'steps' (int): The number of steps for the line-search.
            'data_generator' (DataGenerator): The data generator.
            'loss_fn' (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]): The loss function.
            'part_obs_constr' (bool): Use part.-obs. construction (concat. past k tokens)
            'embed_dim' (int): Construction size for partobs
            'constr' (bool): Use Full-Obs. construction,
            'slots' (int): Token-'slots' in Full-Obs. construction
        Returns:
            jnp.ndarray: The optimized lambda parameter.
        '''
        min_score = float('inf')
        range_vals = jnp.linspace(minv, maxv, steps)
        rng = jax.random.PRNGKey(42)
        rng, test_rng = jax.random.split(rng)

        for lam in range_vals:
            (data, targets), _ = data_generator.get_data(rng=test_rng, batch_size=512)

            if part_obs_constr:
                batch_size, seq_len, obs_dim = data.shape
                constructed_data = jnp.zeros(shape=(batch_size, seq_len, embed_dim))
                constructed_data = constructed_data.at[:,:,0:obs_dim].set(data)
                for k in range(1, embed_dim//obs_dim):
                    shifted_data = jnp.concatenate((jnp.zeros(shape=(batch_size,(k),obs_dim)),data[:,:-1*(k),:]),axis=1)
                    constructed_data = constructed_data.at[:,:,k*obs_dim:(k+1)*obs_dim].set(shifted_data)
                shifted_data = jnp.concatenate([jnp.expand_dims(constructed_data[:, 0, :], 1)*0, constructed_data], axis=1)[:, :-1, :]
                data = constructed_data
                preds_lsq = self.all_preds(seq_data=shifted_data,
                                           seq_labels=data,
                                           seq_len=self.seq_len,
                                           lamb=lam)[:,:,0:obs_dim]
            elif self.use_mlp:
                dat_feat = self.get_features(data)
                dat_feat /= (jnp.linalg.norm(dat_feat,axis=-1)[...,None])
                shifted_data = jnp.concatenate([jnp.expand_dims(dat_feat[:, 0, :], 1)*0, dat_feat], axis=1)[:, :-1, :]
                preds_lsq = self.all_preds(seq_data=shifted_data,
                                           seq_labels=data/(jnp.linalg.norm(data, axis=-1)[...,None] + 1e-16),
                                           seq_len=self.seq_len,
                                           lamb=lam)
            else:
                if constr:
                    shifted_data = data[:,:,(slots-1)*targets.shape[-1]:]
                    data= data[:,:,(slots-2)*targets.shape[-1]:(slots-1)*targets.shape[-1]]
                else:
                    shifted_data = jnp.concatenate([jnp.expand_dims(data[:, 0, :], 1)*0, data], axis=1)[:, :-1, :]

                preds_lsq = self.all_preds(seq_data=shifted_data,
                                           seq_labels=data,
                                           seq_len=self.seq_len,
                                           lamb=lam)

            score = loss_fn(preds_lsq, targets)
            print(f"for lambda = {lam:.6f} lsq-loss: ", score)
            if score < min_score:
                min_score = score
                self.lamb = lam
        return self.lamb

class GDSequenceSolver:
    '''Class implementing a gradient descent solver for sequence data.'''
    def __init__(self,
                 eta: float,
                 lamb: float = 0,
                 seq_len: int = 50):
        '''
        Initializes the GDSequenceSolver.
        Args:
            'eta' (float): The learning rate.
            'lamb' (float): The (optional) lambda parameter.
        '''
        self.eta = eta
        self.lamb = lamb
        self.seq_len = seq_len

    def predict(self,
                shifted_data: jnp.ndarray,
                data: jnp.ndarray) -> jnp.ndarray:
        '''
        Function to get predictions for a batch of data.
        Args:
            'shifted_data' (jnp.ndarray): The shifted data.
            'data' (jnp.ndarray): The original data.
        Returns:
            jnp.ndarray: The predictions.
        '''
        return self.all_preds(seq=data,
                              seq_shifted=shifted_data,
                              eta=self.eta)

    def gd_delta(self,
           seq: jnp.ndarray,
           seq_shifted: jnp.ndarray,
           idx: int) -> jnp.ndarray:
        '''Function to get the delta for the gradient descent solver.'''
        outer_productsGD = jnp.matmul(seq_shifted[:, :, None], seq[:, None, :])
        resultGD = jnp.cumsum(outer_productsGD, axis=0)
        return resultGD[idx]

    def one_step_gd(self,
                    seq: jnp.ndarray,
                    seq_shifted: jnp.ndarray,
                    eta: float,
                    lamb: float,
                    seq_len: int) -> jnp.ndarray:
        '''
        Function to perform one step of the gradient descent solver.
        Args:
            'seq' (jnp.ndarray): The sequence data.
            'seq_shifted' (jnp.ndarray): The shifted sequence data.
            'eta' (float): The learning rate.
            'lamb' (float): The lambda parameter.
            'seq_len' (int): The sequence length of the test sequences.
        Returns:
            jnp.ndarray: The gradient descent updates.
        '''
        result = []
        for idx in range(seq_len):
            deltaWi = self.gd_delta(seq=seq,
                                    seq_shifted=seq_shifted,
                                    idx=idx)
            deltaWi += lamb*deltaWi
            gd_update = eta * (seq[idx] @ deltaWi)
            result.append(gd_update)
        return jnp.array(result)

    def all_preds(self, seq: jnp.ndarray, seq_shifted: jnp.ndarray, eta: float) -> jnp.ndarray:
        '''Function to get predictions for all sequences in a batch using the gradient descent solver.'''
        return jax.vmap(self.one_step_gd, in_axes=(0,0,None,None,None))(seq, seq_shifted, eta, self.lamb, self.seq_len)

    def opt_eta(self,
                minv: float,
                maxv: float,
                steps: int,
                data_generator: DataGenerator,
                loss_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
                constr: bool = False,
                slots: bool = 4) -> jnp.ndarray:
        '''
        Function to optimize eta parameter for GD sequence solver via line-search.
        Sets the eta parameter to the value that minimizes the loss function and returns it
        Args:
            'minv' (float): The minimum value for the lambda parameter.
            'maxv' (float): The maximum value for the lambda parameter.
            'steps' (int): The number of steps for the line-search.
            'data_generator' (DataGenerator): The data generator.
            'loss_fn' (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]): The loss function.
            'constr' (bool): Use Full-Obs. construction,
            'slots' (int): Token-'slots' in Full-Obs. construction
        Returns:
            jnp.ndarray: The optimized lambda parameter.
        '''
        min_score = float('inf')
        range_vals = jnp.linspace(minv, maxv, steps)
        rng = jax.random.PRNGKey(42)
        rng, test_rng = jax.random.split(rng)

        for eta in range_vals:
            (data, targets), _ = data_generator.get_data(rng=test_rng, batch_size=512)
            if constr:
                shifted_data = data[:,:,(slots-1)*targets.shape[-1]:]
                data= data[:,:,(slots-2)*targets.shape[-1]:(slots-1)*targets.shape[-1]]
            else:
                shifted_data = jnp.concatenate([jnp.expand_dims(data[:, 0, :], 1)*0, data], axis=1)[:, :-1, :]

            preds_gd = self.all_preds(seq=data,
                                      seq_shifted=shifted_data,
                                      eta=eta)

            score = loss_fn(preds_gd, targets)
            print(f"for eta = {eta:.6f} lsq-loss: ", score)
            if score < min_score:
                min_score = score
                self.eta = eta
        return self.eta

## Sequence performance evaluator (old code, be careful)

In [31]:
class EvalModel(abc.ABC):
    '''Abstract base class for evaluation models.'''
    @abc.abstractmethod
    def evaluate(self,
                 data: jnp.ndarray,
                 targets: jnp.ndarray,
                 loss_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
        pass

class TFEvaluator(EvalModel):
    '''Class for evaluating transformer models.'''
    def __init__(self,
                 model: nn.Module,
                 state: train_state.TrainState,
                 constr: bool,
                 slots: int):
        '''
        Initializes the TFEvaluator.
        Args:
            'model' (nn.Module): The model to evaluate.
            'state' (train_state.TrainState): The state of the model.
            'constr' (bool): Whether the model is using constructed data.
            'slots' (int): The number of slots.
        '''
        self.model = model
        self.state = state
        self.constr = constr
        self.slots = slots

    def evaluate(self,
                 data:jnp.ndarray,
                 targets:jnp.ndarray,
                 loss_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
        '''
        Evaluates the model.
        Args:
            'data' (jnp.ndarray): The input data.
            'targets' (jnp.ndarray): The target data.
            'loss_fn' (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]): The loss function.
        Returns:
            'jnp.ndarray': The loss.
        '''
        logits, _ = self.model.apply({'params': self.state.params}, data)
        preds = logits[:, :, :targets.shape[-1]]

        print('preds: ', preds.shape)
        print('data: ', data.shape)
        print('targets: ', targets.shape)

        return loss_fn(preds, targets)

class AuxmodelEvaluator(EvalModel):
    '''Class for evaluating auxiliary models.'''
    def __init__(self,
                 model: AuxModel,
                 state: any,
                 constr: bool,
                 slots: int,
                 part_obs: bool,
                 use_mlp: bool):
            '''
            Initializes the AuxmodelEvaluator.
            Args:
                'model' (AuxModel): The model to evaluate.
                'state' (int): Used for Part.-Obs. Construction (holds embed_dim) or mlp (holds mlp_function)
                'constr' (bool): Whether the model is using constructed data.
                'slots' (int): The number of slots.
                'part_obs' (bool): Evaluate on constructed token of past partial observations
                'use_mlp' (bool): Evaluate on mlp features

            '''
            self.model = model
            self.state = state
            self.constr = constr
            self.slots = slots
            self.part_obs = part_obs
            self.use_mlp = use_mlp

    def get_features(self, batch: jnp.ndarray) -> jnp.ndarray:
        feature_seq_func = lambda seq : vmap(self.state)(seq)
        feature_batch = vmap(feature_seq_func, in_axes=(0))(batch)
        return feature_batch

    def evaluate(self,
                 data:jnp.ndarray,
                 targets:jnp.ndarray,
                 loss_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
        '''
        Evaluates the model.
        Args:
            'data' (jnp.ndarray): The input data.
            'targets' (jnp.ndarray): The target data.
            'loss_fn' (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]): The loss function.
        Returns:
            'jnp.ndarray': The loss.
        '''
        if self.part_obs:
            batch_size, seq_len, obs_dim = data.shape
            embed_dim = self.state
            constructed_data = jnp.zeros(shape=(batch_size, seq_len, embed_dim))
            constructed_data = constructed_data.at[:,:,0:obs_dim].set(data)
            for k in range(1, embed_dim // obs_dim):
                shifted_data = jnp.concatenate((jnp.zeros(shape=(batch_size,(k),obs_dim)),data[:,:-1*(k),:]),axis=1)
                constructed_data = constructed_data.at[:,:,k*obs_dim:(k+1)*obs_dim].set(shifted_data)
            shifted_data = jnp.concatenate([jnp.expand_dims(constructed_data[:, 0, :], 1)*0, constructed_data], axis=1)[:, :-1, :]
            data = constructed_data
            preds = self.model.predict(shifted_data=shifted_data, data=data)[:,:,0:obs_dim]
        elif self.use_mlp:
            dat_feat = self.get_features(data)
            dat_feat /= (jnp.linalg.norm(dat_feat,axis=-1)[...,None])
            shifted_data = jnp.concatenate([jnp.expand_dims(dat_feat[:, 0, :], 1)*0, dat_feat], axis=1)[:, :-1, :]
            preds = self.model.predict(shifted_data=shifted_data, data=data/(jnp.linalg.norm(data, axis=-1)[...,None]))
        else:
            if self.constr:
                shifted_data = data[:,:,(self.slots-1)*targets.shape[-1]:]
                data = data[:,:,(self.slots-2)*targets.shape[-1]:(self.slots-1)*targets.shape[-1]]
            else:
                shifted_data = jnp.concatenate([jnp.expand_dims(data[:, 0, :], 1)*0, data], axis=1)[:, :-1, :]
            preds = self.model.predict(shifted_data=shifted_data, data=data)
            print('lsqpreds: ', preds.shape)
            print('lsqdata: ', data.shape)
            print('lsqtargets: ', targets.shape)
        return loss_fn(preds, targets)

def get_evaluator(model_type:str, **kwargs) -> EvalModel:
    '''Returns an evaluator based on the model type.'''
    if model_type.lower() in ['transformer', 'mesa-transformer', 'fwp', 'dfwp', 'delta', 'deep-delta']:
        return TFEvaluator(**kwargs)
    elif model_type.lower() in ['lsq', 'gd', 'lsq_partobs', 'lsq_mlp']:
        return AuxmodelEvaluator(part_obs=(model_type == 'lsq_partobs'), use_mlp=(model_type == 'lsq_mlp'), **kwargs)
    else:
        raise ValueError(f"Unsupported model type: {model_type}")

class SequencePredictionEvaluator:
    '''Class for evaluating sequence prediction models across test sequences, per token.'''
    def __init__(self,
                 data_generator: DataGenerator,
                 test_batch_size: int,
                 seeds: List[int],
                 model_list: List[str],
                 models: List[nn.Module],
                 states: List[train_state.TrainState],
                 loss_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
        '''
        Initializes the SequencePredictionEvaluator.
        Args:
            'data_generator' (DataGenerator): The data generator.
            'test_batch_size' (int): The batch size for testing.
            'seeds' (List[int]): The seeds for testing.
            'model_list' (List[str]): The list of model types.
            'models' (List[nn.Module]): The models to evaluate.
            'states' (List[any]): The states of the models.
            'loss_fn' (Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]): The loss function for evaluation.
        '''
        self.data_generator = data_generator
        self.test_batch_size = test_batch_size
        self.seeds = seeds
        self.model_list = model_list
        self.models = models
        self.states = states
        self.loss_fn = loss_fn
        self.constr = data_generator.get_data_info()['constr']
        if self.constr:
            self.slots = 2 if data_generator.get_data_info()['token_format'] == 'compact' else 4
        else:
            self.slots = 0
        self.losses = [[] for _ in range(len(self.model_list))]

        print(self.slots)

    def run(self) -> Dict[str, any]:
        '''Runs the evaluation experiment.'''
        for seed in self.seeds:
            test_rng = random.PRNGKey(seed)
            (data, targets), _ = self.data_generator.get_data(rng=test_rng, batch_size=self.test_batch_size)
            for modelname, idx in zip(self.model_list, jnp.arange(len(self.model_list))):
                print(f'Evaluating model: {modelname}')
                evaluator = get_evaluator(model_type=modelname, model=self.models[idx], state=self.states[idx], constr=self.constr, slots=self.slots)
                result = evaluator.evaluate(data, targets, self.loss_fn)
                self.losses[idx].append(result)

        return {'losses': self.losses}

# Experiments

###  Setup

In [32]:
class TrainingInitializer:
    '''Class for initializing training components'''
    def __init__(self,
                 model_config: config_dict.ConfigDict,
                 experiment_config: config_dict.ConfigDict):
        '''
        Initializes the training components
        Args:
            model_config: Configuration for the model
            experiment_config: Configuration for the experiment
        '''

        self.model_config = model_config
        self.experiment_config = experiment_config

    def get_tf_model(self):
        '''Returns the transformer model based on the configuration'''
        initializer = get_fast_weight_init(
            num_layers=self.model_config.num_layers + self.model_config.hybrid_first_block,
            d_model=self.model_config.embed_dim,
            key_size=self.model_config.key_size,
            range_dfwp=self.model_config.range_dfwp,
        )

        return FullTransformerModel(
            use_gla=self.model_config.use_gla,
            is_discrete=self.model_config.is_discrete,
            vocab_size=self.model_config.vocab_size,
            pad_token_id=None if not self.experiment_config.data.ttype == 'regbench' else data_module.vocab.get_id(data_module.vocab.noop),
            use_depth=self.model_config.use_depth,
            use_emb=self.model_config.use_emb,
            use_fwp=self.model_config.use_fwp,
            range_dfwp=self.model_config.range_dfwp,
            use_pe_kq=self.model_config.use_pe_kq,
            use_pe_emb=self.model_config.use_pe_emb,
            hybrid_first_block=self.model_config.hybrid_first_block,
            pe_dim=self.model_config.pe_dim,
            out_dim=self.model_config.out_dim,
            initializer=initializer,
            use_layernorm=self.model_config.use_layernorm,
            use_bias=self.model_config.use_bias,
            use_mlp=self.model_config.use_mlp,
            masked=self.model_config.masked,
            use_clip=self.model_config.use_clip,
            clip_range=self.model_config.clip_range,
            num_layers=self.model_config.num_layers,
            num_heads=self.model_config.num_heads,
            embed_dim=self.model_config.embed_dim,
            key_size=self.model_config.key_size,
            seq_len=self.experiment_config.data.seq_len,
            dim_feedforward_MLP=self.model_config.dim_feedforward_MLP,
            linear=self.model_config.linear,
            use_schlagnorm=self.model_config.use_schlagnorm,
            schlagnorm_targets=self.model_config.schlagnorm_targets
        )

    def get_optimizer(self):
        '''Returns the optimizer based on the configuration'''
        return Optimizer(
            grad_clip=self.experiment_config.optim.grad_clip,
            peak_lr=self.experiment_config.optim.peak_lr,
            use_schedule=self.experiment_config.optim.use_schedule,
            warmup_steps=self.experiment_config.optim.warmup_steps,
            max_iters=self.experiment_config.optim.max_iters,
            init_value=self.experiment_config.optim.init_value,
            end_value=self.experiment_config.optim.end_value,
            weight_decay=self.experiment_config.optim.weight_decay
        ).get_optimizer()

    def get_data_generator(self, ttype:str):
        '''Returns the data generator based on the configuration'''
        if ttype == 'seq_lin':
            return LinearSequenceDataGenerator(
                seq_len=self.experiment_config.data.seq_len,
                data_dim=self.experiment_config.data.data_dim,
                obs_dim=self.experiment_config.data.obs_dim,
                range=self.experiment_config.data.range,
                noise=self.experiment_config.data.noise,
                noise_obs=self.experiment_config.data.noise_obs,
                data_clip=self.experiment_config.data.data_clip,
                eye_obs=self.experiment_config.data.eye_obs
            )
        elif ttype == 'fixed':
            raise NotImplementedError('Fixed-W Sequence Generation is not implemented yet.')
        elif ttype == 'seq_nonlin':
            raise NotImplementedError('Nonlinear Sequence Generation is not implemented yet.')
        elif ttype == 'contracting':
            raise NotImplementedError('Contracting Sequence Generation is not implemented yet.')
        elif ttype == 'seq_lin_constr':
            return ConstructedFullSeqGenerator(
                data_generator=self.get_data_generator('seq_lin'),
                embed_dim=self.experiment_config.data.embed_dim,
                token_format=self.experiment_config.data.token_format,
            )
        elif ttype == 'seq_lin_constr_part':
            return ConstructedPartObsGenerator(data_generator=self.get_data_generator('seq_lin'),
                                                embed_dim=self.experiment_config.data.embed_dim)
        elif ttype == 'regbench':
            return DFADataGenerator(
                data_module=data_module,
                seq_len=data_module.input_seq_len,
                data_dim=data_module.vocab_size,
                eye_obs=True
            )
        elif ttype == 'wikitext':
            return generator
            # return WikiTextDataGenerator(
            #     seq_len=self.experiment_config.data.seq_len,
            #     data_dim=self.experiment_config.data.vocab_size,  # This will be our vocab size
            #     eye_obs=True
            # )
        else:
            raise ValueError('Data type not recognized')

    def get_train_module(self, model, optimizer, data_generator):
            '''Returns the training module based on the configuration'''
            if self.experiment_config.data.ttype == 'wikitext':
                return LanguageModelTraining(
                    model=model,
                    optimizer=optimizer,
                    data_generator=data_generator,
                    batch_size=self.experiment_config.data.batch_size,
                    test_batch_size=self.experiment_config.data.test_batch_size,
                    task_type=self.experiment_config.data.task_type
                )
            else:
                return Training(
                    model=model,
                    optimizer=optimizer,
                    data_generator=data_generator,
                    batch_size=self.experiment_config.data.batch_size,
                    test_batch_size=self.experiment_config.data.test_batch_size,
                    task_type=self.experiment_config.data.task_type
                )

    def setup_components(self):
        '''Returns the components'''
        model = self.get_tf_model()
        optimizer = self.get_optimizer()
        data_generator = self.get_data_generator(self.experiment_config.data.ttype)
        train_module = self.get_train_module(model, optimizer, data_generator)

        return (model, optimizer, data_generator, train_module)

### Language Experiments

#### Runs

Standard Delta Rule

In [34]:
experiment_config = get_experiment_config(seeds=[42])
model_config = get_model_config(use_depth=False, use_gla=False)
training_initializer = TrainingInitializer(model_config=model_config, experiment_config=experiment_config)
model_sd, optimizer, data_generator, train_module_sd = training_initializer.setup_components()

losses_seed_sd = []
accs_seed_sd = []
pers_seed_sd = []
for training_seed in range(1):
  losses_run = []
  accs = []
  pers = []
  rng = jax.random.PRNGKey(training_seed)
  rng, test_rng, train_rng = jax.random.split(rng, 3)
  state_tf_sd, rng = train_module_sd.get_init_state(rng)
  for epoch_idx in range(30):
      state_tf_sd, train_rng, loss, test_acc = train_module_sd.train_epoch(
          epoch=epoch_idx,
          state=state_tf_sd,
          rng=train_rng,
          test_rng=test_rng,
          num_batches_train=300
      )
      losses_run.append(loss)
      accs.append(test_acc)
      #pers.append(test_per)
  losses_seed_sd.append(losses_run)
  accs_seed_sd.append(accs)
  #pers_seed_sd.append(pers)

train set size: (1000, 511)
test set size: (250, 511)
Using discrete input embeddings with num_embeddings (vocab): 20 and features: 64
Using discrete output projection onto feat_dim (vocab): 20
TF!! embedded input shape:  (32, 511, 64)
Number of Decoder-Layers:  3
Using deep delta Rule T/F: False
Transformer Block:
Using Delta Rule Attention
Leveraging depth T/F: False
Using MLP
Using LayerNorm
Transformer Block:
Using Delta Rule Attention
Leveraging depth T/F: False
Using MLP
Using LayerNorm
Transformer Block:
Using Delta Rule Attention
Leveraging depth T/F: False
Using MLP
Using LayerNorm
------
Model has 164,588 parameters
Using discrete input embeddings with num_embeddings (vocab): 20 and features: 64
Using discrete output projection onto feat_dim (vocab): 20
TF!! embedded input shape:  (32, 511, 64)
Number of Decoder-Layers:  3
Using deep delta Rule T/F: False
Transformer Block:
Using Delta Rule Attention
Leveraging depth T/F: False
Using MLP
Using LayerNorm
Transformer Block:
Usi

KeyboardInterrupt: 

In [None]:
a = {'params_sd' : state_tf_sd.params, 'accs_sd' : accs_seed_sd, 'pers_sd': pers_seed_sd}
with open('run_sd_singleseede256l4h8ex1000ep30.pkl', 'wb') as handle:
    pickle.dump(a, handle, protocol=pickle.HIGHEST_PROTOCOL)

Deep Delta Rule

In [35]:
experiment_config = get_experiment_config(seeds=[42])
model_config = get_model_config(use_depth=True, use_gla=False)
training_initializer = TrainingInitializer(model_config=model_config, experiment_config=experiment_config)
model_dd, optimizer, data_generator, train_module_dd = training_initializer.setup_components()

losses_seed_dd = []
accs_seed_dd = []
pers_seed_dd = []
for training_seed in range(1):
  losses_run = []
  accs = []
  pers = []
  rng = jax.random.PRNGKey(training_seed)
  rng, test_rng, train_rng = jax.random.split(rng, 3)
  state_tf_dd, rng = train_module_dd.get_init_state(rng)
  for epoch_idx in range(30):
      state_tf_dd, train_rng, loss, test_acc = train_module_dd.train_epoch(
          epoch=epoch_idx,
          state=state_tf_dd,
          rng=train_rng,
          test_rng=test_rng,
          num_batches_train=300
      )
      losses_run.append(loss)
      accs.append(test_acc)
      #pers.append(test_per)
  losses_seed_dd.append(losses_run)
  accs_seed_dd.append(accs)
#  pers_seed_dd.append(pers)

train set size: (1000, 511)
test set size: (250, 511)
Using discrete input embeddings with num_embeddings (vocab): 20 and features: 64
Using discrete output projection onto feat_dim (vocab): 20
TF!! embedded input shape:  (32, 511, 64)
Number of Decoder-Layers:  3
Using deep delta Rule T/F: True
Transformer Block:
Using Delta Rule Attention
Leveraging depth T/F: True
Using MLP
Using LayerNorm
Transformer Block:
Using Delta Rule Attention
Leveraging depth T/F: True
Using MLP
Using LayerNorm
Transformer Block:
Using Delta Rule Attention
Leveraging depth T/F: True
Using MLP
Using LayerNorm
------
Model has 164,588 parameters
Using discrete input embeddings with num_embeddings (vocab): 20 and features: 64
Using discrete output projection onto feat_dim (vocab): 20
TF!! embedded input shape:  (32, 511, 64)
Number of Decoder-Layers:  3
Using deep delta Rule T/F: True
Transformer Block:
Using Delta Rule Attention
Leveraging depth T/F: True
Using MLP
Using LayerNorm
Transformer Block:
Using Del

KeyboardInterrupt: 

In [None]:
b = {'params_dd' : state_tf_dd.params, 'accs_dd' : accs_seed_dd, 'pers_dd': pers_seed_dd}
with open('run_dd_singleseede256l4h8ex1000ep30.pkl', 'wb') as handle:
    pickle.dump(b, handle, protocol=pickle.HIGHEST_PROTOCOL)

Fast Weight Programmer

In [None]:
experiment_config = get_experiment_config(seeds=[42])
model_config = get_model_config(use_depth=False, use_gla=False)
model_config.use_fwp = False
training_initializer = TrainingInitializer(model_config=model_config, experiment_config=experiment_config)
model_fwp, optimizer, data_generator, train_module_fwp = training_initializer.setup_components()

losses_seed_fwp = []
accs_seed_fwp = []
pers_seed_fwp = []
for training_seed in range(3):
  losses_run = []
  accs = []
  pers = []
  rng = jax.random.PRNGKey(training_seed)
  rng, test_rng, train_rng = jax.random.split(rng, 3)
  state_tf_fwp, rng = train_module_fwp.get_init_state(rng)
  for epoch_idx in range(70):
      state_tf_fwp, train_rng, loss, test_acc = train_module_fwp.train_epoch(
          epoch=epoch_idx,
          state=state_tf_fwp,
          rng=train_rng,
          test_rng=test_rng,
          num_batches_train=300
      )
      losses_run.append(loss)
      accs.append(test_acc)
      # pers.append(test_per)
  losses_seed_fwp.append(losses_run)
  accs_seed_fwp.append(accs)
  #pers_seed_fwp.append(pers)

In [None]:
a = {'params_fwp' : state_tf_fwp.params, 'accs_fwp' : accs_seed_fwp, 'pers_fwp' : pers_seed_fwp}
with open('run_fwp_singleseede256l4h8.pkl', 'wb') as handle:
    pickle.dump(a, handle, protocol=pickle.HIGHEST_PROTOCOL)

Gated Linear Attention

In [None]:
experiment_config = get_experiment_config(seeds=[42])
model_config = get_model_config(use_depth=False, use_gla=True)
model_config.use_fwp = True
training_initializer = TrainingInitializer(model_config=model_config, experiment_config=experiment_config)
model_gla, optimizer, data_generator, train_module_gla = training_initializer.setup_components()

losses_seed_gla = []
accs_seed_gla = []
pers_seed_gla = []
for training_seed in range(3):
  losses_run = []
  accs = []
  pers = []
  rng = jax.random.PRNGKey(training_seed)
  rng, test_rng, train_rng = jax.random.split(rng, 3)
  state_tf_gla, rng = train_module_gla.get_init_state(rng)
  for epoch_idx in range(100):
      state_tf_gla, train_rng, loss, test_acc = train_module_gla.train_epoch(
          epoch=epoch_idx,
          state=state_tf_gla,
          rng=train_rng,
          test_rng=test_rng,
          num_batches_train=300
      )
      losses_run.append(loss)
      accs.append(test_acc)
      #pers.append(test_per)
  losses_seed_gla.append(losses_run)
  accs_seed_gla.append(accs)
  #pers_seed_gla.append(pers)

In [None]:
a = {'params_gla' : state_tf_gla.params, 'accs_gla' : accs_seed_gla, 'pers_gla' : pers_seed_gla}
with open('run_gla_singleseede256l4h8.pkl', 'wb') as handle:
    pickle.dump(a, handle, protocol=pickle.HIGHEST_PROTOCOL)

#### Plots

In [None]:
colors = ['#2978A0',   # Blue
         '#BE7A3C',    # Orange
         '#4E917A',    # Green
         '#A94964',    # Red
         '#7764D8',    # Purple
         '#4D8B31',    # Dark green
         '#C35DCF',    # Pink
         '#816C5B',    # Brown
         '#3778BF',    # Light blue
         '#E05263']    # Coral


plt.figure(figsize=(8, 6))

for seed in range(3):
  plt.plot(accs_seed_dd[seed], label=('deep-delta' if seed == 0 else None), color=colors[0])
  plt.plot(accs_seed_sd[seed], label=('delta' if seed == 0 else None), color=colors[1])
  plt.plot(accs_seed_fwp[seed], label=('fwp/lin-att.' if seed == 0 else None), color=colors[2])
plt.ylim(0.0,1)

# Customize appearance
plt.tick_params(axis='x', colors='dimgray')
plt.tick_params(axis='y', colors='dimgray')
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['bottom'].set_color('dimgray')
plt.gca().spines['left'].set_color('dimgray')

# Set labels
plt.xlabel('Training steps (in $100$)')
plt.ylabel('Accuracy')
plt.legend(loc='lower right', fontsize='large')
plt.title('RegBench: Test Accuracy over training')

# Display the plot
plt.tight_layout()
plt.show()

### Synthetic Experiments

#### Runs

Standard Delta

In [None]:
experiment_config = get_experiment_config(seeds=[42])
model_config = get_model_config(use_depth=False)
training_initializer = TrainingInitializer(model_config=model_config, experiment_config=experiment_config)
model_sd, optimizer, data_generator, train_module_sd = training_initializer.setup_components()

losses_seed_sd = []
for training_seed in range(3):
  losses_run = []
  rng = jax.random.PRNGKey(training_seed)
  rng, test_rng, train_rng = jax.random.split(rng, 3)
  state_tf_sd, rng = train_module_sd.get_init_state(rng)
  for epoch_idx in range(100):
      state_tf_sd, train_rng, loss = train_module_sd.train_epoch(
          epoch=epoch_idx,
          state=state_tf_sd,
          rng=train_rng,
          test_rng=test_rng,
          num_batches_train=100
      )
      losses_run.append(loss)
  losses_seed_sd.append(losses_run)

Deep Delta

In [None]:
experiment_config = get_experiment_config(seeds=[42])
model_config = get_model_config(use_depth=True)
training_initializer = TrainingInitializer(model_config=model_config, experiment_config=experiment_config)
model_dd, optimizer, data_generator, train_module_dd = training_initializer.setup_components()

losses_seed_dd = []
for training_seed in range(3):
  losses_run = []
  rng = jax.random.PRNGKey(training_seed)
  rng, test_rng, train_rng = jax.random.split(rng, 3)
  state_tf_dd, rng = train_module_dd.get_init_state(rng)
  for epoch_idx in range(100):
      state_tf_dd, train_rng, loss = train_module_dd.train_epoch(
          epoch=epoch_idx,
          state=state_tf_dd,
          rng=train_rng,
          test_rng=test_rng,
          num_batches_train=100
      )
      losses_run.append(loss)
  losses_seed_dd.append(losses_run)

FWP (Lin.Att.)

In [None]:
experiment_config = get_experiment_config(seeds=[42])
model_config = get_model_config(use_depth=False)
model_config.use_fwp = False
training_initializer = TrainingInitializer(model_config=model_config, experiment_config=experiment_config)
model_fwp, optimizer, data_generator, train_module_fwp = training_initializer.setup_components()

losses_seed_fwp = []
for training_seed in range(3):
  losses_run = []
  rng = jax.random.PRNGKey(training_seed)
  rng, test_rng, train_rng = jax.random.split(rng, 3)
  state_tf_fwp, rng = train_module_fwp.get_init_state(rng)
  for epoch_idx in range(100):
      state_tf_fwp, train_rng, loss = train_module_fwp.train_epoch(
          epoch=epoch_idx,
          state=state_tf_fwp,
          rng=train_rng,
          test_rng=test_rng,
          num_batches_train=100
      )
      losses_run.append(loss)
  losses_seed_fwp.append(losses_run)

Training DFWP

Per-Token performance eval.

In [None]:
lsq_solver = LeastSquaresSequenceSolver(approximator='None',
                                        seq_len=50,
                                        apx_steps=20,
                                        lamb=0.001,)
model_names = ['delta','deep-delta','fwp','lsq']
loss_fn = lambda p, t : list((jax.numpy.sum(((p - t)**2), axis=(0,2))/(2*p.shape[0])))
seq_evaluator = SequencePredictionEvaluator(data_generator=data_generator,
                                                  test_batch_size=256,
                                                  seeds=[1,2,3,4,5],
                                                  model_list = model_names,
                                                  models = [model_sd, model_dd, model_fwp, lsq_solver],
                                                  states = [state_tf_sd, state_tf_dd, state_tf_fwp, None],
                                                  loss_fn = loss_fn)

seq_loss_dict = seq_evaluator.run()

#### Plot Results

In [None]:
plt.plot(jnp.mean(jnp.array(losses_seed_dd),axis=0), label='deep-delta')
plt.plot(jnp.mean(jnp.array(losses_seed_sd), axis=0), label='delta')
plt.plot(jnp.mean(jnp.array(losses_seed_fwp), axis=0), label='fwp/lin-att.')
plt.ylim(0.3,2)
plt.title('Synth. mesa task: Test loss over training')
plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import jax.numpy as jnp

# Define a nice color palette
colors = ['#2978A0',   # Blue
         '#BE7A3C',    # Orange
         '#4E917A',    # Green
         '#A94964',    # Red
         '#7764D8',    # Purple
         '#4D8B31',    # Dark green
         '#C35DCF',    # Pink
         '#816C5B',    # Brown
         '#3778BF',    # Light blue
         '#E05263']    # Coral

# Calculate mean and std losses
loss_list = seq_loss_dict['losses']
loss_arr = jnp.array(loss_list)
mean_losses = tuple([jnp.mean(loss_arr[idx], axis=0) for idx in range(len(model_names))])
std_losses = tuple([jnp.std(loss_arr[idx], axis=0) for idx in range(len(model_names))])

# Create figure
plt.figure(figsize=(8, 6))

# Plot each model's losses
for idx in range(len(model_names)):

    plt.plot(mean_losses[idx],
             linewidth=2,
             label=model_names[idx],
             color=colors[idx % len(colors)])  # Cycle through colors if more models than colors
    plt.fill_between(range(len(mean_losses[0])),
                    mean_losses[idx] - std_losses[idx],
                    mean_losses[idx] + std_losses[idx],
                    alpha=0.3,
                    color=colors[idx % len(colors)])

# Customize appearance
plt.tick_params(axis='x', colors='dimgray')
plt.tick_params(axis='y', colors='dimgray')
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['bottom'].set_color('dimgray')
plt.gca().spines['left'].set_color('dimgray')

# Set labels
plt.xlabel('Sequence length $t$')
plt.ylabel('Next-token prediction MSE')
plt.legend(loc='upper right', fontsize='large')
plt.title('Synth. mesa task: Loss over one batch of sequences')

# Display the plot
plt.tight_layout()
plt.show()