In [None]:
using Gen;
using Dierckx;
include("./proposals.jl")
include("./rendering.jl")
include("./model.jl")
include("./time_helpers.jl")
include("./gaussian_helpers.jl")
include("./plotting.jl")
include("./inference_helpers.jl")
include("./extra_distributions.jl")

In [None]:
function initialize_tone_sequences(elements, pad, steps)

    n_elements = length(elements)
    constraints = choicemap()
    constraints[:n_sources] = 1 
    constraints[:source => 1 => :source_type] = 1 #tone
    constraints[:source => 1 => :n_elements] = n_elements

    last_offset = pad[1]
    for i = 1:n_elements

        element = elements[i]
        
        wait = element["g"] + 1e-4*rand() + ( i == 1 ? last_offset : 0 )
        constraints[:source => 1 => (:element, i) => :wait ] = wait 
        duration = element["d"] + 1e-4*rand() 
        constraints[:source => 1 => (:element, i) => :dur_minus_min] = duration - steps["min"]
           
        onset = i == 1 ? wait : wait + last_offset
        offset = onset + duration 
        
        curr_t = get_element_gp_times([onset, offset], steps["t"])
        
        f = element["f"]
        erbSpl = Spline1D([onset,offset], [f,f], k=1)
        erbf0 = erbSpl(curr_t); 
        l = element["l"]
        ampSpl = Spline1D([onset,offset], [l,l], k=1)
        filt = ampSpl(curr_t);

        constraints[:source => 1 => (:element, i) => :erb] = erbf0
        constraints[:source => 1 => (:element, i) => :amp] = filt
        
        last_offset = offset
        
    end
    
    scene_duration = last_offset + pad[2]
    return constraints, scene_duration
    
end


In [None]:
function tougas_bregman_1A()
    
    frequencies = [1600, 1270, 1008, 800, 635, 504, 400]
    levels = Dict("1600"=>75, "1270"=>73, "1008"=>70, "800"=> 70, "635"=> 70, "504"=> 72, "400"=>77)
    duration = 0.099
    gap = 0.001

    elements = []
    odd=1; even=1;
    for i = 1:13
        
        if i % 2 == 0
            f = frequencies[end:-1:1][odd]
            odd += 1
        else
            f = frequencies[even]
            even += 1
        end      
        element = Dict("f"=>freq_to_ERB(f), "l"=>levels[string(f)], 
                        "d"=>duration, "g"=>gap)
        push!(elements, element)
        
    end
    
    return elements
    
end

function bregman_rudnicky(standard, comparison, captor)
    
    #(0.1s sil)AB(1s sil)CCCDABDCC
    A = Dict("f" => 2200, "l" => 60, "t"=>"A")
    B = Dict("f" => 2400, "l" => 60, "t"=>"B")
    D = Dict("f" => 1460, "l" => 65, "t"=>"D")
    duration = 0.057
    gap_target = 0.008
    
    C = Dict( "far" => Dict("f" => 590, "l" => 63, "t"=>"C"),
              "mid" => Dict("f" => 1030, "l" => 60, "t"=>"C"),
              "near" => Dict("f" => 1460, "l" => 65, "t"=>"C"))
    gap_captors = 0.130 
    
    standard_tones = standard == "up" ? [A, B] : [B, A]
    captor_tones_before = captor == "none" ? [] : [C[captor], C[captor], C[captor]]
    comparison_tones = comparison == "up" ? [D, A, B, D] : [D, B, A, D]
    captor_tones_after = captor == "none" ? [] : [C[captor], C[captor]]
    
    elements = []
    for i = 1:length(standard_tones)
        tone = standard_tones[i]
        element = Dict("f"=>freq_to_ERB(tone["f"]), "l"=>tone["l"],
                    "d"=>duration,"g"=>gap_target);
        push!(elements, element)
    end
    
    separating_silence = 1.0
    if captor == "none"
        separating_silence += 3*(duration + gap_captors)
    else
        for i = 1:length(captor_tones_before)
            tone = captor_tones_before[i]
            gap = i == 1 ? separating_silence : gap_captors
            element = Dict("f"=>freq_to_ERB(tone["f"]), "l"=>tone["l"],
                        "d"=>duration,"g"=>gap);
            push!(elements, element)        
        end
    end
    
    for i = 1:length(comparison_tones)
        tone = comparison_tones[i]
        gap = i == 1 ? (captor == "none" ? separating_silence : gap_captors) : gap_target
        element = Dict("f"=>freq_to_ERB(tone["f"]), "l"=>tone["l"],
            "d"=>duration,"g"=>gap);
        push!(elements, element) 
    end
        
    if captor != "none"
        for i = 1:length(captor_tones_after)
            tone = captor_tones_after[i]
            element = Dict("f"=>freq_to_ERB(tone["f"]), "l"=>tone["l"],
                        "d"=>duration,"g"=>gap_captors);
            push!(elements, element)        
        end
    end
    
    return elements 
    
end

function ABA(semitones, onset_difference)
    
    #onset_difference: 60 - 800 ms
    #semitones: -15 to +15 
    A_freq = 1000
    B_freq = A_freq*(2. ^(semitones/12.));
    duration = 0.040
    g = onset_difference - duration
    level = 70
    reps = 5
    
    elements = []
    for r = 1:reps
        A_gap = g + ( r == 1 ? 0 : onset_difference )
        push!(elements, Dict("f"=>freq_to_ERB(A_freq), "l"=>level, 
                "g"=>A_gap, "d"=>duration))
        push!(elements, Dict("f"=>freq_to_ERB(B_freq), "l"=>level,
                            "g"=>g, "d"=>duration))
        push!(elements, Dict("f"=>freq_to_ERB(A_freq), "l"=>level,
                    "g"=>g, "d"=>duration))
    end
    
    return elements
    
end

In [None]:
function perfect_initialization(demofunc, demoargs;plot_trace=false)
    
    source_params, steps, gtg_params, obs_noise = include("./base_params.jl")
    constraints, scene_duration = initialize_tone_sequences(demofunc(demoargs...), [0.051, 0.050], steps)
    #println(constraints)
    audio_sr = 20000; 
    wts, f = gtg_weights(audio_sr, gtg_params)
    args = (source_params, float(scene_duration), wts, steps, Int(audio_sr), obs_noise, gtg_params)
    #No scene constraints yet. 
    single_source_trace, _ = generate(generate_scene, args, constraints);
    
    #Only switch proposals to mix up the sources
    trace = single_source_trace
    for i = 1:50
        (fwd_choices, fwd_score, fwd_ret) = propose(switch_randomness, (trace,))
        (new_trace, bwd_choices, weight) = switch_involution(trace, fwd_choices, fwd_ret, ());
        if ~isinf(get_score(new_trace))
            trace = new_trace
        end
    end
    
    if plot_trace
        scene_gram, t, scene_wave, source_waves, element_waves=get_retval(trace)
        #plot_gtg(scene_gram, scene_duration, audio_sr/2, 20, 100)
        plot_sources(trace, scene_gram, 0)
    end
    
    return trace, constraints, args
    
end

In [None]:
trace, constraints, args = perfect_initialization(ABA,(1,0.10));

In [None]:
observation_trace = trace
function get_initial_trace(observation_trace, constraints, args)
    #Add the scene constraint to be what the initial demo specification gave
    constraints[:scene] = observation_trace[:scene]
    initial_trace, _ = generate(generate_scene, args, constraints)
    plot_sources(initial_trace, observation_trace[:scene], 0, save=false)
    return initial_trace
end
initial_trace = get_initial_trace(observation_trace, constraints, args);

In [None]:
function make_switch_proposal(initial_trace, observation_trace, show_info)
    (fwd_choices, fwd_score, fwd_ret) = propose(switch_randomness, (initial_trace,))
    (new_trace, bwd_choices, weight) = switch_involution(initial_trace, fwd_choices, fwd_ret, ())
    (bwd_score, _) = assess(switch_randomness, (new_trace,), bwd_choices)
    alpha = weight - fwd_score + bwd_score
    if show_info
        println((weight - fwd_score + bwd_score))
        println("Weight: $(weight), Fwd: $(fwd_score), Bwd: $(bwd_score)")
        plot_sources(new_trace, observation_trace[:scene], 1, save=false)
        println(fwd_choices)
        println("Backward choices")
        println(bwd_choices)
    end
    return new_trace, alpha
end
new_trace, alpha = make_switch_proposal(initial_trace, observation_trace, true);

In [None]:
function check_score(initial_trace, new_trace, observation_trace)
    println("Looking at each part of the score:");println("")
    source_params, scene_duration, wts, steps, audio_sr, obs_noise, gtg_params = get_args(initial_trace)
    #@assert initial_trace[:noise] == new_trace[:noise]
    #noise_value = initial_trace[:noise]
    # old_scene, _, _ = get_retval(initial_trace)
    # new_scene, _, _ = get_retval(new_trace)
    ## NOTE: if you use "old_scene" or "new_scene" instead of the value from the trace, the likelihood does NOT come out equal!!
    noise_value = obs_noise["val"]
    old_likelihood=Gen.logpdf(noisy_matrix,initial_trace[:scene],observation_trace[:scene],noise_value)
    new_likelihood=Gen.logpdf(noisy_matrix,new_trace[:scene],observation_trace[:scene],noise_value)
    println("Old likelihood: $old_likelihood, New likelihood: $new_likelihood")
    @assert old_likelihood == new_likelihood
    println("Likelihoods are equal.")

    println(""); println("Score of random choices:")
    for trace_number = 1:2
        trace = trace_number == 1 ? initial_trace : new_trace
        trace_name = trace_number == 1 ? "Initial" : "New"
        
        total = 0
        for i = 1:trace[:n_sources]
            single = project(trace, select(:source => i))
            total += single
            println("$trace_name Trace -- source $i: $single")
            for gptype in [:erb, :amp]
                for a in [:mu, :scale, :sigma, :noise]
                    sa = string(a)
                    va = round(trace[:source => i => gptype => a], digits=2)
                    println("\t$gptype $sa $va: ", project(trace, select(:source => i => gptype => a)))
                end
                println()
            end
            for tptype in [:wait,:dur_minus_min]
                for a in [:a, :mu]
                    sa = string(a)
                    va = round(trace[:source => i => tptype => a], digits=2)
                    println("\t$tptype $sa $va: ", project(trace, select(:source => i => tptype => a)))
                end
                println()
            end
            n_elements= trace[:source => i => :n_elements]
            println("\tn_elements $n_elements: ", project(trace, select(:source => i => :n_elements)))
            for j = 1:trace[:source => i => :n_elements]
                println("\tTone $j")
                for a = [:wait, :dur_minus_min, :erb, :amp]
#                     println(trace[:source => i => (:tone, j) => a])
                    aval = round.(trace[:source => i => (:element, j) => a], digits=3)
                    println("\t\t", a, "= $aval: ", project(trace, select(:source => i => (:element, j) => a)))
                end
            end
        end
        println("$trace_name Trace -- total: $total");println("")
            
    end

end
check_score(initial_trace, new_trace, observation_trace)

In [None]:
proposals = ["ns","nt","gpParam","tpParam","Wait","Dur","GpVal", "GpAll", "Spl/Mrg","Swh"]; 
n_proposals = length(proposals)
accept_counts = zeros(n_proposals); proposal_counts = zeros(n_proposals);

for i = 1:10
    print("$i ")
    trace, accept_counts, proposal_counts = mcmc_update(trace, accept_counts, proposal_counts)
    scene_gram, t, scene_wave, source_waves, element_waves=get_retval(trace)
    plot_sources(trace, scene_gram, i)
end