##### 1.定义连续SSM与离散化后的SSM的参数

In [1]:
import numpy as np
np.random.seed(1)

#Continuous SSM
def Random_SSM(N):
    #shape: A[N,N], B[N,1], C[1,N]
    A = np.random.rand(N,N)
    B = np.random.rand(N,1)
    C = np.random.rand(1,N)
    return A, B, C

#Discrete SSM
def discretize(A, B, C, step):
    #A_bar = (I - step/2 * A)^(-1) * (I + step/2 * A)
    #B_bar = (I - step/2 * A)^(-1) * B * step

    I = np.eye(A.shape[0])
    BL = np.linalg.inv(I - (step / 2.0) * A)
    Ab = BL @ (I + (step / 2.0) * A)
    Bb = (BL * step) @ B
    return Ab, Bb, C

##### 2.SSM RNN Representation

In [2]:
def scan_SSM(Ab, Bb, Cb, u, x0):
    # x = A_bar * x + B_bar * u
    # Y = C_bar * x

    x0 = Ab @ x0 + Bb * u
    y = Cb @ x0
    return x0, y

#Demo: Run SSM
def run_SSM (A, B, C, u):
    L = u.shape[0]
    N = A.shape[0]
    Ab, Bb, Cb = discretize(A, B, C, step=1.0 / L)
    x0 = np.zeros((N,1))

    #Run Recurrence
    for i in range(L):
        x0, y = scan_SSM(Ab, Bb, Cb, u[i], x0)
    return y

##### 3.utils

In [3]:
# Mat Power
def matmul_n_times(A, n_times):
    raw_data = A
    if n_times > 0:
        for i in range(n_times - 1):
            A = np.matmul(A, raw_data)
    elif n_times == 0:
        A = np.eye(A.shape[0])
    return A

# Get Conv Kernel
def K_conv(Ab, Bb, Cb, L):
    # K = [C_bar * A_bar ^ i * B_bar for i in range(L)]
    return np.array([(Cb @ matmul_n_times(Ab, i) @ Bb) for i in range(L)]).squeeze()

#Convolution
def causal_convolution(u, K, nofft = False):
    if nofft: #不使用FFT
        return K[::-1] @ np.transpose(u)
    else: #使用FFT
        assert K.shape[0] == u.shape[0]
        ud = np.fft.rfft(np.pad(u, (0, K.shape[0])))
        Kd = np.fft.rfft(np.pad(K, (0, u.shape[0])))
        out = ud * Kd
        return np.fft.irfft(out)[u.shape[0] - 1]

In [4]:
# Demo in this stage
L = 4
step = 1.0 / L
A, B, C = Random_SSM(3)
Ab, Bb,Cb = discretize(A, B, C, step)
Origin_Kernel = K_conv(Ab, Bb, Cb, L)
Origin_Kernel

array([0.13734084, 0.16658424, 0.20268662, 0.24721982])

In [6]:
def test_cnn_is_rnn(N = 3, L = 5, step = 1.0/5):

    ssm = Random_SSM(N)
    u = np.array([-1, -2, -3, -4, -5])

    # RNN results
    rec = run_SSM(*ssm, u)

    # CNN results
    ssmb = discretize(*ssm, step=step)
    # Get Conv Kernel K
    K = K_conv(*ssmb, L)
    # Calculate K * u
    conv = causal_convolution(u, K, True)
    conv2 = causal_convolution(u, K, False)

    #Check results
    print()
    print("RNN result is :", rec.ravel()[0])
    print("CNN(w\o FFT) result is : ", conv.ravel()[0])
    print("CNN(w\ FFT) result is : ", conv2.ravel()[0])

    return (np.abs((rec.ravel()[0] - conv2.ravel()[0])) < 1e6)
test_cnn_is_rnn()
    
    


RNN result is : -2.9878612423736817
CNN(w\o FFT) result is :  -2.987861242373681
CNN(w\ FFT) result is :  -2.9878612423736812


True

In [9]:
# Define HiPPO Matrix
def make_HiPPO(N):
    P = np.sqrt(1 + 2 * np.arange(N))
    A = P[:, np.newaxis] * P[np.newaxis, :]
    A = np.tril(A) - np.diag(np.arange(N))
    return -A

In [None]:
# function K*z naive method
def K_gen_simple(Ab, Bb, Cb, L):
    K = K_conv(Ab, Bb, Cb, L)
    
    def gen(z):
        return np.sum(K * (z ** np.arange(L)))
    
    return gen

#function K*z (generation function method)
def K_gen_inversr(Ab, Bb, Cb, L):
    I = np.eye(Ab.shape[0])
    Ab_L = matmul_n_times(Ab, L)
    #C_~ = C_bar * (I - A_bar^L)
    Ct = Cb @ (I - Ab_L)
    return lambda z:(Ct @ np.linalg.inv(I - Ab * z) @ Bb)

def conv_from_gen(gen, L):
    
