In [2]:
%load_ext autoreload
%autoreload 2

In [153]:
import torch
import torch.nn as nn

from mamba import S4_with_shared_A, S4_base, MultiChannelS4
import ipytest
import pytest
ipytest.autoconfig()

In [None]:
%%ipytest

@pytest.mark.parametrize("s4_param",[(1,3),(5,3)],ids=["D1N3","D5N3"])
def test_discretize(s4_param):
    D,N = s4_param
    model = S4_with_shared_A(channels=D,hidden_state=N)
    dA,dB = model.discretize()
    assert dA.shape == (N,N)
    assert dB.shape == (N,D)

[32m.[0m[32m.[0m[32m                                                                                           [100%][0m
[32m[32m[1m2 passed[0m[32m in 0.02s[0m[0m


In [None]:
%%ipytest

@pytest.mark.parametrize("seq_params",[(1,3),(7,7)],ids=["D1N3","D5N3"])
def test_propagate_dimensional(seq_params):
    D,L = seq_params
    model = S4_with_shared_A(channels=D,hidden_state=5)
    X = torch.rand((D,L))
    Y = model.propagate_RNN(X)
    assert Y.shape == (D,L)


def test_propagate_seq():
    L = 7
    model = S4_with_shared_A(channels=1,hidden_state=5)
    X = torch.rand(L)
    Y = model.propagate_RNN(X)
    assert Y.dim() == 1
    #assert Y.shape == torch.Size([L])

[32m.[0m[32m.[0m[32m.[0m[32m                                                                                          [100%][0m
[32m[32m[1m3 passed[0m[32m in 0.05s[0m[0m


In [None]:
%%ipytest

@pytest.mark.parametrize("params",[(1,7),(5,7)],ids=["D1","D5"])
def test_con_filter_2D(params):
    D,L = params
    model = S4_with_shared_A(channels=D,hidden_state=5,seed=42,kernel_max_size=L)
    X = torch.rand((D,L))
    Y_conv = model.propagate_convolution_filter(X,use_fourier=False)
    Y_fourier = model.propagate_convolution_filter(X,use_fourier=True)
    Y_rnn = model.propagate_RNN(X)
    assert Y_conv.shape == (D,L)
    assert Y_fourier.shape == (D,L)
    assert torch.allclose(Y_rnn.ravel(),Y_fourier.ravel())
    assert torch.allclose(Y_conv.ravel(),Y_fourier.ravel())
    assert torch.allclose(Y_conv.ravel(),Y_rnn.ravel())

[32m.[0m[32m.[0m[32m                                                                                           [100%][0m
[32m[32m[1m2 passed[0m[32m in 0.02s[0m[0m


In [None]:
%%ipytest

def test_con_filter_seq():
    D,L = 1,7
    model = S4_with_shared_A(channels=1,hidden_state=5,seed=42,kernel_max_size=L)
    X = torch.rand(L)
    Y_conv = model.propagate_convolution_filter(X,use_fourier=False)
    Y_fourier = model.propagate_convolution_filter(X,use_fourier=True)
    Y_rnn = model.propagate_RNN(X)
    assert Y_conv.shape == (L,)
    assert Y_fourier.shape == (L,)
    assert torch.allclose(Y_rnn.ravel(),Y_fourier.ravel())
    assert torch.allclose(Y_conv.ravel(),Y_fourier.ravel())
    assert torch.allclose(Y_conv.ravel(),Y_rnn.ravel())

[32m.[0m[32m                                                                                            [100%][0m
[32m[32m[1m1 passed[0m[32m in 0.01s[0m[0m


In [None]:
%%ipytest
def test_con_filter_seq():
    D,L = 1,7
    model = S4_with_shared_A(channels=1,hidden_state=5,seed=42,kernel_max_size=L,mode="Conv")
    X = torch.rand(L)
    Y_conv = model(X)
    assert Y_conv.shape == (L,)

@pytest.mark.parametrize("params",[(1,7),(5,7)],ids=["D1","D5"])
def test_con_filter_dim(params):
    D,L = params
    model = S4_with_shared_A(channels=D,hidden_state=L,seed=42,kernel_max_size=L,mode="Conv")
    X = torch.rand((D,L))
    Y_conv = model(X)
    assert Y_conv.shape == (D,L)

[32m.[0m[32m.[0m[32m.[0m[32m                                                                                          [100%][0m
[32m[32m[1m3 passed[0m[32m in 0.03s[0m[0m


In [152]:
%%ipytest

def test_discretize():
    D,N = 1,5
    model = S4_base(channels=D,hidden_state=N)
    dA,dB = model.discretize()
    assert dA.shape == (N,N)
    assert dB.shape == (N,D)


def test_con_filter_seq():
    D,L = 1,7
    model = S4_base(channels=1,hidden_state=5,seed=42,kernel_max_size=L,mode="Conv")
    X = torch.rand(L)
    Y_conv = model(X)
    assert Y_conv.shape == (L,)


def test_propagate_seq():
    L = 7
    model = S4_base(channels=1,hidden_state=5)
    X = torch.rand(L)
    Y = model.propagate_RNN(X)
    assert Y.dim() == 1
    assert Y.shape == torch.Size([L])



[32m.[0m[32m.[0m[32m.[0m[32m                                                                                          [100%][0m
[32m[32m[1m3 passed[0m[32m in 0.02s[0m[0m


In [196]:
%%ipytest
@pytest.mark.parametrize("params",[(1,7),(5,7)],ids=["D1","D5"])
def test_con_filter_dim(params):
    D,L = params
    model = MultiChannelS4(D=D,hidden_state=L,seed=42,kernel_max_size=L,mode="Conv")
    X = torch.rand((D,L))
    Y_conv = model.forward(X)
    assert Y_conv.shape == (D,L)

[32m.[0m[32m.[0m[32m                                                                                           [100%][0m
[32m[32m[1m2 passed[0m[32m in 0.04s[0m[0m
