In [2]:
using FFTW
using LinearAlgebra
using BenchmarkTools

## Objective and stuff

$$ \min \| M \cdot ( X - \hat X) \|^2 $$


$$ \hat X = \sum_\ell W_\ell H S_{\ell-1} $$

$$ \hat X = \sum_k W_k * h_k^T $$

$$ (W * h^T)_{nt} = \sum_{\ell=1}^L W_{n \ell} h_{t +1 - \ell} $$

In [3]:
function s_dot(Wl::Matrix, H::Matrix, lag)
    K, T = size(H)

    if (lag < 0)
        return Wl * H[:, 1-lag:T]

    else  # lag >= 0
        return Wl * H[:, 1:T-lag]
    end
end

function tensor_conv(W, H)
    L, N, K = size(W)
    T = size(H)[2]

    pred = zeros(N, T)
    for lag = 0:(L-1)
        pred[:, lag+1:T] += s_dot(W[lag+1, :, :], H, lag)
    end
    return pred
end

N, T, L, K = 30, 50, 7, 2
W = rand(L, N, K)
H = rand(K, T)
#H[:, end-L:end] .= 0

X = tensor_conv(W, H);

In [4]:
# Zero pad H
Hpad = [H zeros(K, L)]

# Zero pad W
Wpad = vcat(W, zeros(T, N, K))  # Potentially zero back

@show size(Wpad)
@show size(Hpad);

size(Wpad) = (57, 30, 2)
size(Hpad) = (2, 57)


In [5]:
# Take fft of H
Hpadhat1 = Array{Complex, 2}(undef, size(Hpad)...)
for k = 1:K
    Hpadhat1[k, :] = fft(Hpad[k, :])
end
Hpadhat2 = fft(Hpad, 2)

@show size(Hpadhat1)
@show size(Hpadhat2)
@show norm(Hpadhat1 - Hpadhat2)

ftH = Hpadhat2;

size(Hpadhat1) = (2, 57)
size(Hpadhat2) = (2, 57)
norm(Hpadhat1 - Hpadhat2) = 0.0


In [6]:
# Take fft of W
ftW1 = Array{Complex, 3}(undef, size(Wpad)...)
for k = 1:K
    for n = 1:N
        ftW1[:, n, k] = fft(Wpad[:, n, k])
    end
end
ftW2 = fft(Wpad, 1)

@show size(ftW1)
@show size(ftW2)
@show norm(ftW1 - ftW2)

ftW = ftW2;

size(ftW1) = (57, 30, 2)
size(ftW2) = (57, 30, 2)
norm(ftW1 - ftW2) = 0.0


In [7]:
# Compute tensor conv
ftX = Array{Complex{Float64}, 2}(undef, N, T+L)
ftX .= 0.0

for n = 1:N
    for k = 1:K
        ftX[n, :] .+= ftW[:, n, k] .* ftH[k, :]
    end
end

ifftX = real.(ifft(ftX, 2))[:, 1:end-L]

@show norm(X - ifftX);
#@show norm(fft([X zeros(N, L)], 2) - ftX)


#@show X[1, 1:20]
#println()
#@show ifftX[1, 1:20]
;

norm(X - ifftX) = 2.8061100569805064e-14


In [8]:
function tensor_convft(W, H)
    
    ftW = fft(vcat(W, zeros(T, N, K)), 1)
    ftH = fft([H zeros(K, L)], 2)
    
    ftX = Array{Complex{Float64}, 2}(undef, N, T+L)
    ftX .= 0.0

    for n = 1:N
        for k = 1:K
            @. ftX[n, :] += ftH[k, :] * ftW[:, n, k]
        end
    end
    
    return real.(ifft(ftX, 2))[:, 1:end-L]
end

@show norm(tensor_conv(W, H) - tensor_convft(W, H))

@btime tensor_conv(W, H) samples=1
@btime tensor_convft(W, H) samples=1
;

norm(tensor_conv(W, H) - tensor_convft(W, H)) = 2.8061100569805064e-14
  165.843 μs (36 allocations: 256.08 KiB)
  2.513 ms (1295 allocations: 452.69 KiB)


Say we want to find some $h$ to minimize

$$ \| x - T(w) h \|^2 + \| I h \|^2$$


$$ \| \hat x - \hat w \cdot \hat h \|^2 $$


$$ \hat h = \frac{\hat w}{\hat x}$$


when $$\hat x \neq 0$$.

In [9]:
# Solve LS for H
# Set H = inv(W'W + 2I) * B
#  in the fourier domain

B = rand(K, T) .+ 3

ftW = fft(vcat(zeros(L+T, N, K), W), 1)
ftB = fft([zeros(K,L) B zeros(K,L)], 2)

# Solve for each column of H via KxK LS
ftH = Array{Complex{Float64}, 2}(undef, K, T+2*L)

for t = 1:T
    ftH[:, t] = (ftW[t, :, :]'ftW[t, :, :] .+ 2) \ ftB[:, t]
end

iftH = real.(ifft(ftH)[:, 1:end-2L])

2×50 Array{Float64,2}:
 0.00613676  NaN  NaN  NaN  NaN  NaN  …  NaN  NaN  NaN  -0.000666076  NaN
 0.00535508  NaN  NaN  NaN  NaN  NaN     NaN  NaN  NaN  -0.000992962  NaN