<a href="https://colab.research.google.com/github/eisbetterthanpi/ssm/blob/main/LRU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Pytorch port of associative scan
#@title PyTorch associative/parallel scan
# Taken from https://github.com/i404788/s5-pytorch/blob/c74be7270fe2ec9dc13efcffcfd7f5355d884030/s5/jax_compat.py
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from typing import overload, Callable, Iterable, List, TypeVar, Any,Tuple
from functools import partial

"""
Jax-Pytorch ported functions, mostly interfaces are kept the same but unsupported features are removed:
* Jax-Keyed RNGs are sampled from global RNG
* Canonical/Named shapes/dtypes/etc are now regular shapes,dtypes
"""

'''
T = TypeVar("T")
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
@overload
def safe_map(f: Callable[[T1], T], __arg1: Iterable[T1]) -> List[T]: ...
@overload
def safe_map(f: Callable[[T1, T2], T], __arg1: Iterable[T1], __arg2: Iterable[T2]) -> List[T]: ...
@overload
def safe_map(f: Callable[[T1, T2, T3], T], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> List[T]: ...
@overload
def safe_map(f: Callable[..., T], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> List[T]: ...
'''
# def safe_map(f, *args):
#     args = list(map(list, args))
#     n = len(args[0])
#     for arg in args[1:]:
#         assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
#     return list(map(f, *args))

def combine(tree, operator, a_flat, b_flat):
    # Lower `fn` to operate on flattened sequences of elems.
    a = tree_unflatten(a_flat, tree)
    b = tree_unflatten(b_flat, tree)
    c = operator(a, b)
    c_flat, _ = tree_flatten(c)
    return c_flat

def _scan(tree, operator, elems, axis):
    """Perform scan on `elems`."""
    num_elems = elems[0].shape[axis]
    if num_elems < 2: return elems
    # Combine adjacent pairs of elements.
    reduced_elems = combine(tree, operator, [torch.ops.aten.slice(elem, axis, 0, -1, 2) for elem in elems],
                            [torch.ops.aten.slice(elem, axis, 1, None, 2) for elem in elems])

    # Recursively compute scan for partially reduced tensors.
    odd_elems = _scan(tree, operator, reduced_elems, axis)

    if num_elems % 2 == 0:
        even_elems = combine(tree, operator, [torch.ops.aten.slice(e, axis, 0, -1) for e in odd_elems],
                             [torch.ops.aten.slice(e, axis, 2, None, 2) for e in elems])
    else:
        even_elems = combine(tree, operator, odd_elems,
                             [torch.ops.aten.slice(e, axis, 2, None, 2) for e in elems])
    # The first element of a scan is the same as the first element
    # of the original `elems`.
    even_elems = [
        torch.cat([torch.ops.aten.slice(elem, axis, 0, 1), result], dim=axis)
        if result.shape.numel() > 0 and elem.shape[axis] > 0 else
        result if result.shape.numel() > 0 else
        torch.ops.aten.slice(elem, axis, 0, 1)  # Jax allows/ignores concat with 0-dim, Pytorch does not
        for (elem, result) in zip(elems, even_elems)]
    # return list(safe_map(partial(_interleave, axis=axis), even_elems, odd_elems))
    return list(list(map(partial(_interleave, axis=axis), even_elems, odd_elems)))


def associative_scan(operator: Callable, elems, axis = 0, reverse: bool = False):
    elems_flat, tree = tree_flatten(elems)
    if reverse: elems_flat = [torch.flip(elem, [axis]) for elem in elems_flat]
    assert axis >= 0 or axis < elems_flat[0].ndim, "Axis should be within bounds of input"
    num_elems = int(elems_flat[0].shape[axis])
    if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
        raise ValueError('Array inputs to associative_scan must have the same first dimension. (saw: {})'.format([elem.shape for elem in elems_flat]))
    scans = _scan(tree, operator, elems_flat, axis)
    if reverse: scans = [torch.flip(scanned, [axis]) for scanned in scans]
    return tree_unflatten(scans, tree)

def _interleave(a, b, axis):
    # https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors
    if b_trunc := (a.shape[axis] == b.shape[axis] + 1):
        pad = [0, 0] * b.ndim
        pad[(b.ndim-axis-1)*2+1] = 1 # +1=always end of dim, pad-order is reversed so start is at end
        b = torch.nn.functional.pad(b, pad)

    stacked = torch.stack([a, b], dim=axis+1)
    interleaved = torch.flatten(stacked, start_dim=axis, end_dim=axis+1)
    if b_trunc:
        # TODO: find torch alternative for slice_along axis for torch.jit.script to work
        interleaved = torch.ops.aten.slice(interleaved, axis, 0, b.shape[axis]+a.shape[axis]-1)
    return interleaved

# Taken from https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/s5_model.py
@torch.jit.script
def binary_operator_diag(q_i: Tuple[torch.Tensor, torch.Tensor], q_j: Tuple[torch.Tensor, torch.Tensor]):
    """Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A.
    Args:
        q_i: tuple containing A_i and Bu_i at position i       (P,), (P,)
        q_j: tuple containing A_j and Bu_j at position j       (P,), (P,)
    Returns:
        new element ( A_out, Bu_out )
    """
    A_i, b_i = q_i
    A_j, b_j = q_j
    # return A_j * A_i, A_j * b_i + b_j
    return A_j*A_i, torch.addcmul(b_j, A_j, b_i)


In [None]:
# @title forgi86/lru
# https://github.com/forgi86/sysid-pytorch-lru/blob/main/lru/linear.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class LRU(nn.Module):
    def __init__(self, in_dim, out_dim, d_model, rmin=0.0, rmax=1.0, max_phase=6.283):
        super().__init__()
        self.in_dim, self.out_dim, self.d_model = in_dim, out_dim, d_model
        B_re, B_im = torch.randn(d_model, in_dim) / math.sqrt(2*in_dim), torch.randn(d_model, in_dim) / math.sqrt(2*in_dim)
        self.B = nn.Parameter(torch.complex(B_re, B_im)) # N, U
        C_re, C_im = torch.randn(out_dim, d_model) / math.sqrt(d_model), torch.randn(out_dim, d_model) / math.sqrt(d_model)
        self.C = nn.Parameter(torch.complex(C_re, C_im)) # H, N
        self.D = nn.Parameter(torch.randn(out_dim, in_dim) / math.sqrt(in_dim))

        self.nu_log = nn.Parameter(torch.log(-.5*torch.log(torch.rand(d_model)*(rmax+rmin)*(rmax-rmin)+rmin**2)))
        self.theta_log = nn.Parameter(torch.log(max_phase*torch.rand(d_model)))

        lambda_abs = torch.exp(-torch.exp(self.nu_log))
        self.gamma_log = nn.Parameter(torch.log(torch.sqrt(torch.ones_like(lambda_abs) - torch.square(lambda_abs))))

    def ss_params(self):
        lambda_abs = torch.exp(-torch.exp(self.nu_log))
        lambda_phase = torch.exp(self.theta_log)
        lambda_re = lambda_abs * torch.cos(lambda_phase)
        lambda_im = lambda_abs * torch.sin(lambda_phase)
        lambdas = torch.complex(lambda_re, lambda_im)
        #lambdas = lambda_abs*torch.exp(1j*lambda_phase)
        gammas = torch.exp(self.gamma_log).unsqueeze(-1).to(self.B.device)
        B = gammas * self.B
        return lambdas, B, self.C, self.D


    def ss_real_matrices(self, to_numpy=True):
        lambdas, B, self.C, self.D = self.ss_params()
        lambdas_full = torch.zeros(2*self.d_model, device=lambdas.device, dtype=lambdas.dtype)
        lambdas_full[::2] = lambdas
        lambdas_full[1::2] = lambdas.conj()

        # First convert to complex conjugate system....
        A_full = torch.diag(lambdas_full)
        # B_full = torch.zeros((2*self.d_model, self.in_dim), device=lambdas.device, dtype=lambdas.dtype)
        B_full = torch.empty((2*self.d_model, self.in_dim), device=lambdas.device, dtype=lambdas.dtype)
        B_full[::2] = B
        B_full[1::2] = B.conj()
        # C_full = torch.zeros((self.out_dim, 2*self.d_model), device=lambdas.device, dtype=lambdas.dtype)
        C_full = torch.empty((self.out_dim, 2*self.d_model), device=lambdas.device, dtype=lambdas.dtype)
        C_full[:, ::2] = 0.5*self.C # we take the real part of the complex conjugate system as output...
        C_full[:, 1::2] = 0.5*self.C.conj()
        # D_full = self.D

        # Then apply transformation to real domain
        T_block = torch.tensor([[1, 1], [1j, -1j]], device=lambdas.device, dtype=lambdas.dtype)
        T_block_inv = torch.linalg.inv(T_block)
        T_full = torch.block_diag(*([T_block] * self.d_model))
        T_full_inv = torch.block_diag(*([T_block_inv] * self.d_model))

        A_real = (T_full @ A_full @ T_full_inv).real
        B_real = (T_full @ B_full).real
        C_real = (C_full @ T_full_inv).real
        # D_real = D_full

        # ss_real_params = [A_real, B_real, C_real, D_real]
        # if to_numpy: ss_real_params = [ss_real_param.detach().numpy() for ss_real_param in ss_real_params]
        # return (*ss_real_params, )
        return A_real, B_real, C_real, self.D


    def forward_loop(self, input, state=None): # Input size: (B, L, H)
        lambdas, B, C, D = self.ss_params()
        output = torch.empty([i for i in input.shape[:-1]] + [self.out_dim], device=self.B.device)
        states = []
        for u_step in input.split(1, dim=1): # 1 is the time dimension
            u_step = u_step.squeeze(1)
            state = lambdas * state + u_step.to(B.dtype) @ B.T
            states.append(state)
        states = torch.stack(states, 1)
        output = (states @ C.mT).real + input @ D.T
        return output

    @torch.compiler.disable
    def forward_scan(self, input, state=None): # (B, L, H)
        # Batched parallel scan, borrows heavily from https://colab.research.google.com/drive/1RgIv_3WAOW53CS0BnT7_782VKTYis9WG?usp=sharing
        # which in turn borrows from https://github.com/i404788/s5-pytorch
        lambdas, B, C, D = self.ss_params()
        lambda_elements = lambdas.tile(input.shape[1], 1) # [N]->[L,N]
        # Calculate B@u for each step u of each input sequence in the batch.
        # Bu_elements will have shape (B, L, N)
        Bu_elements = input.to(B.dtype) @ B.T
        if state is not None: Bu_elements[:, 0, :] = Bu_elements[:, 0, :] + lambdas * state
        # Vmap the associative scan since Bu_elements is a batch of B sequences.
        # Recall that Lambda_elements has been repeated L times to (L, N),
        # while Bu_seq has shape (B, L, N)
        inner_state_fn = lambda Bu_seq: associative_scan(binary_operator_diag, (lambda_elements, Bu_seq))[1]
        # inner_states will be of shape (B, L, N)
        inner_states = torch.vmap(inner_state_fn)(Bu_elements)
        #y = (inner_states @ self.C.T).real + input_sequences * self.D
        y = (inner_states @ C.T).real + input @ D.T
        return y

    def forward(self, input, state=None):
        if state is None: state = torch.view_as_complex(torch.zeros((self.d_model, 2), device=input.device)) # default initial state, size N
        y = self.forward_scan(input, state)
        y = self.forward_loop(input, state)
        return y


d_model = 40 #256 # N
H = 20 #512 # input/output dimension
L = 10_000 #2048 # input sequence length
b = 32
layer = LRU(in_dim=H, out_dim=H, d_model=d_model)
input_sequences = torch.randn(b, L, H) # multiple sequences
output_sequences = layer(input_sequences)
output_sequences_scan = layer.forward_scan(input_sequences)
torch.allclose(output_sequences_scan, output_sequences, 1e-2)

print(output_sequences_scan.shape)


torch.Size([32, 10000, 20])


In [None]:
# @title forgi86 architectures.py
# https://github.com/forgi86/sysid-pytorch-lru/blob/main/lru/architectures.py
import math
from dataclasses import dataclass
import torch
import torch.nn as nn

@dataclass
class DWNConfig:
    d_model = 10
    d_state = 64
    n_layers = 6
    dropout = 0.0
    bias: bool = True
    rmin = 0.0
    rmax = 1.0
    max_phase = 2*math.pi
    ff: str = "GLU"

class MLP(nn.Module):
    """ Standard Transformer MLP """
    def __init__(self, config: DWNConfig):
        super().__init__()
        self.c_fc = nn.Linear(config.d_model, 4 * config.d_model, bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.d_model, config.d_model, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout) if config.dropout > 0 else nn.Identity()

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class GLU(nn.Module):
    def __init__(self, config: DWNConfig):
        super().__init__()
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(config.dropout) if config.dropout > 0 else nn.Identity()
        self.output_linear = nn.Sequential(
            nn.Linear(config.d_model, 2 * config.d_model),#nn.Conv1d(config.d_model, 2 * config.d_model, kernel_size=1),
            nn.GLU(dim=-1),
        )

    def forward(self, x):
        x = self.dropout(self.activation(x))
        x = self.output_linear(x)
        return x

class DWNBlock(nn.Module):
    def __init__(self, config: DWNConfig):
        super().__init__()
        self.ln = nn.LayerNorm(config.d_model, bias=config.bias)
        self.lru = LRU(config.d_model, config.d_model, config.d_state, rmin=config.rmin, rmax=config.rmax, max_phase=config.max_phase)
        # self.ff = GLU(config)
        self.ff = MLP(config)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, state=None, mode="scan"):
        z = x
        z = self.ln(z)  # prenorm
        z = self.lru(z, state, mode)
        z = self.ff(z) # MLP or GLU
        z = self.dropout(z)
        x = z + x
        return x


class DWN(nn.Module):
    def __init__(self, n_u, n_y, config: DWNConfig):
        super().__init__()
        self.encoder = nn.Linear(n_u, config.d_model)
        self.blocks = nn.ModuleList([DWNBlock(config) for _ in range(config.n_layers)])
        self.decoder = nn.Linear(config.d_model, n_y)

    def forward(self, u, state=None, mode="scan"):
        x = self.encoder(u)
        for layer, block in enumerate(self.blocks):
            state_block = state[layer] if state is not None else None
            x = block(x, state=state_block, mode=mode)
        x = self.decoder(x)
        return x


In [None]:
# @title hrpan torch_lru lru.py base
# https://github.com/hrpan/torch_lru/blob/main/lru.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def parallel_lcse(log_input, log_coeff): # On Structured State-Space Duality oct 2025 https://www.arxiv.org/pdf/2510.04944
    t, b, d = log_input.shape
    t_log_coeff = torch.arange(t, device=log_coeff.device)[:,None] * log_coeff[None,:] # [t,]
    t_log_coeff = t_log_coeff.unsqueeze(1) # [t,1,]
    return t_log_coeff + torch.logcumsumexp(log_input - t_log_coeff, dim=0)

# def conv(_input, log_coeff):
#     t, b, d = _input.shape
#     t_log_coeff = torch.arange(t-1,-1,-1, device=log_coeff.device)[:,None] * log_coeff[None,:]
#     kernel_transpose = torch.diag_embed(t_log_coeff.exp())
#     kernel = kernel_transpose.permute(1,2,0)
#     input_pad = F.pad(_input.permute(1,2,0), (t-1, 0, 0, 0, 0, 0)) # T B D -> B D T
#     return F.conv1d(input_pad, kernel.to(dtype=torch.complex64)).permute(2,0,1) # B D T -> T B D

class LRU(nn.Module):
    def __init__(self, in_dim, dim, r_min=.5, r_max=.95, max_phase=6.283):
        super().__init__()
        # self.b_linear = nn.Linear(dim, dim, bias=False)
        # self.b_linear.weight = nn.Parameter((torch.randn(dim, dim) + 1j * torch.randn(dim, dim)) / np.sqrt(2*dim))
        # self.c_linear = nn.Linear(dim, dim, bias=False)
        # self.c_linear.weight = nn.Parameter((torch.randn(dim, dim) + 1j * torch.randn(dim, dim)) / np.sqrt(dim))
        # self.d = nn.Parameter(torch.randn(dim)) # ?

        self.b_linear = nn.Linear(in_dim, dim, bias=False)
        self.b_linear.weight = nn.Parameter((torch.randn(dim, in_dim) + 1j * torch.randn(dim, in_dim)) / np.sqrt(2*in_dim))
        self.c_linear = nn.Linear(dim, in_dim, bias=False)
        self.c_linear.weight = nn.Parameter((torch.randn(in_dim, dim) + 1j * torch.randn(in_dim, dim)) / np.sqrt(dim))
        self.d = nn.Parameter(torch.randn(in_dim)) # ?

        self.nu_log = nn.Parameter(torch.log(-.5 * torch.log(torch.rand(dim) * (r_max+r_min)*(r_max-r_min) + r_min**2)))
        self.theta_log = nn.Parameter(torch.log(max_phase * torch.rand(dim)))

        _lambda = torch.exp(-torch.exp(self.nu_log) + 1j * torch.exp(self.theta_log)).detach()
        self.gamma_log = nn.Parameter(torch.log(torch.sqrt(1 - torch.abs(_lambda)**2)))
        self.forward = self.lcse


    def lcse(self, x, h=None, eps=1e-10):
        x_complex = x.to(dtype=torch.complex64)
        print(x_complex.shape, self.b_linear.weight.shape)
        bx = self.b_linear(x_complex) + eps
        log_lambda = -torch.exp(self.nu_log) + 1j * torch.exp(self.theta_log)
        if h is not None:
            x_in = torch.cat([h.log(), self.gamma_log + bx.log()], dim=0)
            ht = parallel_lcse(x_in, log_lambda)[1:].exp()
        else:
            x_in = self.gamma_log + bx.log()
            ht = parallel_lcse(x_in, log_lambda).exp()
        y = self.c_linear(ht.to(dtype=torch.complex64)).real + self.d * x
        return y, ht[-1]

    # def conv(self, x, h=None):
    #     if h is None: h = torch.zeros_like(x[0])
    #     log_lambda = -torch.exp(self.nu_log) + 1j * torch.exp(self.theta_log)
    #     x_complex = x.to(dtype=torch.complex64)
    #     bx = self.gamma_log.exp() * self.b_linear(x_complex)
    #     ht = conv(bx, log_lambda)
    #     y = self.c_linear(ht.to(dtype=torch.complex64)).real + self.d * x
    #     return y, ht[-1]

    # def seq(self, x, h=None): # [b,t,d], [t,d]
    #     x_complex = x.to(dtype=torch.complex64)
    #     if h is None: h = torch.zeros_like(x[0])
    #     log_lambda = -torch.exp(self.nu_log) + 1j * torch.exp(self.theta_log)
    #     bx = self.gamma_log.exp() * self.b_linear(x_complex)
    #     ht = []
    #     _lambda = log_lambda.exp()
    #     for t in range(x.size(0)):
    #         ht.append(h * _lambda + bx[t])
    #         h = ht[-1]
    #     ht = torch.stack(ht)
    #     y = self.c_linear(ht.to(dtype=torch.complex64)).real + self.d * x
    #     return y, ht[-1]


class LRUBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.prenorm = nn.LayerNorm(dim)
        self.rnn = LRU(dim)
        self.linear = nn.Linear(dim, 2*dim)

    def forward(self, x, h=None):
        z = self.prenorm(x)
        z, h = self.rnn(z)
        z = F.gelu(z)
        z1, z2 = self.linear(z).chunk(2, dim=-1)
        z = z1 * torch.sigmoid(z2)
        return z + x, h


# torch.set_default_device('cuda')
# torch.set_default_device('cpu')
# b, t = 32, 10000
b, t = 4, 100
in_dim = 20
dim=40
x = torch.randn(b, t, in_dim, dtype=torch.float32)

layer = LRU(in_dim, dim)
y_lcse = layer.lcse(x)
# y_conv = layer.conv(x)
# y_seq = layer.seq(x)
# print('LCSE/SEQ output allclose:', torch.allclose(y_lcse[0], y_seq[0], atol=1e-4))
# print('CONV/SEQ output allclose:', torch.allclose(y_conv[0], y_seq[0], atol=1e-4))

# y_conv[0].sum().backward()
# conv_grad = {}
# for n, p in layer.named_parameters():
#     conv_grad[n] = p.grad
#     p.grad = None

y_lcse[0].sum().backward()
lcse_grad = {}
for n, p in layer.named_parameters():
    lcse_grad[n] = p.grad
    p.grad = None

# y_seq[0].sum().backward()
# seq_grad = {}
# for n, p in layer.named_parameters():
#     seq_grad[n] = p.grad
#     p.grad = None

# for k in conv_grad.keys():
#     print(f'CONV/SEQ {k} grad allclose:', torch.allclose(conv_grad[k], seq_grad[k], atol=1e-4))
#     print(f'LCSE/SEQ {k} grad allclose:', torch.allclose(lcse_grad[k], seq_grad[k], atol=1e-4))

repeats = 2
log_x = np.arange(2, 16)
y_lcse = []
y_conv = []
y_seq = []

import time
for log_length in log_x:
    print('Seq length:', 2**log_length)
    x = torch.randn((2**log_length, 1, in_dim), dtype=torch.float32)

    t0 = time.time()
    for _ in range(repeats):
        layer.lcse(x)
    delta_t = time.time() - t0
    print('lcse:', delta_t)
    y_lcse.append(delta_t)

    # t0 = time.time()
    # for _ in range(repeats):
    #     layer.conv(x)
    # delta_t = time.time() - t0
    # print('conv:', delta_t)
    # y_conv.append(delta_t)

    # t0 = time.time()
    # for _ in range(repeats):
    #     layer.seq(x)
    # delta_t = time.time() - t0
    # print('seq:', time.time() - t0)
    # y_seq.append(delta_t)




In [None]:
# @title hrpan torch_lru lru.py
# https://github.com/hrpan/torch_lru/blob/main/lru.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# torch.set_default_device('cuda')

def parallel_lcse(log_input, log_coeff): # On Structured State-Space Duality oct 2025 https://www.arxiv.org/pdf/2510.04944
    t, b, d = log_input.shape
    t_log_coeff = torch.arange(t, device=log_coeff.device)[:,None] * log_coeff[None,:] # [t,]
    t_log_coeff = t_log_coeff.unsqueeze(1) # [t,1,]
    return t_log_coeff + torch.logcumsumexp(log_input - t_log_coeff, dim=0)

def conv(_input, log_coeff):
    t, b, d = _input.shape
    t_log_coeff = torch.arange(t-1,-1,-1, device=log_coeff.device)[:,None] * log_coeff[None,:]
    kernel_transpose = torch.diag_embed(t_log_coeff.exp())
    kernel = kernel_transpose.permute(1,2,0).to(dtype=torch.complex64)
    input_pad = F.pad(_input.permute(1,2,0), (t-1, 0, 0, 0, 0, 0)) # T B D -> B D T
    return F.conv1d(input_pad, kernel).permute(2,0,1) # B D T -> T B D

class LRU(nn.Module):
    def __init__(self, in_dim, d_model=None, out_dim=None, r_min=.5, r_max=.95, max_phase=6.283):
        super().__init__()
        d_model, out_dim = d_model or in_dim, out_dim or in_dim
        self.b_linear = nn.Linear(in_dim, d_model, bias=False)
        self.b_linear.weight = nn.Parameter((torch.randn(d_model, in_dim) + 1j * torch.randn(d_model, in_dim)) / np.sqrt(2*in_dim))
        self.c_linear = nn.Linear(d_model, out_dim, bias=False)
        self.c_linear.weight = nn.Parameter((torch.randn(out_dim, d_model) + 1j * torch.randn(out_dim, d_model)) / np.sqrt(d_model))
        self.d = nn.Parameter(torch.randn(out_dim)) # ?

        self.nu_log = nn.Parameter(torch.log(-.5 * torch.log(torch.rand(d_model) * (r_max+r_min)*(r_max-r_min) + r_min**2)))
        self.theta_log = nn.Parameter(torch.log(max_phase * torch.rand(d_model)))
        _lambda = torch.exp(-torch.exp(self.nu_log) + 1j * torch.exp(self.theta_log)).detach()
        self.gamma_log = nn.Parameter(torch.log(torch.sqrt(1 - torch.abs(_lambda)**2)))
        self.forward = self.lcse
# h_t+1 = lamb h_t + b ipnut ut/x_t
# out y = c hidden h + d input u/x

    def lcse(self, x, h=None, eps=1e-10):
        x_complex = x.to(dtype=torch.complex64)
        bx = self.b_linear(x_complex) + eps
        log_lambda = -torch.exp(self.nu_log) + 1j * torch.exp(self.theta_log)
        if h is not None:
            x_in = torch.cat([h.log(), self.gamma_log + bx.log()], dim=0)
            ht = parallel_lcse(x_in, log_lambda)[1:].exp()
        else:
            x_in = self.gamma_log + bx.log()
            ht = parallel_lcse(x_in, log_lambda).exp()
        y = self.c_linear(ht.to(dtype=torch.complex64)).real + self.d * x
        return y, ht[-1]

    def conv(self, x, h=None):
        x_complex = x.to(dtype=torch.complex64)
        if h is None: h = torch.zeros_like(x[0])
        log_lambda = -torch.exp(self.nu_log) + 1j * torch.exp(self.theta_log)
        bx = self.gamma_log.exp() * self.b_linear(x_complex)
        ht = conv(bx, log_lambda)
        y = self.c_linear(ht.to(dtype=torch.complex64)).real + self.d * x
        return y, ht[-1]

    def seq(self, x, h=None): # [t,b,d], [t,d]
        x_complex = x.to(dtype=torch.complex64)
        if h is None: h = torch.zeros_like(x[0])
        log_lambda = -torch.exp(self.nu_log) + 1j * torch.exp(self.theta_log)
        bx = self.gamma_log.exp() * self.b_linear(x_complex)
        ht = []
        _lambda = log_lambda.exp()
        for t in range(x.size(0)):
            ht.append(h * _lambda + bx[t])
            h = ht[-1]
        ht = torch.stack(ht)
        y = self.c_linear(ht.to(dtype=torch.complex64)).real + self.d * x
        return y, ht[-1]

# class LRUBlock(nn.Module):
#     def __init__(self, dim):
#         super().__init__()
#         self.prenorm = nn.LayerNorm(dim)
#         self.rnn = LRU(dim)
#         self.linear = nn.Linear(dim, 2*dim)

#     def forward(self, x, h=None):
#         z = self.prenorm(x)
#         z, h = self.rnn(z)
#         z = F.gelu(z)
#         z1, z2 = self.linear(z).chunk(2, dim=-1)
#         z = z1 * torch.sigmoid(z2)
#         return z + x, h


# b, t = 32, 10000
b, t = 4, 100
in_dim = 20
d_model = 40
x = torch.randn(t,b,in_dim, dtype=torch.float32)

layer = LRU(in_dim, d_model)
y_lcse = layer.lcse(x)

y_lcse[0].sum().backward()
lcse_grad = {}
for n, p in layer.named_parameters():
    lcse_grad[n] = p.grad
    p.grad = None

repeats = 2
log_x = np.arange(2,16)
y_lcse = []
y_conv = []
y_seq = []

import time
for log_length in log_x:
    print('Seq length:', 2**log_length)
    x = torch.randn(2**log_length, 1, in_dim, dtype=torch.float32) # [t,b,in]

    t0 = time.time()
    for _ in range(repeats):
        layer.lcse(x)
    delta_t = time.time() - t0
    print('lcse:', delta_t)
    y_lcse.append(delta_t)



Seq length: 4
lcse: 0.0011658668518066406
Seq length: 8
lcse: 0.001047372817993164
Seq length: 16
lcse: 0.001398324966430664
Seq length: 32
lcse: 0.0019273757934570312
Seq length: 64
lcse: 0.0030863285064697266
Seq length: 128
lcse: 0.005485057830810547
Seq length: 256
lcse: 0.010121345520019531
Seq length: 512
lcse: 0.020756244659423828
Seq length: 1024
lcse: 0.03751349449157715
Seq length: 2048
lcse: 0.06968951225280762
Seq length: 4096
lcse: 0.1582951545715332
Seq length: 8192
lcse: 0.27878808975219727
Seq length: 16384
lcse: 0.6257624626159668
Seq length: 32768
lcse: 1.2754313945770264


In [None]:
# @title LRUBlock

def zero_module(module):
    for p in module.parameters():
        p.detach().zero_()
    return module

class GLU(nn.Module): # https://arxiv.org/pdf/2002.05202
    def __init__(self, in_dim, d_model=None):
        super().__init__()
        d_model = d_model or in_dim
        self.lin = nn.Sequential(
            # nn.LayerNorm(in_dim),
            nn.GELU(), # SiLU
            # nn.Linear(in_dim, d_model)
            zero_module(nn.Linear(in_dim, 2*d_model, bias=False))
        )

    def forward(self, x): # [b,t,d]
        x0, x1 = self.lin(x).chunk(2, dim=-1)
        x = x0 * torch.sigmoid(x1)
        return x

class LRUBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.prenorm = nn.LayerNorm(d_model)
        self.rnn = LRU(d_model)
        self.glu = GLU(d_model)

    def forward(self, x, h=None):
        z, h = self.rnn(self.prenorm(x), h)
        x = x + self.glu(z)
        return x, h

b, t, d_model = 4, 100, 40
lrublock = LRUBlock(d_model)
x = torch.randn(t,b,d_model, dtype=torch.float32)

y, h = lrublock(x)
print(y.shape, h.shape)


torch.Size([100, 4, 40]) torch.Size([4, 40])


In [None]:
# @title Gothos LRU.py
# https://github.com/Gothos/LRU-pytorch/blob/main/LRU_pytorch/LRU.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class LRU(nn.Module):
    def __init__(self,in_features,out_features,state_features, rmin=0, rmax=1,max_phase=6.283):
        super().__init__()
        self.out_features=out_features
        self.D=nn.Parameter(torch.randn([out_features,in_features])/math.sqrt(in_features))
        u1=torch.rand(state_features)
        u2=torch.rand(state_features)
        self.nu_log= nn.Parameter(torch.log(-0.5*torch.log(u1*(rmax+rmin)*(rmax-rmin) + rmin**2)))
        self.theta_log= nn.Parameter(torch.log(max_phase*u2))
        Lambda_mod=torch.exp(-torch.exp(self.nu_log))
        self.gamma_log=nn.Parameter(torch.log(torch.sqrt(torch.ones_like(Lambda_mod)-torch.square(Lambda_mod))))
        B_re=torch.randn([state_features,in_features])/math.sqrt(2*in_features)
        B_im=torch.randn([state_features,in_features])/math.sqrt(2*in_features)
        self.B=nn.Parameter(torch.complex(B_re,B_im))
        C_re=torch.randn([out_features,state_features])/math.sqrt(state_features)
        C_im=torch.randn([out_features,state_features])/math.sqrt(state_features)
        self.C=nn.Parameter(torch.complex(C_re,C_im))
        self.state=torch.complex(torch.zeros(state_features),torch.zeros(state_features))

    def forward(self, input,state=None):
        self.state=self.state.to(self.B.device) if state==None else state
        Lambda_mod=torch.exp(-torch.exp(self.nu_log))
        Lambda_re=Lambda_mod*torch.cos(torch.exp(self.theta_log))
        Lambda_im=Lambda_mod*torch.sin(torch.exp(self.theta_log))
        Lambda=torch.complex(Lambda_re,Lambda_im)
        Lambda=Lambda.to(self.state.device)
        gammas=torch.exp(self.gamma_log).unsqueeze(-1).to(self.B.device)
        gammas=gammas.to(self.state.device)
        output=torch.empty([i for i in input.shape[:-1]] +[self.out_features],device=self.B.device)
        #Handle input of (Batches,Seq_length, Input size)
        if input.dim()==3:
            for i,batch in enumerate(input):
                out_seq=torch.empty(input.shape[1],self.out_features)
                for j,step in enumerate(batch):
                    self.state=(Lambda*self.state + gammas* self.B@step.to(dtype= self.B.dtype))
                    out_step= (self.C@self.state).real + self.D@step
                    out_seq[j]=out_step
                self.state=torch.complex(torch.zeros_like(self.state.real),torch.zeros_like(self.state.real))
                output[i]=out_seq
        #Handle input of (Seq_length, Input size)
        if input.dim()==2:
            for i,step in enumerate(input):
                self.state=(Lambda*self.state + gammas* self.B@step.to(dtype= self.B.dtype))
                out_step= (self.C@self.state).real + self.D@step
                output[i]=out_step
            self.state=torch.complex(torch.zeros_like(self.state.real),torch.zeros_like(self.state.real))
        return output
