In [None]:
using HePPCAT, LinearAlgebra, Plots, Random, Statistics, StatsPlots, Printf
using Plots.PlotMeasures
include("petrels.jl")
include("grouse.jl")
include("shasta1.jl")
include("shasta2.jl")
include("plotting.jl")

### Incremental computation over batch (fully observed)

In [None]:
# d, n, v, λ = 50, [200,800], [1,4], [4,2,1]
# d, n, v, λ = 50, [200,800], [0.1,1], [4,2,1]
# d, n, v, λ = 50, [200,800], [1e-4,1e-2], [4,2,1]

# d, n, v, λ = 50, [500,2000], [0.001,0.01], [4,2,1]
d, n, v, λ = 100, [500,2000], [1e-2,1e-1], [4,2,1]
# d, n, v, λ = 50, [500,2000], [1e-2,1e-2], [4,2,1]

# d, n, v, λ = 50, [100,1000], [1e-2,1], [4,2,1]
k, L = length(λ), length(v)
khat = 3;

num_trials = 10

# Generate data
Random.seed!(0)
U = qr(rand(d,k)).Q[:,1:k]
F = U*sqrt(Diagonal(λ))
missing = 0.0
Ω = [(rand(d,n[l]) .> missing) for l in 1:L]
Ytrue = [F*randn(k,n[l]) + sqrt(v[l])*randn(d,n[l]) for l in 1:L]
Y = [Ω[l] .* Ytrue[l] for l in 1:L]

### HePPCAT (Batch)
heppcat_time_log = []
for _=1:20
    tHeppcat = @elapsed begin
        Mheppcat = heppcat(Y,k,100) 
    end
    push!(heppcat_time_log,tHeppcat)
end
tHeppcat = median(heppcat_time_log)
Mheppcat = heppcat(Y,k,100) 

### True model
Mtrue = HePPCATModel(U,λ,I(k),v)

### Homoscedastic PPCA
cor = sum(Y[l]*Y[l]' for l in 1:L)/sum(n)
λh, Uh = eigen(Hermitian(cor),sortby=-)
λb = mean(λh[khat+1:end])
Mppca = HePPCATModel(Uh[:,1:khat],λh[1:khat] .- λb,I(khat),fill(λb,L))

### PPCA-G1
cor = Y[1]*Y[1]'/sum(n[1])
λh, Uh = eigen(Hermitian(cor),sortby=-)
λb = mean(λh[khat+1:end])
Mppca1 = HePPCATModel(Uh[:,1:khat],λh[1:khat] .- λb,I(khat),fill(λb,L))

### PPCA-G2
cor = Y[2]*Y[2]'/sum(n[2])
λh, Uh = eigen(Hermitian(cor),sortby=-)
λb = mean(λh[khat+1:end])
Mppca2 = HePPCATModel(Uh[:,1:khat],λh[1:khat] .- λb,I(khat),fill(λb,L))


### Streaming-setup
Ymat_true = hcat([Ytrue[l] for l=1:L]...)
Ymat = hcat([Y[l] for l=1:L]...)
vmat = vcat([v[l]*ones(n[l]) for l=1:L]...)
groups = Int64.(vcat([l*ones(n[l]) for l=1:L]...));
ΩY = abs.(Ymat) .> 0

dataIdx = randperm(sum(n))
# dataIdx = 1:sum(n)
Ymat = Ymat[:,dataIdx]
Ymat_true = Ymat_true[:,dataIdx]
ΩY = ΩY[:,dataIdx]
groups = groups[dataIdx]

Ltrue = loglikelihood(Mtrue,Ytrue)
stats_fcn(M) = loglikelihood(M,Ytrue) - Ltrue

### Streaming over each group
dataIdx1 = randperm(n[1])
dataIdx2 = randperm(n[2])
Y1mat = Y[1][:,dataIdx1]
Y2mat = Y[2][:,dataIdx2]
ΩY1 = Ω[1][:,dataIdx1]
ΩY2 = Ω[2][:,dataIdx2];

In [None]:
function Fmeasure(M)
#     Fhat = M.U * Diagonal(M.λ).^(0.5) * M.Vt
#     return norm(Fhat*Fhat' - F*F') / norm(F*F')
    Uf = svd(F).U[:,1:k]
    Uhat = M.U
    
    return norm(Uhat*Uhat' - Uf*Uf') / norm(Uf*Uf')
end

##### Run PETRELS (full data)

In [None]:
# M0 = HePPCATModel(Matrix(qr!(randn(d,khat)).Q),rand(khat),I(khat),v)
# Mstream = deepcopy(M0)
λ = 1.0
# λ = sqrt(2*d*pi*maximum(v))
δ = 1e-1
petrels_stats_trials = []
petrels_err_trials = []
petrels_time_trials = []

### Compile for dummy run
M0 = HePPCATModel(Matrix(qr!(randn(d,khat)).Q),rand(khat),I(khat),rand(L))
    Mstream = deepcopy(M0)
    Mstream, Yrec, stats_log, err_log, time_log = PETRELS(Mstream,Ymat,ΩY,λ,δ)

Random.seed!(0)
for _=1:num_trials    
    M0 = HePPCATModel(Matrix(qr!(randn(d,khat)).Q),rand(khat),I(khat),rand(L))
    Mstream = deepcopy(M0)
    Mstream, Yrec, stats_log, err_log, time_log = PETRELS(Mstream,Ymat,ΩY,λ,δ)
    push!(petrels_stats_trials,stats_log)
    push!(petrels_err_trials,err_log)
    push!(petrels_time_trials,time_log)
end

##### PETRELS (Group 1)

In [None]:
# M0 = HePPCATModel(Matrix(qr!(randn(d,khat)).Q),rand(khat),I(khat),v)
# Mstream = deepcopy(M0)
λ = 1.0
# λ = sqrt(2*d*pi*maximum(v))
δ = 1e-1
petrels_stats_trials_g1 = []
petrels_err_trials_g1 = []
petrels_time_trials_g1 = []

### Compile for dummy run
M0 = HePPCATModel(Matrix(qr!(randn(d,khat)).Q),rand(khat),I(khat),rand(L))
Mstream = deepcopy(M0)
Mstream, Yrec, stats_log, err_log, time_log = PETRELS(Mstream,Y1mat,ΩY1,λ,δ)

Random.seed!(0)
for _=1:num_trials    
    M0 = HePPCATModel(Matrix(qr!(randn(d,khat)).Q),rand(khat),I(khat),rand(L))
    Mstream = deepcopy(M0)
    Mstream, Yrec, stats_log, err_log, time_log = PETRELS(Mstream,Y1mat,ΩY1,λ,δ)
    push!(petrels_stats_trials_g1,stats_log)
    push!(petrels_err_trials_g1,err_log)
    push!(petrels_time_trials_g1,time_log)
end

##### PETRELS (Group 2)

In [None]:
# M0 = HePPCATModel(Matrix(qr!(randn(d,khat)).Q),rand(khat),I(khat),v)
# Mstream = deepcopy(M0)
λ = 1.0
# λ = sqrt(2*d*pi*maximum(v))
δ = 1e-1
petrels_stats_trials_g2 = []
petrels_err_trials_g2 = []
petrels_time_trials_g2 = []

### Compile for dummy run
M0 = HePPCATModel(Matrix(qr!(randn(d,khat)).Q),rand(khat),I(khat),rand(L))
Mstream = deepcopy(M0)
Mstream, Yrec, stats_log, err_log, time_log = PETRELS(Mstream,Y2mat,ΩY2,λ,δ)
Random.seed!(0)
for _=1:num_trials    
    M0 = HePPCATModel(Matrix(qr!(randn(d,khat)).Q),rand(khat),I(khat),rand(L))
    Mstream = deepcopy(M0)
    Mstream, Yrec, stats_log, err_log, time_log = PETRELS(Mstream,Y2mat,ΩY2,λ,δ)
    push!(petrels_stats_trials_g2,stats_log)
    push!(petrels_err_trials_g2,err_log)
    push!(petrels_time_trials_g2,time_log)
end

##### Run GROUSE

In [None]:
grouse_stats_trials = []
grouse_err_trials = []
grouse_time_trials = []

grouse_step = 0.01
# grouse_step = 0

### Compile for dummy run
M0 = HePPCATModel(Matrix(qr!(randn(d,khat)).Q),rand(khat),I(khat),rand(L))
    Mstream = deepcopy(M0)
    Mstream, stats_log, err_log, time_log = GROUSE(Mstream,Ymat,ΩY,grouse_step,stats_fcn,Fmeasure)
Random.seed!(0)
for _=1:num_trials    
    M0 = HePPCATModel(Matrix(qr!(randn(d,khat)).Q),rand(khat),I(khat),rand(L))
    Mstream = deepcopy(M0)
    Mstream, stats_log, err_log, time_log = GROUSE(Mstream,Ymat,ΩY,grouse_step,stats_fcn,Fmeasure)
    push!(grouse_stats_trials,stats_log)
    push!(grouse_err_trials,err_log)
    push!(grouse_time_trials,time_log)
end

##### Run SHASTA-PCA-2

In [None]:
mutable struct LearningRateParams
    w::Float64
    cf::Float64
    cv::Float64
end

In [None]:
shasta_stats_trials = []
shasta_err_trials = []
shasta_time_trials = []
shasta_models = []

w = 1
cf = 0.1
# cf = 1
# cv = 1
cv = 0.1

# w = 0.001
# wf = 1
# wv = 1

lrparams = LearningRateParams(w,cf,cv)


L = length(Y)
δ = 0.1
Yrec = deepcopy(Ymat)

### Compile the function on a dummy run
M0 = HePPCATModel(Matrix(qr!(randn(d,khat)).Q),rand(khat),I(khat),rand(L))
Mstream = deepcopy(M0)
# Mstream, Yrec, stats_log, err_log, time_log = SHASTA_PCA(Mstream,Ymat,ΩY,groups,w1,w2,wf,wv,L,δ,Fmeasure,stats_fcn)
# Mstream, Yrec, stats_log, err_log, time_log= SHASTA_PCA(Mstream,Ymat,ΩY,groups,lrparams,δ,Fmeasure,stats_fcn)

Random.seed!(0)
for _=1:num_trials
    M0 = HePPCATModel(Matrix(qr!(randn(d,khat)).Q),rand(khat),I(khat),rand(L))
    Mstream = deepcopy(M0)
#     Mstream, Yrec, stats_log, err_log, time_log = SHASTA_PCA(Mstream,Ymat,ΩY,groups,w1,w2,wf,wv,L,δ,Fmeasure,stats_fcn)
    Mstream, Yrec, stats_log, err_log, time_log= SHASTA_PCA2(Mstream,Ymat,ΩY,groups,lrparams,δ,Fmeasure,stats_fcn)
    push!(shasta_stats_trials,stats_log)
    push!(shasta_err_trials,err_log)
    push!(shasta_time_trials,time_log)
    push!(shasta_models,Mstream)
end

##### Plot the results

In [None]:
stats = [petrels_stats_trials, grouse_stats_trials, shasta_stats_trials]
errs = [petrels_err_trials,petrels_err_trials_g1,petrels_err_trials_g2, grouse_err_trials, shasta_err_trials]
times = [petrels_time_trials,grouse_time_trials, shasta_time_trials]
labels = ["PETRELS","PETRELS-G1","PETRELS-G2","GROUSE","SHASTA-PCA-2"]
colors = Dict("PETRELS"=>:orange,"PETRELS-G1"=>:orange,"PETRELS-G2"=>:orange,"GROUSE"=>:purple,"SHASTA-PCA-2"=>:blue)
markers = Dict("PETRELS"=>:utriangle,"PETRELS-G1"=>:rect,"PETRELS-G2"=>:hexagon,"GROUSE"=>:circle,"SHASTA-PCA-2"=>:diamond)

# stats = [shasta_stats_trials]
# errs = [shasta_err_trials]
# times = [shasta_time_trials]
# labels = ["SHASTA-PCA"]
# colors = Dict("SHASTA-PCA"=>:blue)

missing_rate = Int(round((1-missing)*100))
interval = 250
alpha = 0.05
fontsize = 10
figsize = (400,300)

p1 = plottraces(errs,labels,colors,markers,interval,alpha,figsize,fontsize,"Subspace Error: $missing_rate % observed","Iteration",:outerright,:log)
hline!(p1,[Fmeasure(Mheppcat)],linestyle=:dash,width=2,color=:red,label="HePPCAT")
hline!(p1,[Fmeasure(Mppca)],linestyle=:dash,width=2,color=:black,label="PPCA")
hline!(p1,[Fmeasure(Mppca1)],linestyle=:dash,width=2,color=:green,label="PPCA-G1")
hline!(p1,[Fmeasure(Mppca2)],linestyle=:dash,width=2,color=:lightgreen,label="PPCA-G2")

figsize = (400,300)
colors = Dict("SHASTA-PCA-2"=>:blue)
p2 = plottraces([shasta_stats_trials],["SHASTA-PCA-2"],colors,markers,interval,alpha,figsize,fontsize,"Log-likelihood: $missing_rate % observed","Iteration",false,:linear)
hline!(p2,[stats_fcn(Mheppcat)],linestyle=:dash,width=2,color=:red,label="HePPCAT")
hline!(p2,[stats_fcn(Mppca)],linestyle=:dash,width=2,color=:black,label="PPCA")
# hline!(p2,[stats_fcn(Mppca1)],linestyle=:dash,width=2,color=:green,label="PPCA-G1")
# hline!(p2,[stats_fcn(Mppca2)],linestyle=:dash,width=2,color=:green,label="PPCA-G2")

ylims!(p2,(-1e5,1e4))
# ylims!(p2,(-1e5,1e4))
p = plot(p2,p1,layout = @layout([q1{0.38w} q2]),size=(750,300),bottom_margin=5mm)
# p = plot(p1,p2,layout = @layout([q1{0.5w} q2]),size=(800,300),bottom_margin=5mm)

file = "/Users/kgilman/Desktop/streaming-hppca/Streaming-Heteroscedastic-PPCA/shasta_results/online_algs_iteration-d_$d-n_$n-observedPercent_$missing_rate-v1_" * @sprintf("%.2E", v[1]) * "-v2_"* @sprintf("%.2E", v[2]) * ".png"
savefig(p,file)
plot(p)