In [1]:
import torch

In [2]:
N, T, D = 20, 128, 378

In [3]:
# Y[:, 0] = X[:, 0]
# Y[:, t] = A[:, t-1] * Y[:, t-1] + X[:, t]

def pscan_fft(A, X):
    N, T, D = X.shape

    # A_log \in [N x T]
    A_log = torch.log(A.to(dtype=torch.cfloat))
    # A_log_T \in [T x N]
    A_log_T = A_log.T
    # A_log_T \in [(2T - 1) x N]
    A_log_T = torch.cat([A_log_T, torch.zeros(T - 1, N)], dim=0)

    # For T = 3
    # mask1 = [1, 1, 1, 0, 0]
    # circulant_matrix = [
    #    [1, 0, 0, 1, 1],
    #    [1, 1, 0, 0, 1],
    #    [1, 1, 1, 0, 0],
    #    [0, 1, 1, 1, 0],
    #    [0, 0, 1, 1, 1],
    # ]
    mask1 = torch.where(
        (torch.arange(2 * T - 1) <= T - 1),
        1, 
        0
    )
    mask1 = mask1.unsqueeze(1)
    Z1_log_rev = torch.fft.ifft(
        torch.fft.fft(mask1, dim=0) * torch.fft.fft(A_log_T, dim=0),
        n=2 * T - 1,
        dim=0
    )
    # Since we add T - 1 of padding zeros to A_log_T
    Z1_log_rev = Z1_log_rev[:T, :].T.unsqueeze(-1)

    # For T = 4 and t = 2
    # mask2[0] = [0, 0, 0, 0, 0, 0, 0]
    # mask2[1] = [0, 1, 0, 0, 0, 0, 0]
    # mask2[2] = [0, 1, 1, 0, 0, 0, 0]
    # mask2[3] = [0, 1, 1, 1, 0, 0, 0]
    #
    # for t = 2
    # circulant_matrix = [
    #    [0, 0, 0, 0, 0, 1, 1],
    #    [1, 0, 0, 0, 0, 0, 1],
    #    [1, 1, 0, 0, 0, 0, 0],
    #    [0, 1, 1, 0, 0, 0, 0],
    #    [0, 0, 1, 0, 0, 0, 0],
    #    [0, 0, 0, 1, 0, 0, 0],
    #    [0, 0, 0, 1, 1, 0, 0],
    # ]
    mask2 = torch.where(
        torch.cat([
            ((torch.arange(2 * T - 1) >= 1) & (torch.arange(2 * T - 1) <= t)).unsqueeze(0) for t in range(T)
        ], dim=0),
        1, 
        0
    )
    mask2 = mask2.unsqueeze(-1)
    Z2_log_rev = torch.fft.ifft(
        torch.fft.fft(mask2, dim=1) * torch.fft.fft(A_log_T.unsqueeze(0), dim=1), 
        n=2 * T - 1,
        dim=1
    )
    # Since we add T - 1 of padding zeros to A_log_T
    Z2_log_rev = Z2_log_rev[:, :T, :]
    Z2_log_rev = Z2_log_rev.permute(2, 0, 1)
    # Fixing the problem casued by line 3 in the example
    Z2_log_rev = torch.tril(Z2_log_rev, diagonal=0)
    
    Z_log = Z1_log_rev - Z2_log_rev
    # Z \in [N x T x T]
    Z = torch.tril(torch.exp(Z_log), diagonal=0)
    # After exp we no longer have complex components
    Z = Z.real
    # Y \in [N x T x D] = bmm([N x T x T], [N x T x D])
    Y_ = torch.bmm(Z, X)
    Y_ = torch.cat([torch.zeros(N, 1, D), Y_[:, :-1, :]], dim=1) 
    Y = Y_ + X
    return Y    

In [4]:
A = torch.randn(N, T).requires_grad_() / 10
X = torch.randn(N, T, D).requires_grad_() / 1000

In [5]:
def test_Y(A, X):
    Y_fft = pscan_fft(A, X)
    
    Y_expected = torch.zeros(N, T, D)
    Y_expected[:, 0, :] = X[:, 0, :]
    for k in range(1, X.shape[1]):
        Y_expected[:, k, :] = A[:, k - 1].unsqueeze(1) * Y_expected[:, k - 1, :] + X[:, k, :]
    
    return torch.norm(Y_fft - Y_expected)

test_Y(A=A, X=X)

tensor(2.3222e-06, grad_fn=<LinalgVectorNormBackward0>)

In [6]:
def test_Z1(A, X):
    A_T = A.T
    A_T = torch.cat([A_T, torch.zeros(T - 1, N)], dim=0)
    
    mask2 = torch.where(
        torch.cat([
            ((torch.arange(2 * T - 1) >= 1) & (torch.arange(2 * T - 1) <= t)).unsqueeze(0) for t in range(T)
        ], dim=0),
        1, 
        0
    )
    mask2 = mask2.unsqueeze(-1)
    
    Z2_ = torch.fft.irfft(
        torch.fft.rfft(mask2, dim=1) * torch.fft.rfft(A_T.unsqueeze(0), dim=1), 
        n=2 * T - 1,
        dim=1
    )
    Z2_fft = Z2_[:, :T, :]
    
    def C(t):
        C_ = torch.tril(torch.ones(t, t), diagonal=-1) 
        C_ = torch.cat([C_, torch.zeros(t, T - t)], dim=-1)
        C_ = torch.cat([C_, torch.zeros(T - t, T)], dim=0)
        return C_
    
    Z2_expected = torch.zeros(T, T, N)
    for t in range(1, T + 1):
        Z2_expected[t - 1, :, :] = C(t) @ A.T

    return torch.norm(Z2_expected.permute(2, 0, 1) - torch.tril(Z2_fft.permute(2, 0, 1), diagonal=0))

test_Z1(A=A, X=X)


tensor(4.9501e-05, grad_fn=<LinalgVectorNormBackward0>)

In [7]:
def test_Z1(A, X):
    mask1 = torch.where(
        (torch.arange(2 * T - 1) <= T - 1),
        1, 
        0
    )
    mask1 = mask1.unsqueeze(1)
    
    A_T = A.T
    A_T = torch.cat([A_T, torch.zeros(T - 1, N)], dim=0)
    Z1_ = torch.fft.irfft(
        torch.fft.rfft(mask1, dim=0) * torch.fft.rfft(A_T, dim=0),
        n=2 * T - 1,
        dim=0
    )
    Z1_ = Z1_[:T, :]
    
    Z1 = torch.tril(torch.ones(T, T), diagonal=0) @ A.T

    return torch.norm(Z1_ - Z1)

test_Z1(A=A, X=X)

tensor(8.1853e-06, grad_fn=<LinalgVectorNormBackward0>)