In [1]:
from s4 import S4D
from new_s4 import S4D as S4D_New
import torch

CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) not found. Install by going to extensions/kernels/ and running `python setup.py install`, for improved speed and memory efficiency. Note that the kernel changed for state-spaces 4.0 and must be recompiled.
Falling back on slow Cauchy and Vandermonde kernel. Install at least one of pykeops or the CUDA extension for better speed and memory efficiency.


In [2]:

old = S4D(d_model=10,
        d_state=64,
        s4d_exp=12,
        dropout=0.0,
        is_real=False,
        transposed=True,
        bottleneck=None,
        skip = False,
        quantize = False,
        final_act = None,
        activation = "relu")
new = S4D_New(d_model= 10, d_state=64 // 2)


In [3]:
assert torch.all(old.layer.kernel.A_real == new.A_real)
assert torch.all(old.layer.kernel.A_imag == new.A_imag)
assert torch.all(old.layer.kernel.B == new.B)
assert torch.all(old.layer.kernel.C == new.C)
assert torch.all(old.layer.kernel.inv_dt == new.inv_dt)

In [4]:
inp = torch.rand(1, 10, 100)

out_new_conv = new(inp)
out_old_conv = old(inp)

In [5]:
torch.all(out_old_conv == out_new_conv)

tensor(True)

In [6]:
out_old_conv.sum().backward()
out_new_conv.sum().backward()

In [7]:

assert torch.all(old.layer.kernel.B.grad == new.B.grad)
assert torch.all(old.layer.kernel.C.grad == new.C.grad)
assert torch.all(old.layer.kernel.A_real.grad == new.A_real.grad)
assert torch.all(old.layer.kernel.A_imag.grad == new.A_imag.grad)
assert torch.all(old.layer.kernel.inv_dt.grad == new.inv_dt.grad)


In [8]:

old.layer.setup_step()
old_state = old.layer.default_state(1)

In [9]:
new.setup_step()
new_state = new.default_state(1)

In [10]:
assert torch.all(old_state == new_state)

In [11]:
for t in range(inp.shape[2]):
    out_old, old_state = old.layer.step(inp[:, :, t], old_state)
    out_new, new_state = new.step(inp[:, :, t], new_state)

    assert torch.all(old_state == new_state)
    assert torch.all(out_old == out_new)

AssertionError: 

In [None]:
out_old

tensor([[0.0000, 0.0000, 1.0238, 0.0000, 0.1436, 0.7919, 0.1387, 0.0000, 0.2731,
         0.2578]], grad_fn=<ReluBackward0>)

In [None]:
out_new

tensor([[0.0000, 0.0000, 1.0238, 0.0000, 0.1436, 0.7919, 0.1387, 0.0000, 0.2731,
         0.2578]], grad_fn=<ReluBackward0>)

In [None]:
out_new.shape

torch.Size([1, 10])

In [None]:
out_new_conv[..., -1]

tensor([[0.0000, 0.0000, 1.0238, 0.0000, 0.1436, 0.7919, 0.1387, 0.0000, 0.2731,
         0.2578]], grad_fn=<SelectBackward0>)

In [None]:
out_old_conv[..., -1]

tensor([[0.0000, 0.0000, 1.0238, 0.0000, 0.1436, 0.7919, 0.1387, 0.0000, 0.2731,
         0.2578]], grad_fn=<SelectBackward0>)

In [13]:
out_new - out_new_conv[..., -1]

tensor([[ 0.0000e+00,  0.0000e+00, -5.9605e-07,  0.0000e+00, -1.1921e-07,
         -9.5367e-07, -4.0233e-07,  0.0000e+00, -8.3447e-07,  8.0466e-07]],
       grad_fn=<SubBackward0>)