In [89]:
using LinearAlgebra
using BenchmarkTools
using Random
using Profile
using IterativeSolvers

## Efficient convolution

In [87]:
function sdot(Wl::Matrix{Float64}, H::Matrix{Float64}, lag::Int64)
    K, T = size(H)

    if (lag < 0)
        return Wl * H[:, 1-lag:T]

    else  # lag >= 0
        return Wl * H[:, 1:T-lag]
    end
end

function shift_cols(X::Matrix{Float64}, lag::Int64)
    T = size(X)[2]
    
    if (lag <= 0)
        return X[:, 1-lag:T]

    else  # lag > 0
        return X[:, 1:T-lag]
    end
end

function shift_and_stack(H::Matrix{Float64}, L::Int64)
    K, T = size(H)

    H_stacked = zeros(L*K, T)
    for lag = 0:(L-1)
        H_stacked[1+K*lag:K*(lag+1), 1+lag:T] = shift_cols(H, lag)
    end

    return H_stacked
end

function unfold_W(W::Array{Float64, 3})
    L, N, K = size(W)
    return reshape(permutedims(W, [2, 3, 1]), N, L*K)
end
    
"""Current implementation"""
function tconv(W::Array{Float64, 3}, H::Matrix{Float64})
    L, N, K = size(W)
    T = size(H, 2)

    pred = zeros(N, T)
    for lag = 0:(L-1)
        pred[:, lag+1:T] .+= sdot(W[lag+1, :, :], H, lag)
    end
    return pred
end

"""Row major ordering"""
function tconv2(W::Array{Float64, 3}, H::Matrix{Float64})
    N, K, L = size(W)
    T = size(H, 2)
    
    pred = zeros(N, T)
    @simd for lag = 0:(L-1)
        @views @inbounds pred[:, lag+1:T] .+= W[:, :, lag+1] * H[:, 1:T-lag]
    end
    return pred
end

"""Block everything"""
function tconv3(W::Array{Float64, 3}, H::Matrix{Float64}, L::Int64, K::Int64)
    N, LK = size(W)
    T = size(H, 2)
    
    return W * shift_and_stack(H, L)
end

"""Elementwise. (how to order loops?)"""
function tconv4(W, H)
    L, N, K = size(W)
    T = size(H, 2)
    
    X = zeros(N, T)
    for t = 1:T
        for k = 1:K
            for n = 1:N
                for l = 1:min(L, t)
                    X[n, t] += W[l, n, k] * H[k, t+1-l]
                end
            end
        end
    end
    
    return X
end

#"""TODO outer products"""
Random.seed!(1234)

N, T = 150, 1000
K, L = 15, 100

H = rand(K, T)

Wlnk = rand(L, N, K)
Wnkl = zeros(N, K, L)
for n = 1:N
    for k = 1:K
        for l=1:L
            Wnkl[n, k, l] = Wlnk[l, n, k]
        end
    end
end

Wstack = deepcopy(unfold_W(Wlnk));

In [88]:
@btime tconv($Wlnk, $H) samples=20
@btime tconv2($Wnkl, $H) samples=20
@btime tconv3($Wstack, $H, $L, $K) samples=20
#@btime tconv4($Wlnk, $H) samples=20
;

  97.347 ms (802 allocations: 231.33 MiB)
  57.672 ms (602 allocations: 109.95 MiB)
  20.342 ms (204 allocations: 23.48 MiB)


## Efficient transpose-convolution

In [85]:
function revshift_and_stack(X, L)
    N, T = size(X)

    Xstacked = zeros(L*N, T)
    @simd for lag = 0:(L-1)
        @views @inbounds Xstacked[1+N*lag:N*(lag+1), 1:T-lag] .= shift_cols(X, -lag)
    end

    return Xstacked
end

"""Block everything"""
function tconvT(W, X, L, K)
    N, LK = size(W)
    T = size(H, 2)
    
    return W * revshift_and_stack(X, L)
end

"""Generic"""
function tconvT2(W, X)
    L, N, K = size(W)
    T = size(X)[2]

    result = zeros(K, T)
    @simd for lag = 0:(L-1)
        @inbounds result[:, 1:T-lag] .+= W[lag+1, :, :]' * shift_cols(X, -lag)
    end

    return result
end

"""Column major"""
function tconvT3(W::Array{Float64, 3}, X::Matrix{Float64})
    K, N, L = size(W)
    T = size(X)[2]
    
    res = zeros(K, T)
    @simd for lag = 0:(L-1)
        @views @inbounds res[:, 1:T-lag] .+= W[:, :, lag+1] * X[:, 1+lag:T]
    end
    
    return res
end

"""Precompute shift and stack"""
function tconvT4(Wstack, Xstack)
    return Wstack * Xstack
end

Random.seed!(1234)

N, T = 150, 1000
K, L = 15, 100

X = rand(N, T)
W = rand(L, N, K)
Wc = deepcopy(reshape(W, N, K, L))
Wcc = deepcopy(reshape(W, K, N, L))
Wrevstack = deepcopy(reshape(W, K, N*L))
Xstack = revshift_and_stack(X, L)
result = zeros(K, T)
;

In [86]:
@btime tconvT($Wrevstack, $X, $L, $K) samples=20
@btime tconvT2($W, $X) samples=20
@btime tconvT3($Wcc, $X) samples=20
@btime tconvT4($Wrevstack, $Xstack) samples=20
;

  151.428 ms (305 allocations: 223.35 MiB)
  67.270 ms (802 allocations: 132.41 MiB)
  19.023 ms (602 allocations: 11.03 MiB)
  18.364 ms (2 allocations: 117.27 KiB)


## Efficient Least Squares

In [182]:
"""
----------
Updating W
----------
"""

"""Generic backslash"""
function lsqW(H, X, L)
    A = shift_and_stack(H, L)'
    B = X'
    
    return (A \ B)'
end

"""Normal eqn backslash"""
function lsqW_normal(H, X, L)
    A = shift_and_stack(H, L)'
    B = X'
    AtA = A'A
    AtB = A'B
        
    return (AtA \ AtB)'
end


"""Conjugate gradient"""
function lsqW_cg(H, X, L)
    A = shift_and_stack(H, L)'
    B = X'
    AtA = A'A
    AtB = A'B
        
    res = zeros(size(AtB))
    for k = 1:K
        res[:, k] .= cg(AtA, AtB[:, k], maxiter=50)
    end
    
    return res'
end

#"""TODO Gradient descent"""
#"""TODO Sketching"""
Random.seed!(1234)

N, T = 140, 1000
K, L = 15, 99

X = rand(N, T)
H = rand(K, T)
;

In [179]:
A = rand(100, 10)
b = rand(100)

norm((A \ b) - (A'A \ A'b))
cond(A)

7.17804385340086

In [183]:
W1 = lsqW(H, X, L)
W2 = lsqW_normal(H, X, L)
W3 = lsqW_cg(H, X, L)

println("Backslash resid: ", norm(W1 * shift_and_stack(H, L) - X))
println("Normal eqn resid: ", norm(W2 * shift_and_stack(H, L) - X))
println("Conj grad resid: ", norm(W3 * shift_and_stack(H, L) - X))

# Problem is ill conditioned!

Backslash resid: 4.355072103562693e-13
Normal eqn resid: 1.0670594419677653e-9
Conj grad resid: 203.80252887809687


In [184]:
@btime lsqW($H, $X, $L) samples=1
@btime lsqW_normal($H, $X, $L) samples=1
#@btime lsqW_cg($H, $X, $L) samples=1
;

  980.552 ms (8247 allocations: 57.08 MiB)
  240.199 ms (211 allocations: 58.95 MiB)
