In [1]:
using Distributions, TimeIt, ProgressMeter, PyPlot, JLD
include("zz_structures_DuLuSuSe.jl")
include("mbsampler.jl")
include("polyagamma.jl");

#### Load data:

In [2]:
d, Nobs = 10_000, 12_000
X = sprandn(d, Nobs, 1e-2)
ξ_true = 3*sprandn(d, 1e-2)
ξ_true[1] = 3*randn()
y = [rand(Binomial(1, 1/(1+exp(-X[:,n]'ξ_true)))) for n in 1:Nobs];

#### Define prior:

In [3]:
σ02 = 1
prior = SS_prior(d, σ02);
# prior = gaussian_prior_nh(d, σ02);

#### Define model:

In [4]:
my_ll = ll_logistic_sp(X,y);
my_model = model(my_ll, prior);

#### Define minibatch sampler:

In [5]:
mb_size = 200
weights_het = spzeros(d, Nobs)
@showprogress for i in 1:d 
    nzind = X[i,:].nzind
    weights_het[i,nzind] = abs.(X[i,nzind])./sum(abs.(X[i,nzind]))
end

prob_het = 0.98
gs = Array{mbsampler}(d)
gs[1] = umbsampler(Nobs, mb_size)
@showprogress for i in 2:d 
    gs[i] = spwumbsampler(Nobs, mb_size, weights_het[i,:], prob_het)
end
gs_list = mbsampler_list(d,gs);

LoadError: [91mInterruptException:[39m

#### Define output timer:

In [6]:
opt = maxa_opt(10^5)
bb = linear_bound(my_model.ll, my_model.pr, gs_list)
mstate = zz_state(d)
update_bound(bb, my_ll, prior, gs_list, mstate);

In [7]:
h = 1e-4
xi_samples = zz_samples(mstate, h) 
pr_samples = hyp_samples(prior, h);

#### Define mbsampler + block Gibbs sampler list:

In [8]:
adapt_speed = "by_var" 
L = 1
my_zz_sampler = zz_sampler(0, gs_list, bb, L, adapt_speed)
hyper_sampler = block_gibbs_sampler(1e4)
blocksampler = Array{msampler}(2)
blocksampler[1] = my_zz_sampler
blocksampler[2] = hyper_sampler;

#### Run sampler:

In [9]:
ZZ_block_sample_discrete(my_model, opt, blocksampler, mstate, xi_samples, pr_samples)

10% attempts in 0.68 mins, zz bounces = 1978, hyp bounces = 82, samples extracted = 81
20% attempts in 1.87 mins, zz bounces = 4599, hyp bounces = 156, samples extracted = 152
30% attempts in 3.46 mins, zz bounces = 6815, hyp bounces = 215, samples extracted = 223
40% attempts in 5.58 mins, zz bounces = 8590, hyp bounces = 284, samples extracted = 296
50% attempts in 8.22 mins, zz bounces = 10130, hyp bounces = 357, samples extracted = 369
60% attempts in 11.28 mins, zz bounces = 11453, hyp bounces = 434, samples extracted = 443
70% attempts in 14.63 mins, zz bounces = 12742, hyp bounces = 518, samples extracted = 517
80% attempts in 17.45 mins, zz bounces = 14109, hyp bounces = 584, samples extracted = 591
90% attempts in 20.57 mins, zz bounces = 15557, hyp bounces = 660, samples extracted = 667
100% attempts in 24.01 mins, zz bounces = 16992, hyp bounces = 737, samples extracted = 741


* Number of bounces:

In [19]:
print(sum(mstate.n_bounces), " zz bounces \n")
print(hyper_sampler.nbounces, " hyper bounces \n")
print("Number of xi samples drawn = ", size(xi_samples.samples,2))

150304 zz bounces 
11134 hyper bouncesNumber of xi samples drawn = 2219

## Note: discard some initial burn-in samples.

#### Posterior coverage:

In [None]:
cover_gzz = zeros(d)
ci_gzz = zeros(d,2)
for i in 1:d 
    ci_gzz[i,:] = percentile(xi_samples.samples[i,:], [2.5, 97.5])
    cover_gzz[i] = (ci_gzz[i,1]<ξ_true[i])&(ξ_true[i]<ci_gzz[i,2])
end
print(100*mean(cover_gzz))

In [None]:
post_mean_gzz = mean(xi_samples.samples,2)
post_median_gzz = median(xi_samples.samples,2);

In [None]:
fig = figure(figsize=(12,3))
plot(sort(abs.(ξ_true.nzval), rev=true), "o-", markersize=3, label="true signal")
plot(sort(vec(abs.(post_mean_gzz)), rev=true)[1:50], "o-", markersize=3, label="posterior mean")
plot(sort(vec(abs.(post_median_gzz)), rev=true)[1:50], "o-", markersize=3, label="posterior median")
grid(true)
title("Gibbs zig-zag sampler")
legend(ncol=3);