In [109]:
using FFTW
using LinearAlgebra
using BenchmarkTools

┌ Info: Precompiling BenchmarkTools [6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf]
└ @ Base loading.jl:1273


In [75]:
function s_dot(Wl::Matrix, H::Matrix, lag)
    K, T = size(H)

    if (lag < 0)
        return Wl * H[:, 1-lag:T]

    else  # lag >= 0
        return Wl * H[:, 1:T-lag]
    end
end

function tensor_conv(W, H)
    L, N, K = size(W)
    T = size(H)[2]

    pred = zeros(N, T)
    for lag = 0:(L-1)
        pred[:, lag+1:T] += s_dot(W[lag+1, :, :], H, lag)
    end
    return pred
end

N, T, L, K = 30, 50, 7, 2
W = rand(L, N, K)
H = rand(K, T)
H[:, end-L:end] .= 0

X = tensor_conv(W, H);

In [85]:
# Zero pad H
Hpad = [zeros(K, L) H zeros(K, L)]

# Zero pad W
Wpad = vcat(zeros(L+T, N, K), W)

@show size(Wpad)
@show size(Hpad);

size(Wpad) = (64, 30, 2)
size(Hpad) = (2, 64)


In [86]:
# Take fft of H
Hpadhat1 = Array{Complex, 2}(undef, size(Hpad)...)
for k = 1:K
    Hpadhat1[k, :] = fft(Hpad[k, :])
end
Hpadhat2 = fft(Hpad, 2)

@show size(Hpadhat1)
@show size(Hpadhat2)
@show norm(Hpadhat1 - Hpadhat2)

ftH = Hpadhat2;

size(Hpadhat1) = (2, 64)
size(Hpadhat2) = (2, 64)
norm(Hpadhat1 - Hpadhat2) = 5.521894863136994e-15


In [87]:
# Take fft of W
ftW1 = Array{Complex, 3}(undef, size(Wpad)...)
for k = 1:K
    for n = 1:N
        ftW1[:, n, k] = fft(Wpad[:, n, k])
    end
end
ftW2 = fft(Wpad, 1)

@show size(ftW1)
@show size(ftW2)
@show norm(ftW1 - ftW2)

ftW = ftW2;

size(ftW1) = (64, 30, 2)
size(ftW2) = (64, 30, 2)
norm(ftW1 - ftW2) = 1.588150543073792e-14


In [88]:
# Compute tensor conv
ftX = Array{Complex{Float64}, 2}(undef, N, T+2*L)
ftX .= 0.0

for n = 1:N
    for k = 1:K
        ftX[n, :] .+= ftW[:, n, k] .* ftH[k, :]
    end
end

ifftX = real.(ifft(ftX, 2))[:, 1:end-2L]

@show norm(X - ifftX);
@show norm(fft([X zeros(N, 2L)], 2) - ftX)

;

norm(X - ifftX) = 2.2192279392499284e-14
norm(fft([X zeros(N, 2L)], 2) - ftX) = 2.0132894844227993e-13


In [131]:
function tensor_convft(W, H)
    ftW = fft(vcat(zeros(L+T, N, K), W), 1)
    ftH = copy(fft(Hpad, 2)')
    
    ftX = Array{Complex{Float64}, 2}(undef, N, T+2*L)
    ftX .= 0.0

    for k = 1:K
        @views ftX .+= (ftH[:, k] .* ftW[:, :, k])'
    end
    
    return real.(ifft(ftX, 2))[:, L+1:end-L]
end

@show norm(tensor_conv(W, H) - tensor_convft(W, H))

#@btime tensor_conv(W, H) samples=1
@btime tensor_convft(W, H) samples=1
;

norm(tensor_conv(W, H) - tensor_convft(W, H)) = 23.20900885154874
  446.559 μs (246 allocations: 342.89 KiB)


In [160]:
# Solve LS for H
# Set H = inv(S'S + 2I) * B
#  in the fourier domain
B = rand(K, T) .+ 3

ftW = fft(vcat(zeros(L+T, N, K), W), 1)
ftB = fft([zeros(K,L) B zeros(K,L)], 2)

# Solve for each column of H via KxK LS
ftH = Array{Complex{Float64}, 2}(undef, K, T+2*L)

for t = 1:T
    ftH[:, t] = (ftW[t, :, :]'ftW[t, :, :] .+ 2) \ ftB[:, t]
end

iftH = real.(ifft(ftH)[:, 1:end-2L])

2×50 Array{Float64,2}:
 0.00831528  0.000742991  -0.00264099  …  0.00786458  0.00955813   0.0171293
 0.00387465  0.00622225    0.0147462      0.0286908   0.000501934  0.0258675