In [None]:
import Random;
using Printf;
using JSON
include("./perfect_init.jl")

In [None]:
println("Loading in command-line args...")
demo_name = "tougas_bregman_1A"

demoargs=[]; demofolder = ""
if demo_name == "tougas_bregman_1A"
    demofunc = tougas_bregman_1A
    demoargs = [] #empty
    demofolder = string(demofolder, demo_name,"-")
elseif demo_name == "bregman_rudnicky"
    demofunc = bregman_rudnicky
    standard=ARGS[2]#"up" or "down"
    comparison=ARGS[3]#"up" or "down"
    captor=ARGS[4]#"none","far","mid","near"
    demoargs = [standard, comparison, captor]
    demofolder = string(demofolder, demo_name , "-", standard, "_", comparison, "_", captor, "-")
elseif demo_name == "ABA"
    demofunc = ABA
    semitones=parse(Float64,ARGS[2]) #-15 to +15
    spacing=parse(Float64, ARGS[3])#0.040s to 0.800s
    demoargs = [semitones,spacing]
    demofolder = string(demofolder, demo_name,"-",spacing,"_",semitones,"-")
end
demoargs = Tuple(demoargs)

println("Setting random seed...")
random_seed = abs(rand(Int,1)[1])
Random.seed!(random_seed)

searchdir(path,key) = filter(x->occursin(key,x), readdir(path))
list_of_expts = searchdir("./sounds/",demofolder)
iter = length(list_of_expts) + 1
itername = @sprintf("%02d",iter); 
save_loc = string("./sounds/",demofolder,"$itername/");
print("Making folder: "); println(save_loc)
mkdir(save_loc)
d = Dict()
d["seed"] = random_seed
open(string(save_loc, "random_seed.json"),"w") do f
    JSON.print(f, d)
end

In [None]:
trace = perfect_initialization(demofunc,demoargs);
scene_gram, t, scene_wave, source_waves, element_waves=get_retval(trace)
plot_sources(trace, scene_gram, 0; save_loc=save_loc)

In [None]:
println(trace[:source => 1 => :erb => :mu])

In [None]:
trace, accepted = hmc(trace, select(:source => 1 => :erb => :mu ))

In [None]:
function make_source_latent_model(source_params, audio_sr, steps)
    
    @gen function source_latent_model(latents, scene_duration)

        ##Single function for generating data for amoritized inference to propose source-level latents
        # Wait: can generate on its own
        # Dur_minus_min: can generate on its own
        # GPs:
        # Can have a separate amoritized inference move for ERB, Amp-1D, Amp-2D 
        # Need to sample wait and dur_minus_min to define the time points for the GP 
        #
        # Format of latents: 
        # latents = Dict(:gp => :amp OR :tp => :wait, :source_type => "tone")

        ### SOURCE-LEVEL LATENTS 
        ## Sample GP source-level latents if needed 
        gp_latents = Dict()
        if :gp in keys(latents)

            gp_params = source_params["gp"]
            gp_type = latents[:gp]; gp_latents[gp_type] = Dict();
            source_type = latents[:source_type]
            
            hyperpriors = gp_type == :erb ? gp_params["erb"] : 
                ((source_type == "noise" || source_type == "harmonic") ? gp_params["amp"]["2D"] : gp_params["amp"]["1D"] )
    
            for latent in keys(hyperpriors)
                hyperprior = hyperpriors[latent]; syml = Symbol(latent)
                gp_latents[gp_type][syml] = @trace(hyperprior["dist"](hyperprior["args"]...), gp_type => syml)
            end
            
        end

        ## Sample temporal source-level latents 
        tp_latents = Dict()
        if :tp in keys(latents)
            tp_latents[latents[:tp]] = Dict()
        elseif :gp in keys(latents)
            tp_latents[:wait] = Dict()
            tp_latents[:dur_minus_min] = Dict()
        end            
        tp_params = source_params["tp"]
        for tp_type in keys(tp_latents)
            hyperpriors = tp_params[String(tp_type)]
            for latent in keys(hyperpriors)
                hyperprior = hyperpriors[latent]; syml = Symbol(latent)
                tp_latents[tp_type][syml] = @trace(hyperprior["dist"](hyperprior["args"]...), tp_type => syml)
            end
            tp_latents[tp_type][:args] = (tp_latents[tp_type][:a], tp_latents[tp_type][:mu]/tp_latents[tp_type][:a]) #a, b for gamma
        end

        ## Sample a number of elements 
        ne_params = source_params["n_elements"]
        n_elements = ne_params["type"] == "max" ? 
            @trace(uniform_discrete(1, ne_params["val"]),:n_elements) : 
            @trace(geometric(ne_params["val"]),:n_elements)    

        ### ELEMENT-LEVEL LATENTS 
        #Storage for what inputs are needed
        tp_elems = Dict( [ k => [] for k in keys(tp_latents)]... )
        gp_elems = Dict(); x_elems = [];
        if :gp in keys(latents)       
            
            gp_type = latents[:gp]; source_type = latents[:source_type]
            
            gp_elems[gp_type] = []
            gp_type = latents[:gp]
            gp_elems[:t] = []
            if gp_type == :amp && (source_type == "noise" || source_type == "harmonic")
                gp_elems[:reshaped] = []
                gp_elems[:f] = []
                gp_elems[:tf] = []
            end 
            
        end
                
        time_so_far = 0.0;
        for element_idx = 1:n_elements
            
            if :wait in keys(tp_elems)
                wait = element_idx == 1 ? @trace(uniform(0, scene_duration), (:element,element_idx)=>:wait) : 
                    @trace(gamma(tp_latents[:wait][:args]...), (:element,element_idx)=>:wait)
                push!(tp_elems[:wait], wait)
            end
            
            if :dur_minus_min in keys(tp_elems)
                dur_minus_min = @trace(truncated_gamma(tp_latents[:dur_minus_min][:args]..., source_params["duration_limit"]), (:element,element_idx)=>:dur_minus_min); 
                push!(tp_elems[:dur_minus_min], dur_minus_min)
            end
            
            if :gp in keys(latents)
                
                gp_type = latents[:gp]; source_type = latents[:source_type]
                duration = dur_minus_min + steps["min"]; onset = time_so_far + wait; 
                time_so_far = onset + duration; element_timing = [onset, time_so_far]

                ## Define points at which the GPs should be sampled
                x = []; ts = [];
                if gp_type === :erb || (source_type == "tone" && gp_type === :amp)
                    x = get_element_gp_times(element_timing, steps["t"])
                elseif gp_type === :amp && (source_type == "noise"  || source_type == "harmonic")
                    x, ts, gp_elems[:f] = get_gp_spectrotemporal(element_timing, steps, audio_sr)
                end

                mu, cov = element_idx == 1 ? get_mu_cov(x, gp_latents[gp_type]) : 
                        get_cond_mu_cov(x, x_elems, gp_elems[gp_type], gp_latents[gp_type])
                element_gp = @trace(mvnormal(mu, cov), (:element, element_idx) => gp_type)
                
                ## Save the element data 
                append!(x_elems, x)
                if gp_type === :erb || (source_type == "tone" && gp_type === :amp)
                    append!(gp_elems[:t], x)
                    append!(gp_elems[gp_type], element_gp)
                elseif gp_type === :amp && (source_type == "noise"  || source_type == "harmonic")
                    append!(gp_elems[:t],ts)
                    append!(gp_elems[:tf],x) 
                    append!(gp_elems[gp_type], element_gp)
                    reshaped_elem = reshape(element_gp, (length(gp_elems[:f]), length(ts))) 
                    if element_idx == 1
                        gp_elems[:reshaped] = reshaped_elem
                    else
                        gp_elems[:reshaped] = cat(gp_elems[:reshaped], reshaped_elem, dims=2)
                    end
                end
           
            end

        end

        return tp_latents, gp_latents, tp_elems, gp_elems

    end
 
    return source_latent_model
    
end
source_params, steps, gtg_params, obs_noise = include("./base_params.jl")
audio_sr = 20000;
source_latent_model = make_source_latent_model(source_params, audio_sr, steps);

In [None]:
latents = Dict()
latents[:wait] = Dict(:tp => :wait)
scene_duration = 2.0
trace, = generate(source_latent_model, (latents[:wait], scene_duration,));

In [None]:
println(get_choices(trace))

In [None]:
trace, accepted = hmc(trace, select(:wait => :mu))
println("Accepted? ", accepted)
println(get_choices(trace))