In [None]:
using WAV;
using Gen;
import Random
include("model.jl")
include("plotting.jl")
include("proposals.jl")
include("gammatonegram.jl")
include("time_helpers.jl")

### Samples from generative model

In [None]:
duration = 0.81
max_tones = 5
audio_sr = 20000
tstep = 0.020
mindur = 0.020
wts, f = gtg_weights(audio_sr);

# function get_kernel_params(var_type)
    
#     if var_type == :erb
#         sigma = 3
#         scale = 0.5
#         noise = 1
#     elseif var_type == :amp
#         sigma = 1
#         scale = 0.5
#         noise = 0.5
#     end
    
#     return sigma, scale, noise
    
# end

In [None]:
figure(figsize=(12, 4))
for i=1:12
    subplot(2, 6, i)
    trace = simulate(generate_scene, (max_tones, duration, wts, mindur, tstep, audio_sr))
    scene_gram, t, scene_wave = get_retval(trace)
    ax = gca()
    noise = round(trace[:noise], digits=2)
    n_sources = trace[:n_sources]
    ax[:set_title]("N: $n_sources, Noise: $noise")
    #could also plot scene_gram
    plot_gtg(trace[:scene], duration, audio_sr/2) 
end
tight_layout()

Listen to a sample!

In [None]:
trace = simulate(generate_scene, (max_tones, duration, wts, mindur, tstep, audio_sr))
scene_gram, t, scene_wave = get_retval(trace)
noise = round(trace[:noise], digits=2)
#could also plot scene_gram
figure(figsize=(4,2))
plot_gtg(scene_gram, duration, audio_sr/2) 
WAV.wavplay(scene_wave/maximum(scene_wave),audio_sr)

## Sample a scene to use as an observation

In [None]:
Random.seed!(24)

# generate a scene
ground_truth_trace, _ = generate(generate_scene, (max_tones, duration, wts, mindur, tstep, audio_sr), choicemap((:n_sources, 3)));
ground_truth_gram, t, ground_truth_wave = get_retval(ground_truth_trace)

#plot scene
figure(figsize=(4,2))
n_sources = ground_truth_trace[:n_sources]
title("ground truth scene, sources: $n_sources")
plot_gtg(ground_truth_gram, duration, audio_sr/2) 
tight_layout()

## MCMC plot & run functions

In [None]:
function run_inference(inference_function, traces, ground_truth_gram, max_steps)

    trace = traces[end]
    max_tones, scene_duration, wts, mindur, tstep, audio_sr = get_args(trace)
    plot_sources(trace, ground_truth_gram, 0)
    accept_counts = zeros(13); totals = zeros(13);
    for i=1:max_steps
        print("$i ")
        trace, accept_counts, totals = inference_function(trace, accept_counts, totals)
        plot_sources(trace, ground_truth_gram, i)
    end

    figure(figsize=(8,2));
    title("Percent of accepted proposals")
    proposals=["ns","nt","MuE","MuA","Wait","Dur","GpE","GpA","Spl/Mrg","Swh","Noise", "GpEAll", "GpAAll"];
    ps=accept_counts./totals
    cs=[]; 
    for p in ps
        if p == 1
            append!(cs, ["g"])
        elseif p > 0
            append!(cs, ["b"])
        else
            append!(cs, ["r"])
        end
    end
    scatter(proposals, ps,color=cs); 
    show()
    
    println("Last sample trace:")
    for t = 1:trace[:n_sources]
        println("Source $t:")
        last_sample = get_submap(get_choices(trace), :source => t)
        println(last_sample)
    end
    
    return push!(traces,trace)
    
end

In [None]:
function mcmc_update(trace, a, tot)

    #swap to the end
    #change by arbitrary number
    #trace, accepted = mh(trace, select(:n_sources))
    trace, accepted = mh(trace, change_n_sources, ())
    a[1] += Int(accepted); tot[1] += 1;
    
    trace, accepted = mh(trace, swap_sources_randomness, (), swap_sources_involution)
    @assert accepted
    
    for source_id = 1:trace[:n_sources]
        #add or remvoe a tone
        trace, accepted = mh(trace, n_randomness, (source_id,), n_involution)
        a[2] += Int(accepted); tot[2] += 1;

        trace, accepted = mh(trace, select(:source => source_id => :mu_erb))
        a[3] += Int(accepted); tot[3] += 1;

        trace, accepted = mh(trace, select(:source => source_id => :mu_amp))
        a[4] += Int(accepted); tot[4] += 1;

        nt = get_choices(trace)[:source => source_id => :n_tones]
        for tone_idx = 1:nt
            
            trace, accepted = mh(trace, wait_randomness, (tone_idx,source_id,), wait_involution)
            a[5] += Int(accepted); tot[5] += 1;
            
            trace, accepted = mh(trace, duration_randomness, (tone_idx,source_id,), duration_involution)
            a[6] += Int(accepted); tot[6] += 1;

            max_tones, scene_duration, wts, mindur, tstep, audio_sr = get_args(trace)
            source_trace = get_submap(get_choices(trace),:source=>source_id)
            old_abs_timings = absolute_timing(source_trace, mindur)
            onset = old_abs_timings[tone_idx][1]; offset = old_abs_timings[tone_idx][2]
            timings = [t for t in 0:tstep:old_abs_timings[end][2] if onset <= t < offset]
            k = length(timings); step_size = 4;
    
            for i = 1:step_size:k

                update_idxs = Int.(i : min(k, Int(i+step_size-1)))
                trace, accepted = mh(trace, gp_randomness, (tone_idx, update_idxs,:erb,source_id), gp_involution)
                a[7] += Int(accepted); tot[7] += 1;

                trace, accepted = mh(trace, gp_randomness, (tone_idx, update_idxs, :amp,source_id), gp_involution)
                a[8] += Int(accepted); tot[8] += 1;

            end
                        
            trace, accepted = mh(trace, gp_randomness, (tone_idx, 1:k, :erb, source_id), gp_involution)
            a[12] += Int(accepted); tot[12] += 1
            trace, accepted = mh(trace, gp_randomness, (tone_idx, 1:k, :amp, source_id), gp_involution)
            a[13] += Int(accepted); tot[13] += 1
                        
        end

        #check if we should split or merge some random pair of consecutive tones
        trace, accepted = mh(trace, sm_randomness, (source_id,), sm_involution)
        a[9] += Int(accepted); tot[9] += 1;
                    
    end
    
    #check if we should move a random tone into a different stream
    trace, accepted = mh(trace, switch_randomness, (), switch_involution) 
    a[10] += Int(accepted); tot[10] += 1;                    
    trace, accepted = mh(trace, select(:noise))
    a[11] += Int(accepted); tot[11] += 1;
    
    return trace, a, tot

end

In [None]:
Random.seed!(12)
constraints = choicemap()
constraints[:scene] = ground_truth_gram
new_init, _ = generate(generate_scene, (max_tones, duration, wts, mindur, tstep, audio_sr), constraints);
tr = run_inference(mcmc_update, [new_init], ground_truth_gram, 100);

## Data-based initialization

In [None]:
max_tones = 16; 
demo, audio_sr = wavread("Track17X.wav");
demo=demo[:,1];
wts, f = gtg_weights(audio_sr)
demo_gram, t = gammatonegram(demo, wts, audio_sr)
scene_duration = length(demo)/audio_sr; 
plot_gtg(demo_gram, scene_duration, audio_sr/2)
title("Bouncing demo");

In [None]:
import JSON
#read in neiral network guide mixture distribution
d = Dict()
open("Track17X_guide.json", "r") do f
    global d
    dt = read(f,String)  # file information to string
    d=JSON.parse(dt)  # parse and transform data
end

In [None]:
function c2p(counts)
    return counts/sum(counts)
end
    
@gen function data_based_elements(d)
    elements = [];
    for i = 1:length(d["elements"])
        element = d["elements"][i]
        p_use = element["p_use"]["p"]
        use_element = p_use > 0.80 ? true : @trace(bernoulli(p_use), (:use, i))
        if use_element
            source_type = "tone" #we're just dealing with a single source type right now
            onset_ps = c2p(element["onset"][source_type]["ps"]); 
            onset_idx = @trace(categorical(onset_ps), (:onset, i))
            offset_ps = c2p(element["offset"][source_type]["ps"][onset_idx]); 
            offset_idx = @trace(categorical(offset_ps), (:offset, i))
            f0_idxs = []; filt_idxs = [];
            for j = onset_idx:offset_idx
                f0_ps = c2p(element["f0"]["ps"][j]); 
                next_f0_idx = @trace(categorical(f0_ps), (:f0, i, j))
                push!(f0_idxs, next_f0_idx)
                filt_ps = c2p(element["filt1D"]["ps"][j]);
                next_filt_idx = @trace(categorical(filt_ps), (:filt, i, j))
                push!(filt_idxs, next_filt_idx)
            end
            
            onset = d["t_vs"][onset_idx] 
            offset = d["t_vs"][offset_idx]
            ts = d["t_vs"][onset_idx:offset_idx]
            erbf0 = freq_to_ERB(d["f_vs"][f0_idxs])
            filt = d["a_vs"][filt_idxs]
            
            if offset - onset >= mindur
                push!(elements, Dict("onset" => onset, "offset" => offset, "erbf0" => erbf0, "filt" => filt, "ts" => ts) )
            end
        end
    end
    
    
    for i = 1:length(elements)
        
        if i == 1 
            elements[i]["source"] = 1
        else
            ps = []
            for j = 1:10 #max_sources
                source_elems = [element for element in elements if haskey(element,"source")]
                source_elems = [element for element in source_elems if element["source"] == j]
                if length(source_elems) == 0
                    push!(ps, 0.1)
                else
                    
                    source_onsets = [se["onset"] for se in source_elems]
                    sorted_onset_idx = sortperm(source_onsets)
                    source_elems = source_elems[sorted_onset_idx]
                    this_onset = elements[i]["onset"]
                    this_offset = elements[i]["offset"]
                                
                    fits = false;
                    for k = 1:length(source_elems) + 1
                        if k == 1
                            fits = 0 < this_onset < this_offset < source_elems[k]["onset"]
                        elseif 1 < k <= length(source_elems) 
                            fits =  source_elems[k-1]["offset"] < this_onset < this_offset < source_elems[k]["onset"]
                        else 
                            fits = source_elems[k - 1]["offset"] < this_onset
                        end
                        if fits
                            break
                        end
                    end
                    push!(ps, fits ? 1 : 0)
                end
            end
            source_idx = @trace(categorical(ps/sum(ps)),(:n, i))
            elements[i]["source"] = source_idx
        end
        
    end
                        
    return elements 
end;

In [None]:
using Dierckx;
Random.seed!(24)
trace = simulate(data_based_elements, (d,))
elements = get_retval(trace);

constraints = choicemap()
sources = []
for j = 1:10
    source_elems = [element for element in elements if element["source"] == j]
    if length(source_elems) > 0
        source_onsets = [se["onset"] for se in source_elems]
        sorted_onset_idx = sortperm(source_onsets)
        source_elems = source_elems[sorted_onset_idx]
    end
    push!(sources, source_elems)
end

sources = [source for source in sources if length(source) > 0]
constraints[:n_sources] = length(sources)
for (i, source) in enumerate(sources)
    n_tones = length(source)
    constraints[:source => i => :n_tones] = n_tones

    last_offset = 0;
    for (j, tone) in enumerate(source)
        #if you don't add random component, the neural network grid exactly aligns with 
        #the GP grid and floating point errors means there are errors in curr_t
        onset = tone["onset"] + 1e-4*rand(); offset = tone["offset"] + 1e-4*rand();
        
        constraints[:source => i =>  (:tone, j) => :wait] = onset - last_offset
        constraints[:source => i => (:tone, j) => :dur_minus_min] = offset - onset - mindur;
        curr_t = get_tone_sample_times([onset, offset], tstep)
                                                                                                                        
        erbSpl = Spline1D(tone["ts"], tone["erbf0"], k=1)
        erbf0 = erbSpl(curr_t); 
        ampSpl = Spline1D(tone["ts"], tone["filt"], k=1)
        filt = ampSpl(curr_t) .- 2 #note: need to align Gen & Webppl Amplitude versions
                                            
        constraints[:source => i => (:tone, j) => :erb] = erbf0
        constraints[:source => i => (:tone, j) => :amp] = filt
        last_offset = offset
        
   end
end
constraints[:scene] = demo_gram
args = (max_tones, float(scene_duration), wts, mindur, tstep, Int(audio_sr))
data_init_trace, _ = generate(generate_scene, args, constraints);
plot_sources(data_init_trace, demo_gram, 0)

In [None]:
Random.seed!(7)
tr0 = run_inference(mcmc_update, [data_init_trace], demo_gram, 10);

In [None]:
tr1 = run_inference(mcmc_update, tr0, demo_gram, 100);

In [None]:
tr2 = run_inference(mcmc_update, tr1, demo_gram, 100)

In [None]:
crossing_trace = tr2[end]

constraints = choicemap()
constraints[:n_sources] = crossing_trace[:n_sources]

constraints[:source=>1=>:n_tones]=crossing_trace[:source => 1 => :n_tones]
all_erbs = []; all_amps = [];
for j = 1:constraints[:source=>1=>:n_tones]
    pull_from = j < 4 ? 1 : 2
    for a = [:erb, :amp, :dur_minus_min, :wait]
        if a == :wait && j == 4
            constraints[:source=>1=>(:tone,j)=>a] = 0.003
        else
            constraints[:source=>1=>(:tone,j)=>a] = crossing_trace[:source => pull_from => (:tone, j)=>a]
        end
    end
    all_erbs = vcat(all_erbs, get_submap(get_choices(crossing_trace), :source => pull_from => (:tone, j))[:erb])
    all_amps = vcat(all_amps, get_submap(get_choices(crossing_trace), :source => pull_from => (:tone, j))[:amp])
end
constraints[:source=>1=>:mu_erb] = mean(all_erbs)
constraints[:source=>1=>:mu_amp] = mean(all_amps)


constraints[:source=>2=>:n_tones]=crossing_trace[:source=>2=>:n_tones]
all_erbs = []; all_amps = [];
for j = 1:constraints[:source=>2=>:n_tones]
    pull_from = j >= 4 ? 1 : 2
    for a = [:erb, :amp, :dur_minus_min, :wait]
        if a == :wait && j == 4
            z = 0.295
        else
            z = 0
        end
        constraints[:source=>2=>(:tone,j)=>a] = crossing_trace[:source => pull_from => (:tone, j)=>a] .+ z
    end
    all_erbs = vcat(all_erbs, get_submap(get_choices(crossing_trace), :source => pull_from => (:tone, j))[:erb])
    all_amps = vcat(all_amps, get_submap(get_choices(crossing_trace), :source => pull_from => (:tone, j))[:amp])
end
constraints[:source=>2=>:mu_erb] = mean(all_erbs)
constraints[:source=>2=>:mu_amp] = mean(all_amps)

constraints[:noise] = crossing_trace[:noise];

bouncing_trace, _ = generate(generate_scene, args, constraints)
plot_gtg(bouncing_trace[:scene], scene_duration, audio_sr/2)

#println(get_choices(bouncing_trace))

crossing_score = get_score(crossing_trace)
bouncing_score = get_score(bouncing_trace)
diff_score = crossing_score - bouncing_score
println("cross score: $crossing_score, bounce score: $bouncing_score, diff: $diff_score")


cross_likelihood=logpdf(noisy_matrix,crossing_trace[:scene],observation_trace[:scene],constraints[:noise])
bounce_likelihood=logpdf(noisy_matrix,bouncing_trace[:scene],observation_trace[:scene],constraints[:noise])
diff_likelihood = cross_likelihood - bounce_likelihood
println("cross likelihood: $cross_likelihood, bounce likelihood: $bounce_likelihood, diff_likelihood: $diff_likelihood")

plot_gtg_unscaled(bouncing_trace[:scene]-crossing_trace[:scene], scene_duration, audio_sr/2)

## The problem is that for the likelihood to be equal, 
## The fourth tone of the lower stream should have onset = 0.91542338441
## However, the third tone of the lower stream has offset = 0.92599250102
## Since they can't overlap the likelihood is a bit off
## However, it seems like if the difference in the likelihood were accounted for 
## then the bouncing explanation would win

## Demo without data-based initialization

In [None]:
Random.seed!(15)
constraints = choicemap()
constraints[:scene] = demo_gram
rand_init, _ = generate(generate_scene, args, constraints);
plot_sources(rand_init, demo_gram, 0)

In [None]:
trace_from_rand = run_inference(mcmc_update, [rand_init], demo_gram, 100);