In [67]:
using LinearAlgebra
using Plots
using Distributions
using LaTeXStrings
using Random
using TSVD
include("../../sesh3/src/functions.jl")
include("../../sesh2/src/functions.jl")

rank_7_CPD

# Exercise 1

In [135]:
function hosvd(C, r)
    S0 = copy(C)
    d = length(size(C))
    n = size(C)
    V = []
    singular_values = []
    errors = []
    for k=1:d
        B = reshape(S0, (prod([[r_i for  r_i in r[1:(k-1)]]..., [i_k for i_k in n[(k+1):end]]...])..., n[k]))
        U_hat, Sig_hat, V_hat = tsvd(convert(Matrix{Float64}, B), r[k])
        B_hat = U_hat * Diagonal(Sig_hat) * transpose(V_hat)
        W_hat = B * V_hat
        S0 = reshape(W_hat , ([r_i for r_i in r[1:k]]..., [n_k for n_k in n[k+1:end]]...))
        
        append!(V, [V_hat])
        append!(singular_values, [Sig_hat])
        C_hat = tucker_eval(S0, V)
        append!(errors, norm(C_hat - C)/norm(C))
    end
    return V, S0, singular_values, errors
end


function tucker_eval(S, V)
    d = length(V)+1
    A = copy(S)
    for (k, V_k) in enumerate(reverse(V))
        A = mode_k_dot(A, V_k, d-k)
    end
    return A
end

tucker_eval (generic function with 1 method)

In [136]:
d = 4;
n = [20+k for k=1:d];
r = [2*k for k=1:d];
V = [rand(Uniform(-1, 1), n[k], r[k]) for k=1:d];
S = rand(Uniform(-1, 1), r...);
A = tucker_eval(S, V);
V, S, singular_values, errors = hosvd(A, r);

# Exercise 2 

In [70]:
function f(x)
    s = 0
    for (k, xk) in enumerate(x)
        s = s .+ (xk.^2 / (8^(k-1)))
    end
    return (1 .+ s).^(-1)
end

function g(x)
    s1 = 0
    s2 = 0
    for (k, xk) in enumerate(x)
        s1 = s1 .+ (xk.^2 / (8^(k-1)))
        s2 = s2 .+ (4 * pi * xk)/(4^(k-1))
    end
    return sqrt.(s1) .* (1 .+ 1/2 * cos.(s2))
end

function unfold(A, k)
    B = permutedims(A, vcat(k, setdiff(1:length(size(A)), k)))
    d = size(B, 1)
    return reshape(B, (d, div(length(B),d)))
end

t(n) = [(2*(i-1)/(n-1) - 1) for i in 1:n]



# Exercise 3

In [80]:
eps_j = [1/(10^(j*2)) for j=1:6]
function rank_approx(C, eps_j=eps_j)
    r_jk = []
    singular_vec = []
    d = length(size(C))
    eps_jk = []
    for k=1:d
        C_k = tenmat(C, k)
        U, S, V = svd(C_k)
        append!(singular_vec, [S])
        for (j, e_j) in enumerate(eps_j)
            for r=1:d
                U_hat, S_hat, V_hat = tsvd(C_k, r)
                C_k_hat = U_hat * Diagonal(S_hat) * transpose(V_hat)
                if norm(C_k_hat-C_k)/norm(C_k) <= e_j
                    append!(eps_jk, norm(C_k_hat-C_k)/norm(C_k))
                    append!(r_jk, r)
                    break
                end
            end
        end
    end
    return reshape(r_jk, (length(eps_j), d)), reshape(eps_jk, (length(eps_j), d)), singular_vec
end

rank_approx (generic function with 2 methods)

In [103]:
d = 4
A = [f([t_i, t_j, t_k, t_l]) for t_i in t(d), t_j in t(d), t_k in t(d), t_l in t(d)];
B = [g([t_i, t_j, t_k, t_l]) for t_i in t(d), t_j in t(d), t_k in t(d), t_l in t(d)];

false

# Exercise 4, 5

In [371]:
function ttmps_eval(U, n)
    A = U[1]
    for U_k in U[2:end]
        A = A * U_k
    end
    return reshape(A, n)
end

function tt_svd(A, n, r, d)
    r_0 = 1
    r_new = [r_0, r...]
    S_0_hat = copy(A)
    C = []; singular_val = []; errors = []
    for k=2:d
        B_k = reshape(S_0_hat, (r_new[k-1] * n[k-1], prod([n[i] for i=k:d])))
        U_hat, Sig_hat, V_hat = tsvd(convert(Matrix{Float64}, B_k), r_new[k])
        C_k = reshape(U_hat, (r_new[k-1], n[k-1], r_new[k]))
        W_k_hat = Diagonal(Sig_hat) * transpose(V_hat)
        S_0_hat = reshape(W_k_hat, (r_new[k], [n[i] for i=k:d]...))
        append!(C, [C_k])
        append!(singular_val, [Sig_hat])
        
        #A_hat = ttmps_eval(S_0_hat, C)
        #append!(errors, norm(A_hat - A)/norm(A))
    end
    append!(C, [S_0_hat])
    return C, singular_val, errors
end

tt_svd (generic function with 1 method)

# Exercise 6

In [372]:
d = 4;
n = [20+k for k=1:d];
r = [[2*k for k=1:(d-1)]..., 1];
V = [rand(Uniform(-1, 1), n[k], r[k]) for k=1:d];
S = rand(Uniform(-1, 1), r...);
A = tucker_eval(S, V);
C, singular_val, errors = tt_svd(A, n, r, d);
