In [None]:
using Gen
using GaussianProcesses
using Random
using Distributions
using JSON

using Statistics: mean, std, cor;
using LinearAlgebra: dot;
using StatsFuns: logsumexp, softplus;
using PyPlot
using SpecialFunctions: digamma,trigamma;
include("./time_helpers.jl")
include("./extra_distributions.jl")
include("./gaussian_helpers.jl")

In [None]:
function make_source_latent_model(source_params, audio_sr, steps)
    
    @gen function source_latent_model(latents, scene_duration)

        ##Single function for generating data for amoritized inference to propose source-level latents
        # Wait: can generate on its own
        # Dur_minus_min: can generate on its own
        # GPs:
        # Can have a separate amoritized inference move for ERB, Amp-1D, Amp-2D 
        # Need to sample wait and dur_minus_min to define the time points for the GP 
        #
        # Format of latents: 
        # latents = Dict(:gp => :amp OR :tp => :wait, :source_type => "tone")

        ### SOURCE-LEVEL LATENTS 
        ## Sample GP source-level latents if needed 
        gp_latents = Dict()
        if :gp in keys(latents)

            gp_params = source_params["gp"]
            gp_type = latents[:gp]; gp_latents[gp_type] = Dict();
            source_type = latents[:source_type]
            
            hyperpriors = gp_type == :erb ? gp_params["erb"] : 
                ((source_type == "noise" || source_type == "harmonic") ? gp_params["amp"]["2D"] : gp_params["amp"]["1D"] )
    
            for latent in keys(hyperpriors)
                hyperprior = hyperpriors[latent]; syml = Symbol(latent)
                gp_latents[gp_type][syml] = @trace(hyperprior["dist"](hyperprior["args"]...), gp_type => syml)
            end
            
        end

        ## Sample temporal source-level latents 
        tp_latents = Dict()
        if :tp in keys(latents)
            if typeof(latents[:tp]) == Symbol
                tp_latents[latents[:tp]] = Dict()
            else
                for a in latents[:tp]
                    tp_latents[a] = Dict()
                end
            end
        elseif :gp in keys(latents)
            tp_latents[:wait] = Dict()
            tp_latents[:dur_minus_min] = Dict()
        end            
        tp_params = source_params["tp"]
        for tp_type in keys(tp_latents)
            hyperpriors = tp_params[String(tp_type)]
            for latent in keys(hyperpriors)
                hyperprior = hyperpriors[latent]; syml = Symbol(latent)
                tp_latents[tp_type][syml] = @trace(hyperprior["dist"](hyperprior["args"]...), tp_type => syml)
            end
            tp_latents[tp_type][:args] = (tp_latents[tp_type][:a], tp_latents[tp_type][:mu]/tp_latents[tp_type][:a]) #a, b for gamma
        end

        ## Sample a number of elements 
        ne_params = source_params["n_elements"]
        n_elements = ne_params["type"] == "max" ? 
            @trace(uniform_discrete(1, ne_params["val"]),:n_elements) : 
            @trace(geometric(ne_params["val"]),:n_elements)    

        ### ELEMENT-LEVEL LATENTS 
        #Storage for what inputs are needed
        tp_elems = Dict( [ k => [] for k in keys(tp_latents)]... )
        gp_elems = Dict(); x_elems = [];
        if :gp in keys(latents)       
            
            gp_type = latents[:gp]; source_type = latents[:source_type]
            
            gp_elems[gp_type] = []
            gp_type = latents[:gp]
            gp_elems[:t] = []
            if gp_type == :amp && (source_type == "noise" || source_type == "harmonic")
                gp_elems[:reshaped] = []
                gp_elems[:f] = []
                gp_elems[:tf] = []
            end 
            
        end
                
        time_so_far = 0.0;
        for element_idx = 1:n_elements
            
            if :wait in keys(tp_elems)
                wait = element_idx == 1 ? @trace(uniform(0, scene_duration-steps["t"]), (:element,element_idx)=>:wait) : 
                    @trace(gamma(tp_latents[:wait][:args]...), (:element,element_idx)=>:wait)
                push!(tp_elems[:wait], wait)
            end
            
            if :dur_minus_min in keys(tp_elems)
                dur_minus_min = @trace(truncated_gamma(tp_latents[:dur_minus_min][:args]..., source_params["duration_limit"]), (:element,element_idx)=>:dur_minus_min); 
                push!(tp_elems[:dur_minus_min], dur_minus_min)
            end
            
            if :gp in keys(latents)
                
                gp_type = latents[:gp]; source_type = latents[:source_type]
                duration = dur_minus_min + steps["min"]; onset = time_so_far + wait; 

                if onset > scene_duration
                    break
                end

                time_so_far = onset + duration; element_timing = [onset, time_so_far]

                ## Define points at which the GPs should be sampled
                x = []; ts = [];
                if gp_type === :erb || (source_type == "tone" && gp_type === :amp)
                    x = get_element_gp_times(element_timing, steps["t"])
                elseif gp_type === :amp && (source_type == "noise"  || source_type == "harmonic")
                    x, ts, gp_elems[:f] = get_gp_spectrotemporal(element_timing, steps, audio_sr)
                end

                mu, cov = element_idx == 1 ? get_mu_cov(x, gp_latents[gp_type]) : 
                        get_cond_mu_cov(x, x_elems, gp_elems[gp_type], gp_latents[gp_type])
                element_gp = @trace(mvnormal(mu, cov), (:element, element_idx) => gp_type)
                
                ## Save the element data 
                append!(x_elems, x)
                if gp_type === :erb || (source_type == "tone" && gp_type === :amp)
                    append!(gp_elems[:t], x)
                    append!(gp_elems[gp_type], element_gp)
                elseif gp_type === :amp && (source_type == "noise"  || source_type == "harmonic")
                    append!(gp_elems[:t],ts)
                    append!(gp_elems[:tf],x) 
                    append!(gp_elems[gp_type], element_gp)
                    reshaped_elem = reshape(element_gp, (length(gp_elems[:f]), length(ts))) 
                    if element_idx == 1
                        gp_elems[:reshaped] = reshaped_elem
                    else
                        gp_elems[:reshaped] = cat(gp_elems[:reshaped], reshaped_elem, dims=2)
                    end
                end

                if time_so_far > scene_duration
                    break
                end
           
            end

        end

        return tp_latents, gp_latents, tp_elems, gp_elems

    end
 
    return source_latent_model
    
end

function make_data_generator(source_latent_model, latents)
    
    function data_generator()


        max_scene_duration = 2.5; min_scene_duration = 0.5;
        scene_duration = round((max_scene_duration-min_scene_duration)*rand() + min_scene_duration,digits=3)
        trace = simulate(source_latent_model, (latents,scene_duration))
        tp_latents, gp_latents, tp_elems, gp_elems = get_retval(trace)

        constraints = choicemap()
        if :tp in keys(latents)
            tp_latent = latents[:tp]
            constraints[tp_latent => :mu] = trace[tp_latent => :mu]
            constraints[tp_latent => :a] = trace[tp_latent => :a]
            
        elseif :gp in keys(latents)
            
            gp_type = latents[:gp]
            source_type = latents[:source_type]
            d = gp_type == :erb ? source_params["gp"]["erb"] : (source_type == "tone" ? source_params["gp"]["amp"]["1D"] : source_params["gp"]["amp"]["2D"]) 
            for k in keys(d)
                constraints[gp_type => Symbol(k)] = trace[gp_type => Symbol(k)]
            end
            
        end
        inputs = :tp in keys(latents) ? (tp_elems,) : (gp_elems,scene_duration,)
        
        return (inputs, constraints)

    end
    
    return data_generator
    
end;

function create_trainable_tp_proposal(latent,source_params)
    
    @gen function trainable_tp_proposal(tp_elems)

        @param B_mean_mu::Vector{Float64}
        @param B_shape_mu::Float64
        
        sp=source_params["tp"][string(latent)]
        n_elements = length(tp_elems[latent])
        if ( latent == :wait && n_elements > 1 ) || ( latent == :dur_minus_min )
            
            #For parameter estimation in Gamma distributions
            #MLE estimators             
            W = latent == :wait ? tp_elems[latent][2:end] : tp_elems[latent]
            m = mean(W)
            ls = (length(W) > 1 && ~(all(y->y==W[1],W))) ? log(std(W)) : 
                        (length(W) == 1) ? sp["a"]["args"][2] : log(0.001) 

            X_mu = [log(m), 1.0]
            X_shape = [log(m), ls, log(n_elements), 1.0]
                                 
            mean_mu_estimate = exp( dot(X_mu, B_mean_mu) )
            shape_mu = exp( dot(X_shape, B_shape_mu) )
            
            @trace(gamma(shape_mu, mean_mu_estimate/shape_mu), :mu)
            
        else
            
            @trace(sp["mu"]["dist"](sp["mu"]["args"]...), :mu)
            
        end
        @trace(sp["a"]["dist"](sp["a"]["args"]...), :a)

        
        return nothing

    end
    
    return trainable_tp_proposal
    
end


In [None]:
function create_hmc_proposal(hyperpriors; n_prior_samples=1000, n_hmc_samples=5000) #n_chains  

    prior_args = Dict(); 
    for sv in keys(hyperpriors)
        if sv == "epsilon" || sv == "scale" || sv == "sigma"
            x = [ log(hyperpriors[sv]["dist"](hyperpriors[sv]["args"]...)) for g in 1:n_prior_samples ]
        elseif sv == "mu"
            x = [ hyperpriors[sv]["dist"](hyperpriors[sv]["args"]...) for g in 1:n_prior_samples ]
        end
        prior_args[sv] = [ mean(x), std(x) ];
        println("$sv: ", prior_args[sv])
    end

    @gen function hmc_based_proposal(x, y)
        
        #for i = 1:n_chains
        mConstant = GaussianProcesses.MeanConst(prior_args["mu"][1])
        kern = GaussianProcesses.SE(0.0, 0.0)
        logObsNoise = -1.0
        gp_ess = GaussianProcesses.GP(Float64.(x),Float64.(y),mConstant,kern, logObsNoise)

        GaussianProcesses.set_priors!(gp_ess.mean, [Distributions.Normal(prior_args["mu"]...)]) 
        GaussianProcesses.set_priors!(gp_ess.kernel, [Distributions.Normal(prior_args["scale"]...), Distributions.Normal(prior_args["sigma"]...)]) 
        GaussianProcesses.set_priors!(gp_ess.logNoise, [Distributions.Normal(prior_args["epsilon"]...)])

        rng = MersenneTwister(2143)
        chain = GaussianProcesses.ess(rng, gp_ess, nIter=n_hmc_samples)
        #end

        #["Log epsilon", "Mean", "SE log scale", "SE log sigma"]
        gp_params_mean = vec(mean(chain,dims=2))
        gp_params_cov = cov(chain,dims=2)
        GaussianProcesses.make_posdef!(gp_params_cov)[1]

        gp_params = @trace(mvnormal(gp_params_mean, gp_params_cov), :gpparams)

    end

    return hmc_based_proposal
end


function modtrain!(gen_fn::GenerativeFunction, data_generator::Function,
                update::ParamUpdate;
                num_epoch=1, epoch_size=1, num_minibatch=1, minibatch_size=1,
                evaluation_size=epoch_size, eval_inputs_and_constraints=[], verbose=false)

    history = Vector{Float64}(undef, num_epoch)
    
    if length(eval_inputs_and_constraints) == 0 
        println("Making evaluation dataset:")
        for i = 1:evaluation_size
            print("$i ")
            ic = data_generator()
            push!(eval_inputs_and_constraints, ic)
        end
        println()
    else
        evaluation_size = length(eval_inputs_and_constraints)
        println("Using eval set of size $evaluation_size")
    end
    
    for epoch=1:num_epoch

        s1=time()
        # generate data for epoch
        if verbose
            println("epoch $epoch: generating $epoch_size training examples...")
        end
        epoch_inputs = Vector{Tuple}(undef, epoch_size)
        epoch_choice_maps = Vector{ChoiceMap}(undef, epoch_size)
        for i=1:epoch_size
            (epoch_inputs[i], epoch_choice_maps[i]) = data_generator()
        end
        s2=time()-s1

        # train on epoch data
        s1=time()
        if verbose
            println("Time to generate data: $s2")
            println("epoch $epoch: training using $num_minibatch minibatches of size $minibatch_size...")
        end
        for minibatch=1:num_minibatch
            permuted = Random.randperm(epoch_size)
            minibatch_idx = permuted[1:minibatch_size]
            minibatch_inputs = epoch_inputs[minibatch_idx]
            minibatch_choice_maps = epoch_choice_maps[minibatch_idx]
            for (inputs, constraints) in zip(minibatch_inputs, minibatch_choice_maps)
                (trace, _) = generate(gen_fn, inputs, constraints)
                retval_grad = accepts_output_grad(gen_fn) ? zero(get_retval(trace)) : nothing
                accumulate_param_gradients!(trace, retval_grad)
            end
            apply!(update)
        end
        s2=time()-s1
        
        # evaluate score on held out data
        s1=time()
        if verbose
            println("Time to train: $s2")
            println("epoch $epoch: evaluating on $evaluation_size examples...")
        end
        avg_score = 0.
        for i=1:evaluation_size
            (_, weight) = generate(gen_fn, eval_inputs_and_constraints[i][1], eval_inputs_and_constraints[i][2])
            avg_score += weight
        end
        avg_score /= evaluation_size
        @assert ~isnan(avg_score)
        history[epoch] = avg_score
        s2=time()-s1

        if verbose
            println("Time to evaluate: $s2")
            println("epoch $epoch: est. objective value: $avg_score")
        end
        
        

    end
    return history
end

function train_tp_proposal(latent, source_params, audio_sr, steps; save_loc="./")

    source_latent_model = make_source_latent_model(source_params, audio_sr, steps);
    latents = latent == :wait ? Dict(:tp => :wait) : Dict(:tp => :dur_minus_min)

    println("Defining data generator.")
    evaluation_size = 100
    data_generator = make_data_generator(source_latent_model, latents)
    evalset = [data_generator() for i = 1:evaluation_size];

    trainable_proposal = create_trainable_tp_proposal(latent,source_params)
    weights = [:B_mean_mu,:B_shape_mu]
    
    println("Initializing parameters.")
    Gen.init_param!(trainable_proposal, :B_mean_mu, [1.0, 0.0])
    Gen.init_param!(trainable_proposal, :B_shape_mu, [2.0, -2.0, 0.1, 0.0])
    
    update = Gen.ParamUpdate(Gen.FixedStepGradientDescent(1e-4), 
        trainable_proposal);

    #train the proposal
    scores = modtrain!(trainable_proposal, data_generator, 
            update,
            num_epoch=1, epoch_size=1, num_minibatch=1, 
            minibatch_size=1, evaluation_size=evaluation_size, 
            eval_inputs_and_constraints=evalset, 
            verbose=false); 
    plot(scores)
    
    #save proposal 
    var_loc = string(save_loc,string(latent),"proposal_")
    trained_weights = [ string(w) => Gen.get_param(trainable_proposal, w) for w in weights ]
    trained_weights_dict = Dict(trained_weights...)
    open(string(var_loc, "weights.json"),"w") do f
        JSON.print(f, trained_weights_dict)
    end

    return trainable_proposal

end

In [None]:
source_params, steps, gtg_params, obs_noise = include("./base_params.jl")
audio_sr = 20000;
source_latent_model = make_source_latent_model(source_params, audio_sr, steps);
proposals = Dict()

## Looking at proposal outputs

In [None]:
for latent = [:dur_minus_min,:wait]
    proposals[latent] = train_tp_proposal(latent, source_params, audio_sr, steps)
end
println(keys(proposals))

In [None]:
for latent = [:wait, :dur_minus_min] 
    trained_weights = [ string(w) => Gen.get_param(proposals[latent], w) for w in [:B_mean_mu,:B_shape_mu] ]
    println(latent)
    println(trained_weights)
end

In [None]:
la = Dict(:wait=>[], :dur_minus_min=>[])
ta = Dict(:wait=>[], :dur_minus_min=>[])
for latent in [:wait, :dur_minus_min]
        println(latent)

for i = 1:100
    
    
    latents = latent == :wait ? Dict(:tp => :wait) : Dict(:tp => :dur_minus_min)  
    scene_duration = 1.0
    data_trace = simulate(source_latent_model, (latents,scene_duration))
    tp_latents, gp_latents, tp_elems, gp_elems = get_retval(data_trace)
    push!(la[latent], tp_latents[latent][:mu])
    
    (trace, _) = Gen.generate(proposals[latent], (tp_elems,));
#     println("get_choices(trace)--\n")
#     println(get_choices(trace))
    if trace[:mu] < 0.00001
    println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~`")
    end
            push!(ta[latent],trace[:mu])
#     println("get_submap(get_choices(trace))--\n")
#     println(get_submap(get_choices(trace), latent))
    
end
end

In [None]:
figure(figsize=(8,4))
subplot(1,2,1)
scatter(la[:wait],ta[:wait])
plot(0:0.1:3,0:0.1:3)
ylim([0,0.1])
xlim([0,0.1])
xlabel("Actual");ylabel("Predicted")
title("Wait mu")
subplot(1,2,2)
scatter(la[:dur_minus_min],ta[:dur_minus_min])
plot(0:0.1:3,0:0.1:3)

ylim([0,0.1])
xlim([0,0.1])
xlabel("Actual");ylabel("Predicted")

title("dur_minus_min mu")

In [None]:
@gen function temporal_proposals(tp_elems)
    @trace(proposals[:dur_minus_min](tp_elems), :dest => :dur_minus_min)
    @trace(proposals[:wait](tp_elems), :dest => :wait)
    @trace(normal(0,1),:x)
end

In [None]:
latents = Dict(:tp => [:wait,:dur_minus_min])
scene_duration = 1.0
data_trace = simulate(source_latent_model, (latents,scene_duration))
tp_latents, gp_latents, tp_elems, gp_elems = get_retval(data_trace)
(trace, _) = Gen.generate(temporal_proposals, (tp_elems,));
println("get_choices(trace)--\n")
println(get_choices(trace))

In [None]:
println("Erb:")
proposals[:erb] = create_hmc_proposal(source_params["gp"]["erb"]);
println("Amp:")
proposals[:amp] = create_hmc_proposal(source_params["gp"]["amp"]["1D"]);

In [None]:
latents = Dict(:tp => [:wait,:dur_minus_min],:gp=>:erb,:source_type=>"tone")
scene_duration = 1.0
data_trace = simulate(source_latent_model, (latents,scene_duration))
tp_latents, gp_latents, tp_elems, gp_elems = get_retval(data_trace)
println("gp_elems: ", gp_elems)
(trace, _) = Gen.generate(proposals[:erb], (gp_elems[:t],gp_elems[:erb],));
println("get_choices(trace)--\n")
println(get_choices(trace))

In [None]:
@gen function source_proposals(tp_elems,gp_elems)
    @trace(proposals[:dur_minus_min](tp_elems), :dest => :dur_minus_min)
    @trace(proposals[:wait](tp_elems), :dest => :wait)
    @trace(proposals[:erb](gp_elems[:t],gp_elems[:erb]), :dest => :erb)
    @trace(normal(0,1),:extra)
end

In [None]:
latents = Dict(:tp => [:wait,:dur_minus_min],:gp=>:erb,:source_type=>"tone")
scene_duration = 1.0
data_trace = simulate(source_latent_model, (latents,scene_duration))
tp_latents, gp_latents, tp_elems, gp_elems = get_retval(data_trace)
(trace, _) = Gen.generate(source_proposals, (tp_elems,gp_elems,));
println("get_choices(trace)--\n")
println(get_choices(trace))

In [None]:
println("ERB params:")
println("mu: ", trace[:dest => :erb => :gpparams][2])
println("scale: ", exp(trace[:dest => :erb  => :gpparams][3]))
println("sigma: ", exp(trace[:dest => :erb  => :gpparams][4]))
println("epsilon: ", exp(trace[:dest => :erb => :gpparams][1]))

## Switch proposal

In [None]:
custom_proposals = Dict()
custom_proposals[:wait] = train_tp_proposal(:wait, source_params, audio_sr, steps)
custom_proposals[:dur_minus_min] = train_tp_proposal(:dur_minus_min, source_params, audio_sr, steps)
custom_proposals[:erb] = create_hmc_proposal(source_params["gp"]["erb"]);
custom_proposals[:amp1D] = create_hmc_proposal(source_params["gp"]["amp"]["1D"]);
custom_proposals

In [None]:
@gen function rewrite_switch_randomness(trace, custom_proposals)
    
    source_params, scene_duration, wts, steps, audio_sr, obs_noise, gtg_params = get_args(trace)
    max_elements = source_params["n_elements"]["val"]
    old_choices = get_choices(trace)
    #onset/offset information for each element in each source: all source timings
    all_source_timings = []
    old_n_sources = old_choices[:n_sources]
    for i = 1:old_n_sources
        #list of lists of times
        #[ (element 1)[onset, offset], (element2)[onset, offset], ... ]
        old_abs_timings = absolute_timing(get_submap(old_choices, :source => i), steps["min"])
        push!(all_source_timings,old_abs_timings) 
    end
    
    
    origin = @trace(uniform_discrete(1, old_n_sources), :origin)
    old_n_elements = old_choices[:source => origin => :n_elements]
    element_idx = @trace(uniform_discrete(1,old_n_elements),:element_idx)
    onset = all_source_timings[origin][element_idx][1];
    offset = all_source_timings[origin][element_idx][2];
    
    #Find the sources into which a element can be switched 
    source_switch = []; which_spot = [];
    for i = 1:old_choices[:n_sources]
        # Only can switch things between streams of the same source_type. may not want this to be true.                                                                                                            
        if i == origin || (old_choices[:source => i => :source_type] != old_choices[:source => origin => :source_type])
            append!(source_switch, 0); append!(which_spot, 0)
        else
            source_nt = old_choices[:source => i => :n_elements]
            timings = all_source_timings[i];
            for j = 1:source_nt + 1

                #switch into spot before the first element
                if j == 1 
                    fits = (0 < onset) && (offset < timings[j][1])
                elseif 1 < j <= source_nt
                    fits = (timings[j-1][2] < onset) && (offset < timings[j][1])
                elseif j == source_nt + 1
                    fits = (timings[j - 1][2] < onset)
                end

                if fits
                    append!(source_switch, 1)
                    append!(which_spot, j)
                    break
                elseif j == source_nt + 1
                    append!(source_switch, 0)
                    append!(which_spot, 0)
                end

            end
        end
    end
    
    #Decide whether to move the element into an existing source
    #Or make a new source, where it will be the only element
    switch_to_existing_source = sum(source_switch)
    switch_to_new_source = (old_n_elements > 1 && old_n_sources < source_params["n_sources"]["val"]) ? 1 : 0 ## currently hard coded that we're using a uniform distribution
    switch_weights = [switch_to_existing_source, switch_to_new_source]
    if sum(switch_weights) == 0
        return "abort"
    end
    ps = switch_weights./sum(switch_weights)
    new_source = @trace(bernoulli(ps[2]), :new_source)
    #Decide the idx of the destination source
    #If it's a new source, it can go before any of the old sources or at the end
    #If it's an old source, you need to choose from the ones in source_switch
    destination_ps = new_source ? fill(1/(old_choices[:n_sources] + 1), old_choices[:n_sources] + 1) : source_switch./sum(source_switch)
    destination = @trace(categorical(destination_ps), :destination)
    
    ## change the source level variables to increase probabilities of acceptance
    source_type = source_params["types"][old_choices[:source => origin => :source_type]] #because we retain the source_type for switches
    tp_types = collect(keys(source_params["tp"]))
    gp_types = source_type == "tone" || source_type == "harmonic" ? ["erb", "amp"] : ["amp"]
    source_vars = append!(tp_types, gp_types)

    #tp: wait & dur_minus_min --> a (variability), mu (mean)
    #gp: erb --> mu, scale, sigma, noise
    #    amp --> mu, scale, sigma, noise OR mu, scale_t, scale_f, sigma, noise

    dest_elements_list = Dict(:t=>[])
    for source_var in source_vars 
        #Collecting the description of all the elements in the destination stream
        source_var_sym = Symbol(source_var)
        dest_elements_list[source_var_sym] = []
        compilefunc = source_var == "erb" || source_var == "amp" ? append! : push! #push for waits & dur_minus_min, they're scalars
        if ~new_source
            k = 1
            for j = 1:old_choices[:source => destination => :n_elements] + 1
                if j == which_spot[destination] ##Get the elements in order.
                    compilefunc(dest_elements_list[source_var_sym], old_choices[:source => origin => (:element, element_idx) => source_var_sym])
                else
                    compilefunc(dest_elements_list[source_var_sym], old_choices[:source => destination => (:element, k) => source_var_sym])
                    k += 1
                end
            end
        else
            compilefunc(dest_elements_list[source_var_sym], old_choices[:source => origin => (:element, element_idx) => source_var_sym])
        end
    end

    #Get a time vector 
    if ~new_source
        k = 1
        for j = 1:old_choices[:source => destination => :n_elements] + 1
            if j == which_spot[destination] ##Get the elements in order.
                append!(dest_elements_list[:t], get_element_gp_times(all_source_timings[origin][element_idx], steps["t"]))
            else
                append!(dest_elements_list[:t], get_element_gp_times(all_source_timings[destination][k], steps["t"]))
                k += 1
            end
        end
    else
        append!(dest_elements_list[:t], get_element_gp_times(all_source_timings[origin][element_idx], steps["t"]))
    end

    for source_var in source_vars 
        source_var_sym = Symbol(source_var)
        if source_var == "wait" || source_var == "dur_minus_min"
            @trace(custom_proposals[source_var_sym](dest_elements_list), :dest => source_var_sym)
        elseif source_var == "erb" || source_var == "amp"
            proposal_key = source_var == "erb" ? :erb : (source_type == "tone" ? :amp1D : :amp2D)
            @trace(custom_proposals[proposal_key](dest_elements_list[:t],dest_elements_list[source_var_sym]), :dest => source_var_sym)       
        end
    end
    #Only works for tones right now...                                                                                                  
    
    if old_choices[:source => origin => :n_elements] > 1  
        #there must be an still existing origin source
        origin_elements_list = Dict(:t=>[])
        for source_var in source_vars 
            #Collecting the description of all the elements in the destination stream
            source_var_sym = Symbol(source_var)
            origin_elements_list[source_var_sym] = []
            compilefunc = source_var == "erb" || source_var == "amp" ? append! : push! #push for waits & dur_minus_min, they're scalars
            for j = [jj for jj in 1:old_choices[:source => origin => :n_elements] if jj != element_idx]
                compilefunc(origin_elements_list[source_var_sym], old_choices[:source => origin => (:element, j) => source_var_sym])
            end
        end
        for j = [jj for jj in 1:old_choices[:source => origin => :n_elements] if jj != element_idx]
            append!(origin_elements_list[:t], get_element_gp_times(all_source_timings[origin][j], steps["t"]))
        end

        #Only works for tones!!
        for source_var in source_vars 
            source_var_sym = Symbol(source_var)
            if source_var == "wait" || source_var == "dur_minus_min"
                @trace(custom_proposals[source_var_sym](origin_elements_list), :orig => source_var_sym)
            elseif source_var == "erb" || source_var == "amp"
                proposal_key = source_var == "erb" ? :erb : (source_type == "tone" ? :amp1D : :amp2D)
                @trace(custom_proposals[proposal_key](origin_elements_list[:t],origin_elements_list[source_var_sym]), :orig => source_var_sym)       
            end
        end
                                    
    end
                                                                                                                                                                                                                                      
    return which_spot, all_source_timings
               
end

In [None]:
function rewrite_switch_involution(trace, fwd_choices, fwd_ret, proposal_args)

    if fwd_ret == "abort"
        return trace, fwd_choices, 0
    end
    #we need to specify how to go backwards
    #and how to construct the new trace
    bwd_choices = choicemap()
    new_choices = choicemap()
    which_gaps = fwd_ret[1]; all_source_timings = fwd_ret[2];
    
    source_params, scene_duration, wts, steps, audio_sr, obs_noise, gtg_params = get_args(trace)
    max_elements = source_params["n_elements"]["val"]
    old_choices = get_choices(trace)
    old_n_sources = old_choices[:n_sources]; 
    
    ## indexes for moving a element from origin to destination source
    origin_idx = fwd_choices[:origin]
    new_source = fwd_choices[:new_source]
    destination_idx = fwd_choices[:destination]
    element_switch_idx = fwd_choices[:element_idx]
                                                                                                                             
    
    old_origin_nt = old_choices[:source => origin_idx => :n_elements]
    old_destination_nt = new_source ? 0 : old_choices[:source => destination_idx => :n_elements]
    which_gap = new_source ? 1 : which_gaps[destination_idx]
    
    source_type = source_params["types"][old_choices[:source => origin_idx => :source_type]]
    gp_hyperpriors = source_type == "tone" ? Dict(:erb => source_params["gp"]["erb"], :amp => source_params["gp"]["amp"]["1D"]) :
                     (source_type == "harmonic" ? Dict(:erb => source_params["gp"]["erb"], :amp => source_params["gp"]["amp"]["2D"]) :
                      Dict(:amp => source_params["gp"]["amp"]["2D"]) ) #for noise
    gp_types = keys(gp_hyperpriors)
    tp_types = [:wait, :dur_minus_min]
    source_attributes_list = []
    for (k, v) in gp_hyperpriors
        push!(source_attributes_list, k => [Symbol(kv) for kv in keys(v)])
    end
    for k in tp_types 
        push!(source_attributes_list, k => [Symbol(kv) for kv in keys(source_params["tp"][String(k)])])
    end
    source_attributes = Dict(source_attributes_list...)
    element_attributes = append!(tp_types, collect(gp_types))
    element_attributes_no_wait = append!([:dur_minus_min], gp_types)

    ##Get all the properties of the switch element in the new source
    switch_element = Dict()
    #absolute onset and offset stay the same, so do duration and gps
    switch_element[:onset] = all_source_timings[origin_idx][element_switch_idx][1] 
    switch_element[:offset] = all_source_timings[origin_idx][element_switch_idx][2]         
    switch_element[:dur_minus_min]= old_choices[:source => origin_idx => (:element, element_switch_idx) => :dur_minus_min]
    for gp_type = gp_types
        switch_element[gp_type]=old_choices[:source => origin_idx => (:element, element_switch_idx) =>gp_type]
    end 
    #wait depends on what is before the switch_element in the destination stream
    prev_offset = (which_gap == 1) ? 0 : all_source_timings[destination_idx][which_gap - 1][2] 
    switch_element[:wait] = switch_element[:onset] - prev_offset;

    ##compute new WAITS OF elementS FOLLOWING SWITCH element, in both destination and origin
    #inserting switch_element before a element in destination source
    if which_gap <= old_destination_nt
        dest_after_wait = all_source_timings[destination_idx][which_gap][1] - switch_element[:offset]
    end
    #removing switch_element before a element in origin source
    if element_switch_idx < old_origin_nt
        prev_offset = element_switch_idx == 1 ? 0 : all_source_timings[origin_idx][element_switch_idx - 1][2]
        orig_after_wait = all_source_timings[origin_idx][element_switch_idx + 1][1] - prev_offset
    end
    
    if old_origin_nt == 1 
        println("OLD ORIGIN NT")
        # If the origin stream had only one element in it, it should be removed
        # The switch_element will not be moved into a new stream,
        # so n_sources should always decrease by 1
        
        # new_source = false
        # destination_idx chooses an existing source 
        
        # if destination_idx is larger than origin_idx
        # the idx of the destination source needs to be shifted down one
        # and any sources after the destination source need to be shifted down one
        
        # if destination_idx is smaller than origin_idx
        # the idx of the destination source can remain the same, 
        # but others may be shifted down one 

        new_choices[:n_sources] = old_n_sources - 1
        #Get indexes of old sources that must be changed 
        new_idx = 1:(old_n_sources - 1)
        #the OLD labels after the origin index are shifted up one because their new labels will be one smaller
        old_idx = [(n >= origin_idx ? (n + 1) : n) for n in new_idx] 
        #Only need to change old indexes that are greater than or equal to the origin idx
        #Get rid of (old_idx < origin_idx) because those won't change...
        #...as well as the destination index, which will be treated on its own
        matching_new_idx = [new_idx[i] for i in 1:length(new_idx) if (old_idx[i] != destination_idx)]
        old_idx = [o for o in old_idx if (o != destination_idx)]
        
        ##Shift sources that do not change
        for i = 1:length(old_idx)                        
            set_submap!(new_choices, :source=>matching_new_idx[i], get_submap(old_choices,:source=>old_idx[i]))                        
        end
        
        ##Deal with destination source specifically
        old_destination_idx = destination_idx
        new_destination_idx = old_destination_idx > origin_idx ? old_destination_idx - 1 : old_destination_idx 
        old_nt = old_choices[:source=>old_destination_idx=>:n_elements]

        #Get source attributes
        new_choices[:source => new_destination_idx => :n_elements] = old_nt + 1
        new_choices[:source => new_destination_idx => :source_type] = old_choices[:source=>old_destination_idx=>:source_type]
        for ks in keys(source_attributes)
            if ks == :erb || ks == :amp
                #Only works for 1D right now!!
                new_choices[:source => new_destination_idx => ks => :mu] = fwd_choices[:dest => ks => :gpparams][2]
                new_choices[:source => new_destination_idx => ks => :sigma] = exp(fwd_choices[:dest => ks => :gpparams][4])
                new_choices[:source => new_destination_idx => ks => :epsilon] = exp(fwd_choices[:dest => ks => :gpparams][1])
                new_choices[:source => new_destination_idx => ks => :scale] = exp(fwd_choices[:dest => ks => :gpparams][3])
            else #temporal attributes, weight and durminsmin
                for a in source_attributes[ks]
                    new_choices[:source => new_destination_idx => ks => a] = fwd_choices[:dest => ks => a]
                end
            end
        end
        
        #Switch element
        for a in element_attributes
            new_choices[:source => new_destination_idx => (:element, which_gap) => a] = switch_element[a]
        end
                                
        #All elements before the switch_element stay the same
        if which_gap > 1
            for j = 1:which_gap - 1
                set_submap!(new_choices, :source => new_destination_idx => (:element, j), get_submap(old_choices,:source=>old_destination_idx=>(:element,j)))
            end
        end
        #If there are any elements after the switch_element they must be increased in index by one
        if which_gap <= old_nt #comes before one of the old elements
            for new_element_idx = (which_gap + 1):(old_nt + 1)
                new_choices[:source => new_destination_idx => (:element, new_element_idx) => :wait] = (new_element_idx == (which_gap + 1)) ? dest_after_wait : old_choices[:source => old_destination_idx => (:element,new_element_idx-1) => :wait]
                for a in element_attributes_no_wait
                    new_choices[:source => new_destination_idx => (:element, new_element_idx) => a] = old_choices[:source=>old_destination_idx=>(:element,new_element_idx-1)=>a]
                end 
            end
        end

        bwd_choices[:origin] = new_destination_idx
        bwd_choices[:new_source] = true
        for ks in keys(source_attributes)
            if ks == :erb || ks == :amp
                #Only works for 1D gps!
                g1 = log(old_choices[:source => origin_idx => ks => :epsilon])
                g2 = old_choices[:source => origin_idx => ks => :mu]
                g3 = log(old_choices[:source => origin_idx => ks => :scale])
                g4 = log(old_choices[:source => origin_idx => ks => :sigma])

                bwd_choices[:dest => ks => :gpparams] = vec([g1, g2, g3, g4])

                g1 = log(old_choices[:source => destination_idx => ks => :epsilon])
                g2 = old_choices[:source => destination_idx => ks => :mu]
                g3 = log(old_choices[:source => destination_idx => ks => :scale])
                g4 = log(old_choices[:source => destination_idx => ks => :sigma])

                bwd_choices[:orig => ks => :gpparams] = vec([g1, g2, g3, g4])
            else #temporal attributes, weight and durminsmin
                for a in source_attributes[ks]
                    bwd_choices[:dest => ks => a] = old_choices[:source => origin_idx => ks => a]
                    bwd_choices[:orig => ks => a] = old_choices[:source => destination_idx => ks => a]
                end
            end
        end

        bwd_choices[:destination] = origin_idx
        bwd_choices[:element_idx] = which_gap
            
    elseif new_source
        println("NEW SOURCE")
        # we put the element in a new stream 
        # we keep the origin stream as well
        # so n_sources increases by 1
        # need to shift all the sources after the destination_idx
                                
        new_choices[:n_sources] = old_n_sources + 1
        new_choices[:source => destination_idx => :source_type] = old_choices[:source => origin_idx => :source_type]
        
        ##Create new destination source with a single element in it
        new_choices[:source => destination_idx => :n_elements] = 1
        for ks in keys(source_attributes)
            if ks == :erb || ks == :amp
                #Only works for 1D right now!!
                new_choices[:source => destination_idx => ks => :mu] = fwd_choices[:dest => ks => :gpparams][2]
                new_choices[:source => destination_idx => ks => :sigma] = exp(fwd_choices[:dest => ks => :gpparams][4])
                new_choices[:source => destination_idx => ks => :epsilon] = exp(fwd_choices[:dest => ks => :gpparams][1])
                new_choices[:source => destination_idx => ks => :scale] = exp(fwd_choices[:dest => ks => :gpparams][3])
            else #temporal attributes, weight and durminsmin
                for a in source_attributes[ks]
                    new_choices[:source => destination_idx => ks => a] = fwd_choices[:dest => ks => a]
                end
            end
        end
        for a in element_attributes
            new_choices[:source => destination_idx => (:element, 1) => a] = switch_element[a]
        end

        ##in origin source, move all elements down one index if they're after the switch index
        old_origin_idx = origin_idx
        new_origin_idx = origin_idx >= destination_idx ? origin_idx + 1 : origin_idx
        old_nt = old_choices[:source => origin_idx => :n_elements]
        new_choices[:source => new_origin_idx => :n_elements] = old_nt - 1
        new_choices[:source => new_origin_idx => :source_type] = old_choices[:source => origin_idx => :source_type]
        for ks in keys(source_attributes)
            if ks == :erb || ks == :amp
                #Only works for 1D right now!!
                new_choices[:source => new_origin_idx => ks => :mu] = fwd_choices[:orig => ks => :gpparams][2]
                new_choices[:source => new_origin_idx => ks => :sigma] = exp(fwd_choices[:orig => ks => :gpparams][4])
                new_choices[:source => new_origin_idx => ks => :epsilon] = exp(fwd_choices[:orig => ks => :gpparams][1])
                new_choices[:source => new_origin_idx => ks => :scale] = exp(fwd_choices[:orig => ks => :gpparams][3])
            else #temporal attributes, weight and durminsmin
                for a in source_attributes[ks]
                    new_choices[:source => new_origin_idx => ks => a] = fwd_choices[:orig => ks => a]
                end
            end
        end

        if element_switch_idx > 1
            for j = 1:element_switch_idx - 1
                set_submap!(new_choices, :source => new_origin_idx => (:element, j), get_submap(old_choices,:source=>old_origin_idx=>(:element,j)))
            end
        end                        
        if element_switch_idx < old_nt
            for old_element_idx = (element_switch_idx + 1):old_nt
                new_choices[:source => new_origin_idx => (:element, old_element_idx-1)=>:wait] = (old_element_idx == (element_switch_idx + 1)) ? orig_after_wait : old_choices[:source => origin_idx => (:element, old_element_idx)=> :wait]
                for a in element_attributes_no_wait
                    new_choices[:source => new_origin_idx => (:element, old_element_idx-1) => a] = old_choices[:source=>origin_idx=>(:element,old_element_idx)=>a]
                end 
            end
        end
                                
        ##shift all sources after destination_idx up one
        if destination_idx < new_choices[:n_sources]
            shift_idxs = [i for i in (destination_idx+1):new_choices[:n_sources] if i != new_origin_idx]
            for i in shift_idxs
                set_submap!(new_choices, :source=>i, get_submap(old_choices, :source=>i-1))
            end
        end
                                
        bwd_choices[:origin] = destination_idx
        bwd_choices[:new_source] = false  
        for ks in keys(source_attributes)
            if ks == :erb || ks == :amp
                #Only works for 1D gps!
                g1 = log(old_choices[:source => origin_idx => ks => :epsilon])
                g2 = old_choices[:source => origin_idx => ks => :mu]
                g3 = log(old_choices[:source => origin_idx => ks => :scale])
                g4 = log(old_choices[:source => origin_idx => ks => :sigma])

                bwd_choices[:dest => ks => :gpparams] = vec([g1, g2, g3, g4])

            else #temporal attributes, weight and durminsmin
                for a in source_attributes[ks]
                    bwd_choices[:dest => ks => a] = old_choices[:source => origin_idx => ks => a]
                end
            end
        end


        bwd_choices[:destination] = new_origin_idx
        bwd_choices[:element_idx] = 1
            
    else
        println("New source = false")
        # we put the element in an old stream, and keep the origin stream
        # streams do not have to be shifted 
        # new_source = false
        ##in origin source, move all elements to earlier index if they're after the switch index
        old_nt = old_choices[:source => origin_idx => :n_elements]
        new_choices[:source => origin_idx => :n_elements] = old_nt - 1
        for ks in keys(source_attributes)
            if ks == :erb || ks == :amp
                #Only works for 1D right now!!
                new_choices[:source => origin_idx => ks => :mu] = fwd_choices[:orig => ks => :gpparams][2]
                new_choices[:source => origin_idx => ks => :sigma] = exp(fwd_choices[:orig => ks => :gpparams][4])
                new_choices[:source => origin_idx => ks => :epsilon] = exp(fwd_choices[:orig => ks => :gpparams][1])
                new_choices[:source => origin_idx => ks => :scale] = exp(fwd_choices[:orig => ks => :gpparams][3])
            else #temporal attributes, weight and durminsmin
                for a in source_attributes[ks]
                    new_choices[:source => origin_idx => ks => a] = fwd_choices[:orig => ks => a]
                end
            end
        end
        if element_switch_idx < old_nt
            for old_element_idx = (element_switch_idx + 1):old_nt
                new_choices[:source => origin_idx => (:element, old_element_idx-1)=>:wait] = (old_element_idx == (element_switch_idx + 1)) ? orig_after_wait : old_choices[:source => origin_idx => (:element, old_element_idx)=>:wait]
                for a in element_attributes_no_wait
                    new_choices[:source => origin_idx => (:element, old_element_idx-1) => a] = old_choices[:source=>origin_idx=>(:element,old_element_idx)=>a]
                end 
            end
        end
            
        ##in destination source, insert element and then shift elements to later index
        old_nt = old_choices[:source => destination_idx => :n_elements]
        new_choices[:source => destination_idx => :n_elements] = old_nt + 1
        for ks in keys(source_attributes)
            if ks == :erb || ks == :amp
                #Only works for 1D right now!!
                new_choices[:source => destination_idx => ks => :mu] = fwd_choices[:dest => ks => :gpparams][2]
                new_choices[:source => destination_idx => ks => :sigma] = exp(fwd_choices[:dest => ks => :gpparams][4])
                new_choices[:source => destination_idx => ks => :epsilon] = exp(fwd_choices[:dest => ks => :gpparams][1])
                new_choices[:source => destination_idx => ks => :scale] = exp(fwd_choices[:dest => ks => :gpparams][3])
            else #temporal attributes, weight and durminsmin
                for a in source_attributes[ks]
                    new_choices[:source => destination_idx => ks => a] = fwd_choices[:dest => ks => a]
                end
            end
        end

        #Switch element
        for a in element_attributes
            new_choices[:source => destination_idx => (:element, which_gap) => a] = switch_element[a]
        end
        #elements after switch_element
        if which_gap <= old_nt
            for new_element_idx = (which_gap + 1):(old_nt + 1)
                new_choices[:source => destination_idx =>(:element,new_element_idx)=>:wait] = (new_element_idx == (which_gap + 1)) ? dest_after_wait : old_choices[:source => destination_idx => (:element, new_element_idx-1)=> :wait]
                for a in element_attributes_no_wait
                    new_choices[:source => destination_idx => (:element, new_element_idx) => a] = old_choices[:source=>destination_idx=>(:element,new_element_idx-1)=>a]
                end
            end
        end
            
        bwd_choices[:origin] = destination_idx
        bwd_choices[:destination] = origin_idx
        bwd_choices[:element_idx] = which_gap
        bwd_choices[:new_source] = false
        for ks in keys(source_attributes)
            if ks == :erb || ks == :amp
                #Only works for 1D gps!
                g1 = log(old_choices[:source => origin_idx => ks => :epsilon])
                g2 = old_choices[:source => origin_idx => ks => :mu]
                g3 = log(old_choices[:source => origin_idx => ks => :scale])
                g4 = log(old_choices[:source => origin_idx => ks => :sigma])

                bwd_choices[:dest => ks => :gpparams] = vec([g1, g2, g3, g4])

                g1 = log(old_choices[:source => destination_idx => ks => :epsilon])
                g2 = old_choices[:source => destination_idx => ks => :mu]
                g3 = log(old_choices[:source => destination_idx => ks => :scale])
                g4 = log(old_choices[:source => destination_idx => ks => :sigma])

                bwd_choices[:orig => ks => :gpparams] = vec([g1, g2, g3, g4])
            else #temporal attributes, weight and durminsmin
                for a in source_attributes[ks]
                    bwd_choices[:dest => ks => a] = old_choices[:source => origin_idx => ks => a]
                    bwd_choices[:orig => ks => a] = old_choices[:source => destination_idx => ks => a]
                end
            end
        end
                                            
              
    end
    new_trace, weight = update(trace, get_args(trace), (), new_choices)
    return new_trace, bwd_choices, weight

end

In [None]:
function check_score(initial_trace, new_trace, observation_trace)
    println("Looking at each part of the score:");println("")
    #@assert initial_trace[:noise] == new_trace[:noise]
    noise_value = 1.0#initial_trace[:noise]
    # old_scene, _, _ = get_retval(initial_trace)
    # new_scene, _, _ = get_retval(new_trace)
    ## NOTE: if you use "old_scene" or "new_scene" instead of the value from the trace, the likelihood does NOT come out equal!!
    old_likelihood=Gen.logpdf(noisy_matrix,initial_trace[:scene],observation_trace[:scene],noise_value)
    new_likelihood=Gen.logpdf(noisy_matrix,new_trace[:scene],observation_trace[:scene],noise_value)
    println("Old likelihood: $old_likelihood, New likelihood: $new_likelihood")
    @assert old_likelihood == new_likelihood
    println("Likelihoods are equal.")

    println(""); println("Score of random choices:")
    for trace_number = 1:2
        trace = trace_number == 1 ? initial_trace : new_trace
        trace_name = trace_number == 1 ? "Initial" : "New"
        
        total = 0
        for i = 1:trace[:n_sources]
            single = project(trace, select(:source => i))
            total += single
            println("$trace_name Trace -- source $i: $single")
            
            for source_var in [:wait, :dur_minus_min, :erb, :amp]
                if source_var == :wait || source_var == :dur_minus_min
                    sp = source_params["tp"][String(source_var)]
                elseif source_var == :erb
                    sp = source_params["gp"]["erb"]
                elseif source_var == :amp
                    sp = source_params["gp"]["amp"]["1D"]
                end
                println("\t$source_var")
                for a in keys(sp)
                    sa = Symbol(a)
                    v = round(trace[:source => i => source_var => sa], digits=4)
                    rp = round(project(trace, select(:source => i => source_var => sa)), digits=8)
                    println("\t\t $a=$v: ", rp)
                end
            end
            
            n_tones = trace[:source => i => :n_elements]
            println("\tn_tones $n_tones: ", project(trace, select(:source => i => :n_elements)))
            for j = 1:trace[:source => i => :n_elements]
                println("\tTone $j")
                for a = [:wait, :dur_minus_min, :erb, :amp]
#                     println(trace[:source => i => (:tone, j) => a])
                    mval= (a==:erb||a==:amp) ? mean(trace[:source => i => (:element, j) => a]) : trace[:source => i => (:element, j) => a]
                    v = round(mval, digits=4)
                    rp = round(project(trace, select(:source => i => (:element, j) => a)),digits=8)
                    println("\t\t", a, "=$v: ", rp)
                end
            end
        end
        println("$trace_name Trace -- total: $total");println("")
            
    end

end

In [None]:
function check_choice_probabilities(new_or_init_trace, bwd_or_fwd_choices)
    #bwd_choices should pair with new_trace (bwd moves away from new_trace to init_trace)
    #fwd_choices should pair with init_trace (fwd moves away from init_trace to new_trace)
    (choice_trace, _) = generate(rewrite_switch_randomness, (new_or_init_trace,custom_proposals,), bwd_or_fwd_choices);
    println("Choice trace:")
    println(get_choices(choice_trace))
    println()
    println("New source: ", project(choice_trace, select(:new_source)))
    println("Element idx: ", project(choice_trace, select(:element_idx)))
    println("Origin: ", project(choice_trace, select(:origin)))
    println("Destination: ", project(choice_trace, select(:destination)))
    for location in [:orig, :dest]
        println(location)
        for gp in [:erb, :amp]
            println("$gp: ", project(choice_trace, select(location => gp => :gpparams)))
        end
        for tp in [:wait, :dur_minus_min]
            for a in [:a, :mu]
                println("$tp $a: ", project(choice_trace, select(location => tp => a)))
            end
        end
    end
end

In [None]:
include("./model.jl")
include("./perfect_init.jl")

demofunc = tougas_bregman_1A
demoargs = [] #empty
demoargs = Tuple(demoargs)

demo_trace = perfect_initialization(demofunc,demoargs;MLE=true);
demo_gram, _, _, _, _ = get_retval(demo_trace)

In [None]:
(fwd_choices, fwd_score, fwd_ret) = propose(rewrite_switch_randomness, (demo_trace,custom_proposals,))
(new_trace, bwd_choices, weight) = rewrite_switch_involution(demo_trace, fwd_choices, fwd_ret, ())
(bwd_score, _) = assess(rewrite_switch_randomness, (new_trace,custom_proposals,), bwd_choices)
println((weight - fwd_score + bwd_score))
println("Weight: $(weight), Fwd: $(fwd_score), Bwd: $(bwd_score)")
# trace, accepted = mh(trace, switch_randomness, (), switch_involution) 
println(fwd_choices)
println("Backward choices")
println(bwd_choices)

In [None]:
plot_sources(demo_trace, demo_gram, 0; save=false)
plot_sources(new_trace, demo_gram, 1; save=false)

In [None]:
check_score(demo_trace, new_trace, demo_trace)

In [None]:
check_choice_probabilities(new_trace, bwd_choices)

In [None]:
check_choice_probabilities(demo_trace, fwd_choices)

are the backward choices correct? i.e. is the round trip trace correct?

In [None]:
function test_involution_fixed_randomness_args(trace, randomness, randomness_args, involution)

    # sample from the randomness
    (fwd_choices, fwd_score, fwd_ret) = propose(randomness, (trace, randomness_args...))
    # run the involution
    (new_trace, bwd_choices, weight) = involution(trace, fwd_choices, fwd_ret, randomness_args)
    (new_randomness_trace, _) = generate(randomness, (new_trace,randomness_args...), bwd_choices)
    # could check: get_choices(new_randomness_trace) == bwd_choices
    new_fwd_ret = get_retval(new_randomness_trace)

    # run the involution again
    (trace_round_trip, fwd_choices_round_trip, reverse_weight) = involution(new_trace, bwd_choices, new_fwd_ret, randomness_args)

    # check the weight
    @assert(isapprox(reverse_weight, -weight), "isapprox(reverse_weight, -weight): $reverse_weight =/= -$weight")
    #plot_gtg(get_retval(trace_round_trip)[1] - get_retval(trace)[1], get_args(trace)[2], get_args(trace)[end-1]/2.)
    
    @assert trace[:scene] == trace_round_trip[:scene]
    #plot_gtg(trace[:scene] - trace_round_trip[:scene], get_args(trace)[2], get_args(trace)[end-1]/2.)
    s1 = project(trace, select(:scene));  s2 = project(trace_round_trip, select(:scene));
    @assert(s1 == s2, "Score-- scene $s1 =/= $s2")
    @assert(trace[:n_sources] == trace_round_trip[:n_sources], ["n_sources: ", trace[:n_sources], " =/= ", trace_round_trip[:n_sources]])
    s1 = project(trace, select(:n_sources));  s2 = project(trace_round_trip, select(:n_sources));
    @assert(s1 == s2, "Score-- n_sources $s1 =/= $s2")
    for source_idx = 1:trace[:n_sources]
        
        source_trace = get_submap(get_choices(trace), :source => source_idx);
        source_trace_round_trip = get_submap(get_choices(trace_round_trip), :source => source_idx)
        @assert(source_trace[:n_elements] == source_trace_round_trip[:n_elements],["n_elements: ", source_trace[:n_elements]," =/= ",source_trace_round_trip[:n_elements]])
        s1 = project(trace, select(:source => source_idx => :n_elements));  s2 = project(trace_round_trip, select(:source => source_idx => :n_elements));
        @assert(s1 == s2, "Score-- n_elements $s1 =/= $s2")
            
        @assert(source_trace[:source_type] == source_trace_round_trip[:source_type],["source_type: ", source_trace[:source_type]," =/= ",source_trace_round_trip[:source_type]])
        s1 = project(trace, select(:source => source_idx => :source_type));  s2 = project(trace_round_trip, select(:source => source_idx => :source_type));
        @assert(s1 == s2, "Score-- source_type $s1 =/= $s2")
        
        for svar = [:erb, :amp]
            for a = [:mu, :sigma, :scale, :epsilon]
                @assert(isapprox(source_trace[svar => a], source_trace_round_trip[svar => a]),
                    string("Value-- $svar $a: ", source_trace[svar => a]," not approx ",source_trace_round_trip[svar => a]))
                s1 = project(trace, select(:source => source_idx => svar => a));  
                s2 = project(trace_round_trip, select(:source => source_idx => svar => a));
                @assert(isapprox(s1, s2), "Score-- $svar $a $s1 not approx $s2")
            end
        end
        
        for svar = [:wait, :dur_minus_min]
            for a = [:mu, :a]
                @assert(isapprox(source_trace[svar => a], source_trace_round_trip[svar => a]),["$svar $a: ", source_trace[svar => a]," not approx ",source_trace_round_trip[svar => a]])
                s1 = project(trace, select(:source => source_idx => svar => a));  s2 = project(trace_round_trip, select(:source => source_idx => svar => a));
                @assert(isapprox(s1, s2), "Score-- $svar $a $s1 not approx $s2")
            end
        end
        
        element_as = [:wait, :dur_minus_min, :erb, :amp] 
        for element_idx = 1:source_trace[:n_elements]
            for a = element_as
                @assert(isapprox(source_trace[(:element, element_idx) => a],source_trace_round_trip[(:element, element_idx) => a]), ["$a $element_idx: ", source_trace[(:element, element_idx) => a], " =/= ", source_trace_round_trip[(:element, element_idx) => a]])
                s1 = project(trace, select(:source => source_idx => (:element, element_idx) => a));  s2 = project(trace_round_trip, select(:source => source_idx => (:element, element_idx) => a));
                @assert( isapprox(s1,s2), "Score-- element $element_idx, $a : $s1 =/= $s2")
            end
        end
        
        
    end
    
end


function test_involution_move(model, randomness, involution, randomness_args_fn, args)
    
    # generate a random trace of the model
    trace = simulate(model, args) 
    # get args for ranadoness
    randomness_args_list = randomness_args_fn(trace)
    for randomness_args in randomness_args_list
        test_involution_fixed_randomness_args(trace, randomness, randomness_args, involution)
    end
    return "Successful test!"
end

In [None]:
args = get_args(demo_trace)
function cp_args_fn(trace)
    return [(custom_proposals,)]
end
println("Testing switch...")
for i = 1:100 #9 is taking a long time
    Random.seed!(i)
    success = test_involution_move(generate_scene, rewrite_switch_randomness, rewrite_switch_involution, cp_args_fn,args)
    println("$i switch: ", success)
end

which of the backward choices is really unlikely?

In [None]:
# (trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple,
#                                 constraints::ChoiceMap)
# Don't have an example of using it this way, but i guess it must be analogous to assess?? 
# (weight, retval) = assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap)
# (bwd_score, _) = assess(rewrite_switch_randomness, (new_trace,custom_proposals,), bwd_choices)
# weight = project(trace::U, selection::Selection)
# e.g. s2 = project(trace_round_trip, select(:source => source_idx => svar => a));



In [None]:
(fwd_trace, _) = generate(rewrite_switch_randomness, (demo_trace,custom_proposals,), fwd_choices);
println("Forward trace:")
println(get_choices(fwd_trace))
println()
println("New source: ", project(fwd_trace, select(:new_source)))
println("Element idx: ", project(fwd_trace, select(:element_idx)))
println("Origin: ", project(fwd_trace, select(:origin)))
println("Destination: ", project(fwd_trace, select(:destination)))
for location in [:orig, :dest]
    println(location)
    for gp in [:erb, :amp]
        println("$gp: ", project(fwd_trace, select(location => gp => :gpparams)))
    end
    for tp in [:wait, :dur_minus_min]
        for a in [:a, :mu]
            println("$tp $a: ", project(fwd_trace, select(location => tp => a)))
        end
    end
end

In [None]:
for (i, v) in enumerate(["a","b","c"])
    println("i: $i, v: $v")
end

In [None]:
trace = demo_trace
plot_sources(demo_trace, scene_gram, 0; save=false)
for i = 1:50
    trace, accepted = mh(trace, rewrite_switch_randomness, 
                        (custom_proposals,), rewrite_switch_involution) 
    plot_sources(trace, scene_gram, i; save=false)
end


In [None]:
for i = 51:101
    trace, accepted = mh(trace, rewrite_switch_randomness, 
                        (custom_proposals,), rewrite_switch_involution) 
    plot_sources(trace, scene_gram, i; save=false)
end

In [None]:
for i = 102:1001
    trace, accepted = mh(trace, rewrite_switch_randomness, 
                        (custom_proposals,), rewrite_switch_involution) 
    plot_sources(trace, scene_gram, i; save=true)
end

In [None]:
(fwd_choices, fwd_score, fwd_ret) = propose(rewrite_switch_randomness, (trace,custom_proposals,))
(new_trace, bwd_choices, weight) = rewrite_switch_involution(trace, fwd_choices, fwd_ret, ())
(bwd_score, _) = assess(rewrite_switch_randomness, (new_trace,custom_proposals,), bwd_choices)
println((weight - fwd_score + bwd_score))
println("Weight: $(weight), Fwd: $(fwd_score), Bwd: $(bwd_score)")
# trace, accepted = mh(trace, switch_randomness, (), switch_involution) 
println(fwd_choices)
println("Backward choices")
println(bwd_choices)
plot_sources(new_trace, scene_gram, 51; save=false)


In [None]:
check_score(trace, new_trace, demo_trace)

In [None]:
x = []
for j = 1:new_trace[:source=>:1=>:n_elements]
    append!(x,new_trace[:source=>:1=>(:element,j)=>:wait])
end
println("wait")
println(mean(x))
println(std(x))
println( (0.0304-mean(x))/std(x) )

d = []
for j = 1:new_trace[:source=>:1=>:n_elements]
    append!(d,new_trace[:source=>:1=>(:element,j)=>:dur_minus_min])
end
println("dur_minus_min")
println(mean(d))
println(std(d))

In [None]:
constraints = choicemap((:source => 1 => :wait => :mu, mean(x)),
                    (:source => 1 => :dur_minus_min => :mu, mean(d))
)

(mu_trace, w, _, discard) = update(new_trace, args, (), constraints)

In [None]:
check_score(trace, mu_trace, demo_trace)

In [None]:
trace, accepted = mh(demo_trace, rewrite_switch_randomness, 
                        (custom_proposals,), rewrite_switch_involution) 
println(accepted)

In [None]:
~("hello" in keys(Dict("hello"=>[])))