# MNIST check

In [None]:
using Pkg
using Plots
using Revise
using DelimitedFiles
using BenchmarkTools

In [None]:
Pkg.activate("../Sampling/")
using Sampling
Pkg.activate("../../GaussianEP/")
using GaussianEP

## (Pv,Ph) = (Bernoulli, Bernoulli)

In [None]:
using JLD2

mean_data = readdlm("../MNIST/mean.txt", '\t', Float64, '\n')
covv_data = readdlm("../MNIST/corr.txt", '\t', Float64, '\n')
data_EP = load("../MNIST/res_BerRBM_rh8_w0bim_adagrad_SGD_epTRBL.jld2");
w_EP = data_EP["res"].weights;


In [None]:

N,M = size(w_EP)
rhov = 0.8677
rhoh = 0.8
Pv = BinaryPrior(0.0, 1.0, rhov)
Ph = BinaryPrior(0.0, 1.0, rhoh)
y = zeros(N+M)

P0 = vcat([Pv for i = 1:N], [Ph for i = 1:M]);
H = [GaussianEP.TermRBM(w_EP,y,1.0)];

out_ep = expectation_propagation(H,P0, nprint = 1000, maxiter = 10000, damp = 0.99, epsconv = 1e-5);



In [None]:
using BoltzmannMachines
#data_bm = load("../MNIST/winf_fields_BerRBM.jld2")
data_bm = load("../MNIST/bm_rbm_cdstep10_batchsize500.jld2")
w_BM = data_bm["rbm"].weights
rhov_BM = 1.0 ./ (1.0 .+ exp.(data_bm["rbm"].visbias));
rhoh_BM = 1.0 ./ (1.0 .+ exp.(data_bm["rbm"].hidbias));

Pv = [BinaryPrior(0.0,1.0,rhov_BM[i]) for i = 1:N]
Ph = [BinaryPrior(0.0,1.0,rhoh_BM[i]) for i = 1:M]

P0_BM = vcat(Pv,Ph);

H_BM = GaussianEP.TermRBM(w_BM,y,1.0);
beta_min = 0.01
beta_max = 1.0
niter = 2
step = (beta_max - beta_min) / niter
beta = beta_min
out_BM = nothing

for iter = 1:niter
    H_BM.w .= beta .* w_BM
    if iter == 1
        out_BM = expectation_propagation([H_BM],P0_BM,nprint = 1, damp= 0.99, epsconv = 1e-4, maxiter = 100000, state = nothing);
    else
        out_BM = expectation_propagation([H_BM],P0_BM,nprint = 1, damp= 0.99, epsconv = 1e-4, maxiter = 100000, state = out_BM[1].state);
    end
    beta *= step
end 

In [None]:
data_bm = load("../MNIST/bm_rbm_cdstep10_batchsize500.jld2")
w_BM = data_bm["rbm"].weights
p1 = histogram(vec(w_BM), nbins = 200)
p2 = plot(vec(H_BM.w),vec(w_BM), seriestype= :scatter)
l = @layout [a;b]
plot(p1,p2,layout = l)

In [None]:
using Random


biasv_ep = log((1.0 - rhov)/rhov)
biash_ep = log((1.0 - rhoh)/rhoh)
rbm_bm = data_bm["rbm"]
rbm_ep = BernoulliRBM(w_EP, fill(biasv_ep, N), fill(biash_ep, M))
particle_ep = BoltzmannMachines.initparticles(rbm_ep, 1, biased=true)
particle_bm = BoltzmannMachines.initparticles(rbm_bm, 1, biased=true)

#typeof(rbm_bm)
Nconf = 10000
Neq = 20000
samples_ep = zeros(Nconf, N)
samples_bm = zeros(Nconf, N)

hidden_bm = zeros(Nconf,M)
hidden_ep = zeros(Nconf,M)

myseed = 19
Random.seed!(myseed)
gibbssample!(particle_ep, rbm_ep, Neq)
gibbssample!(particle_bm, rbm_bm, Neq)
for i = 1:Nconf

    gibbssample!(particle_ep, rbm_ep, 1)
    gibbssample!(particle_bm, rbm_bm, 1)
    samples_ep[i,:] = particle_ep[1]
    samples_bm[i,:] = particle_bm[1]
    hidden_ep[i,:] = particle_ep[2]
    hidden_bm[i,:] = particle_bm[2]
end


In [None]:
mean_bm_mc = mean(samples_bm, dims=1)
p1 = plot(mean_data, out_BM[1].av[1:N], seriestype = :scatter, xlabel = "Data", ylabel = "EP (BM par)")
p2 = plot(mean_data, vec(mean_bm_mc), seriestype = :scatter, xlabel = "Data", ylabel = "MC (BM par)")
l = @layout = [a;b]
plot(p1, p2, layout = l)

In [None]:
using Statistics
covv_bm_mc = cov(samples_bm)
covv_ep_mc = cov(samples_ep)
mean_bm_mc = mean(samples_bm, dims=1)
mean_ep_mc = mean(samples_ep, dims=1)

p1 = plot(out_ep[1].av[1:N], vec(mean_ep_mc), seriestype = :scatter, xlabel = "EP (EP)", ylabel = "EP (MC)")
p2 = plot(out_BM[1].av[1:N], vec(mean_bm_mc), seriestype = :scatter, xlabel = "BM (EP)", ylabel = "BM (MC)")
p3 = plot(mean_data, vec(mean_ep_mc), seriestype = :scatter, xlabel = "Data", ylabel = "EP (MC)")
p4 = plot(mean_data, vec(mean_bm_mc), seriestype = :scatter, xlabel = "Data", ylabel = "BM (MC)")
l = @layout [a b ; c d]
plot(p1,p2,p3,p4, layout = l)

In [None]:
Nconf = 10000
Ndig = 10000
digits_ep = zeros(Ndig, N)
digits_bm = zeros(Ndig, N)
L = convert(Int64,sqrt(N))

p1 = []
p2 = []
for idx = 1:100
    for i = 1:Ndig
        digits_ep[i,:] = samplevisible(rbm_ep, hidden_ep[idx,:])
        digits_bm[i,:] = samplevisible(rbm_bm, hidden_bm[idx,:])
    end
    p1 = vcat(p1, heatmap(reshape(mean(digits_ep, dims=1), L, L), aspect_ratio = :equal, yflip = false))
    p2 = vcat(p2, heatmap(reshape(mean(digits_bm, dims=1), L, L), aspect_ratio = :equal, yflip = false))

end



In [None]:

idx = 90
plot(p1[idx], p2[idx])

In [None]:
covv_ep = zeros(N,N)
for i = 1:N, j = i+1:N
	covv_ep[i,j] = out_ep.state.Σ[i,j] + out_ep.state.av[i] * out_ep.state.av[j]
	covv_ep[j,i] = covv_ep[i,j]
end

va_data = zeros(N,)
for i = 1:N
	va_data[i] = covv_data[i,i] - mean_data[i] * mean_data[i]
	covv_data[i,i] = NaN
end

covv_bm = zeros(N,N)
for i = 1:N, j = i+1:N
	covv_bm[i,j] = out_BM.state.Σ[i,j] + out_BM.state.av[i] * out_BM.state.av[j]
	covv_bm[j,i] = covv_bm[i,j]
end


In [None]:
p = Plots.plot(mean_data, mean_data, aspect_ratio = :equal, seriestype = :scatter, label = "x = y", legend = :topleft, xlabel= "Data av")
p = Plots.plot!(mean_data, out_ep.av[1:N], aspect_ratio = :equal, seriestype = :scatter, label = "EP means", legend = :topleft, xlabel= "Data av")
p1 = Plots.plot(mean_data, mean_data, aspect_ratio = :equal, seriestype = :scatter, label = "x = y", legend = :topleft, xlabel= "Data av")
p1 = Plots.plot!(mean_data, out_BM.av[1:N], seriestype = :scatter, label = "BM means", legend = :topleft, xlabel= "Data av")
p2 = Plots.plot(out_BM.av[1:N], out_ep.av[1:N], seriestype = :scatter, label = "Av", legend = :topleft, xlabel = "BM", ylabel = "EP")
l = @layout [a{0.3h} b{0.3h} c{0.3h}]
Plots.plot(p, p1,p2, layout = l)

In [None]:
p = Plots.plot(va_data, va_data, aspect_ratio = :equal, seriestype = :scatter, label = "x = y", legend = :topleft, xlabel= "Data var")
p = Plots.plot!(va_data, out_ep.va[1:N], aspect_ratio = :equal, seriestype = :scatter, label = "EP vars", legend = :topleft, xlabel= "Data var")
p1 = Plots.plot(va_data, va_data, aspect_ratio = :equal, seriestype = :scatter, label = "x = y", legend = :topleft, xlabel= "Data var")
p1 = Plots.plot!(va_data, out_BM.va[1:N], seriestype = :scatter, label = "BM vars", legend = :topleft, xlabel= "Data var")
p2 = Plots.plot(out_BM.va[1:N], out_ep.va[1:N], seriestype = :scatter, label = "Vars", legend = :topleft, xlabel = "BM", ylabel = "EP")
l = @layout [a{0.3h} b{0.3h} c{0.3h}]
Plots.plot(p, p1,p2, layout = l)

In [None]:
vcovv_data = vec(covv_data)
vcovv_ep = vec(covv_ep)
vcovv_bm = vec(covv_bm)

p = Plots.plot(vcovv_data, vcovv_data, aspect_ratio = :equal, seriestype = :scatter, label = "x = y", legend = :topleft, xlabel= "Data cov(v,v)")
p = Plots.plot!(vcovv_data, vcovv_ep, aspect_ratio = :equal, seriestype = :scatter, label = "EP covv", legend = :topleft, xlabel= "Data cov(v,v)")
p1 = Plots.plot(vcovv_data, vcovv_data, aspect_ratio = :equal, seriestype = :scatter, label = "x = y", legend = :topleft, xlabel= "Data cov(v,v)")
p1 = Plots.plot!(vcovv_data, vcovv_bm, seriestype = :scatter, label = "BM covv", legend = :topleft, xlabel= "Data cov(v,v)")
p2 = Plots.plot(vcovv_bm, vcovv_ep, seriestype = :scatter, label = "Cov (v,v)", legend = :topleft, xlabel = "BM", ylabel = "EP")
l = @layout [a{0.3h} b{0.3h} c{0.3h}]
Plots.plot(p, p1,p2, layout = l)



In [None]:
writedlm(open("../MNIST/covv_bm.dat", "w"), vec(covv_bm))
writedlm(open("../MNIST/covv_ep.dat", "w"), vec(covv_ep))
writedlm(open("../MNIST/covv_data.dat", "w"), vec(covv_data))
size(vec(covv_bm))