In [1]:
import torch
from torch import nn
import numpy as np

In [2]:
def hippo_init(N):
    u = torch.arange(N)
    b = torch.sqrt(2 * u[:, None] + 1)
    A = b @ b.T
    A = torch.tril(A, 0)
    A = - (A - torch.diag(u))


    return A, b

In [3]:
A, b = hippo_init(4)

In [4]:
def hippo_init_dplr(N):
    u = torch.arange(N)[:, None]
    
    p = torch.sqrt(u + 0.5)

    # Extract the skew-hermitian part of A.
    S = torch.tril(p @ p.T)
    S = - S + S.T

    # A small trick to make a skew-hermitian matrix a hermitian one.
    hermitian_S = S * -1j
    Lambda, V = torch.linalg.eigh(hermitian_S)

    # Mutliplies back the eigenvalues by (1j)^-1 to retrieve the original eigenvalues of the skew-hermitian matrix.
    # We have to add the real parts of the eigenvalues, coming for the 1/2*Id part of the decomposition. 
    Lambda = Lambda * 1j - 0.5

    # Change of basis for b and p.
    b = torch.sqrt(2 * u + 1)
    b = V.H @ b.type(torch.complex64)
    p = V.H @ p.type(torch.complex64)
    
    return V, Lambda, p, b



In [5]:
V, Lambda, p, b = hippo_init_dplr(4)

In [6]:
assert torch.allclose(V @ (torch.diag(Lambda) - p @ p.H) @ V.H, A.type(torch.complex64), atol=1e-4)

In [18]:
def k_conv(A, b, c, L):
    return torch.tensor([(c.T @ A.pow(l) @ b).item() for l in range(L)])

In [9]:
b.shape

torch.Size([4, 1])

In [19]:
def k_gen_simple(A, b, c, L):
    K = k_conv(A, b, c, L)
    def gen(z):
        return torch.sum(K *  z ** torch.arange(L))
    return gen