In [4]:
using NBInclude
@nbinclude("functions.ipynb")

All function imported


In [None]:
#setup
x_dim = 784
k = 20
epoch_to_load = 16 # Load the model 
tolerance = 1e-7; max_iter = 3000; out_toggle = 0
trials = 20
opt = Flux.Optimise.ADAM(.1);


dir = "trained_GNN/MNIST_identity"
encoder_μ, encoder_logvar, decoder, B = load_model_identity(dir, epoch_to_load)

_, _, decoder = load_model("trained_GNN/MNIST", 20)
A = Flux.params(decoder)[3]
z = randn(k) # ground code vector
F = dct(diagm(ones(x_dim)),2);

β_list, α_list = get_β_α(F, A, B, 21);
m_list = 20:20:440
# setup
# experiments with MNIST dataset

In [41]:
phase_matrix_recovery = zeros(trials, length(m_list), length(β_list))
phase_matrix_recon = zeros(trials, length(m_list), length(β_list))

Threads.@threads for i in 1:length(m_list)
    m = m_list[i]
    recovery_error_matrix = zeros(trials)
    recon_error_matrix = zeros(trials)
    for β in β_list
        recovery_error_list = []
        recon_error_list = []
        G = Chain(
            Dense(20, 500, relu, bias = false; init =(out,in) -> Flux.params(decoder)[1]),
            Dense(500, 500, relu, bias = false; init =(out,in) -> Flux.params(decoder)[2]),
            Dense(500, 784, identity, bias = false; init =(out,in) -> β * A + (1 - β) *B),
        )      #setup the generative network 

        Threads.@threads for trials in 1:trials
            F_sub = sample_fourier(m, x_dim)  # subsampling DCT for measurement matrix
            F_sub = F_sub * sqrt(x_dim)/sqrt(m) #normalization          
            y = F_sub * G(z) # measurement vector using subsampled DCT matrix

            z_est = randn(k) #initialization for algorithm
            z_est = estimated_code(opt, G, y, F_sub, z_est; max_iter, tolerance, out_toggle)  # run a optimizer to solve the least squares problem

            push!(recovery_error_list, relative_error(z, z_est))    
            push!(recon_error_list, relative_error(G(z), G(z_est)))
        end

        recovery_error_matrix =  hcat(recovery_error_matrix, recovery_error_list)
        recon_error_matrix = hcat(recon_error_matrix, recon_error_list)
    end
    phase_matrix_recovery[:, i, :] = recovery_error_matrix[:, 2:end]
    phase_matrix_recon[:, i, :] = recon_error_matrix[:, 2:end]
    println("Done with m = $m")
println("")
# end

Done with m = 20



In [36]:
save("saved_data/Phase_plot_full_MNIST.jld", "phase_matrix_recovery", phase_matrix_recovery, "phase_matrix_recon", phase_matrix_recon, "α_list", α_list, "m_list", m_list, "trials", trials)

In [89]:
M = zeros(size(phase_matrix_recovery[1,:,:]))
for trial in 1:trials
    M += (phase_matrix_recovery[trial,:,:] .< 1e-5)
end
p1 = plot(colorview(Gray, reverse(M/trials, dims = 1)), bg =:black, title = "recovery of code z₀", framestyle = :box)
xticks = 1:5:length(α_list)
xticks!(xticks, ["$(round.(α_list, digits = 2)[i])" for i in 1:5:length(α_list)])
yticks = 2:5:length(m_list)
yticks!(yticks, ["$(m_list[length(m_list) + 1 - i])" for i in 2:5:length(m_list)])

# m_pred = 600*α_list.^10
# ff(x) = -21/420 * x +(1 + 21*440/420)
# plot!(1:length(α_list), ff.(m_pred), label =:false, linewidth = 2 )

# m_pred = 150*α_list.^3
# plot!(1:length(α_list), ff.(m_pred), label =:false, linewidth = 2 )


M = zeros(size(phase_matrix_recon[1,:,:]))
for trial in 1:trials
    M += (phase_matrix_recovery[trial,:,:] .< 1e-5)
end

p2 = plot(colorview(Gray, reverse(M/trials, dims = 1)), bg =:black, title = "recovery of signal G(z₀)",framestyle = :box)
xticks = 1:5:length(α_list)
xticks!(xticks, ["$(round.(α_list, digits = 2)[i])" for i in 1:5:length(α_list)])
yticks = 2:5:length(m_list)
yticks!(yticks, ["$(m_list[length(m_list) + 1 - i])" for i in 2:5:length(m_list)])
# plot!(1:length(α_list), ff.(m_pred), label =:false, linewidth = 2 )

plot(p1, p2, layout = 2, xlabel = "incoherence upper bound", ylabel = "number of measurements")
savefig("figures/Phase_plot_Full_MNIST_percentage.pdf")



In [95]:


M = sum(phase_matrix_recovery,dims =1)[1,:,:]/20; M[M .>1] .= 1
p1 = plot(colorview(Gray, reverse(M,dims=1)), framestyle = :box, bg =:black, title = "recovery of code z₀")
xticks = 1:5:length(α_list)
xticks!(xticks, ["$(round.(α_list, digits = 2)[i])" for i in 1:5:length(α_list)])
yticks = 2:5:length(m_list)
yticks!(yticks, ["$(m_list[length(m_list) + 1 - i])" for i in 2:5:length(m_list)])

M = sum(phase_matrix_recon,dims =1)[1,:,:]/20; M[M .>1] .= 1
p2 = plot(colorview(Gray, reverse(M,dims=1)), framestyle = :box, bg =:black, title = "recovery of signal G(z₀)")
xticks = 1:5:length(α_list)
xticks!(xticks, ["$(round.(α_list, digits = 2)[i])" for i in 1:5:length(α_list)])
yticks = 2:5:length(m_list)
yticks!(yticks, ["$(m_list[length(m_list) + 1 - i])" for i in 2:5:length(m_list)])

plot(p1, p2, layout = 2,  bg =:white)
savefig("figures/Phase_plot_Full_MNIST_error.pdf")

