In [None]:
using Gen;
using Random
using JSON

using Statistics: mean, std, cor;
using LinearAlgebra: dot;
using StatsFuns: logsumexp, softplus;
using PyPlot
using SpecialFunctions: digamma,trigamma,loggamma;
include("../model/time_helpers.jl")
include("../model/extra_distributions.jl")
include("../model/gaussian_helpers.jl")
include("../tools/plotting.jl")

include("../model/model.jl")
using PyPlot;

In [None]:
@gen function generate_tp() 
    
    #a_0=2.64 b_0=0.79 k_0=0.51 m_0=-1.78
    alpha_0=2.5 
    beta_0=1.0/0.8 
    kappa_0=0.5
    mu_0=-1.5
   
    precision = @trace(gamma(alpha_0, beta_0),:precision)
    mean = @trace(normal(mu_0, 1.0/sqrt(precision * kappa_0)), :mean)
    n_elements = @trace(uniform_discrete(1,10),:n_elements)
    ws = []
    for i = 1:n_elements
        push!(ws,@trace(log_normal(mean, 1.0/sqrt(precision)),(:w, i)))
    end
    
    return ws
    
end
(choices, _, _) = propose(generate_tp, ())
println(choices)

In [None]:
all_waits = []
for i = 1:1000
    trace, = generate(generate_tp, ())
    waits = get_retval(trace)
    append!(all_waits, waits)
end
all_waits = sort(all_waits);

In [None]:
n = length(all_waits)
percentile10_idx = Int(floor(n*0.1))
percentile90_idx = Int(floor(n*0.975))
percentile99_idx = Int(floor(n*0.99))

minw= round(minimum(all_waits),digits=2)
modew= round(mode(all_waits),digits=2)
maxw= round(maximum(all_waits),digits=2)
q1 = round(all_waits[percentile10_idx],digits=2)
q2= round(all_waits[percentile90_idx],digits=2)
q3= round(all_waits[percentile99_idx],digits=2)

hist(all_waits[1:percentile99_idx], bins=100)
xlim([0,q3])
title("Min:$minw, mode:$meanw, max:$maxw, 10%:$q1, 97.5%:$q2, 99%:$q3")

In [None]:
source_params, steps, gtg_params, obs_noise = include("../params/gnprior.jl")
audio_sr=20000; 
wts, f = gtg_weights(audio_sr, gtg_params)
scene_duration = 2.0
args = (source_params, float(scene_duration), wts, steps, Int(audio_sr), obs_noise, gtg_params);

In [None]:
# include("./model/model.jl")
trace, = generate(generate_scene, args);

scene_gram, t, scene_wave, source_waves, element_waves = get_retval(trace)
# wavwrite(scene_wave/maximum(abs.(scene_wave)), "scene.wav", Fs=audio_sr)
plot_gtg(scene_gram, scene_duration, audio_sr, 20, 100)
title(trace[:source=>1=>:n_elements])

In [None]:
println(get_choices(trace))