# Simple Bayesian Auditory Scene Analysis 

We are going to implement a simple model which can organize tones into streams, inferring the perceptual organization from the cochleagram. In order to do approximate posterior inference, we will implement a set of MCMC moves that can operate on this relatively complex trace (as compared to the previous models we looked at).

In [None]:
using Gen
using PyPlot
using Statistics: mean, std;

#We'll need a few extra functions which are stored in julia scripts
#for generating soundwaves from latent varibles
include("./rendering.jl") 
#for our likelihood function & some special proposal distributions
include("./extra_distributions.jl")
#to generate our cochleagram
include("./gammatonegram.jl")
#to visualize the trace
include("./plotting.jl");

## Generative Model

We'll write out the generative model one level of hierarchy at a time. We'll start with a single tone. Write out the following priors:
- :gap ~ gamma(2, 0.1)
- :duration ~ gamma(2, 0.1) + 0.050
- :f0 ~ normal(mu, 3.0)

In [None]:
#= 
We'll use a minimum duration for our tones so that they're always audible
"@dist" allows the user to concisely define a distribution, 
as long as that distribution can be expressed as a certain type 
of deterministic transformation of an existing distribution. 
=#
@dist dur_dist(shape, scale, min_duration) = gamma(shape, scale) + min_duration

@gen function tone_latents(mu::Float64)
    
    gap = #TO DO
    tone_duration = #TO DO
    f0 = #TO DO
    
    return gap, tone_duration, f0
    
end;

We'll write a function which samples an entire tone_sequence. How should we sample from tone_latents if we want the addresses to look like `(:tone, 1) => :mu`?

In [None]:
@gen function tone_sequence_latents(mu::Float64, n_tones::Int)
    
    ## Generate tones
    # First collect all the tone-level latent variables
    onsets = Vector{Float64}(undef, n_tones)
    offsets = Vector{Float64}(undef, n_tones)
    f0s = Vector{Float64}(undef, n_tones)
    for tone_idx=1:n_tones
        gap, tone_duration, f0s[tone_idx] = #TO DO
        onsets[tone_idx] = (tone_idx == 1 ? 0 : offsets[tone_idx - 1]) + gap; 
        offsets[tone_idx] = onsets[tone_idx] + tone_duration; 
    end

    return onsets, offsets, f0s
    
end;

Write a generative function to sample the source-level latents. Write `part_dist` to create a geometric distribution that begins at 1 rather than 0.

In [None]:
@dist part_dist(lambda) = #TO DO

@gen function source_latents(audio_sr::Int)
    
    # Source Parameters 
    mu = @trace(uniform(freq_to_ERB(20.0), freq_to_ERB(audio_sr/2.)), :mu)
    # Number of tones 
    n_tones = @trace(part_dist(0.2), :n_tones)
    
    return mu, n_tones
    
end;

Now we'll put this all together in a function that both samples a source, its melody, and then generates the actual sound produced by the source. First we sample all of the latents, and then we render our source waveform.

In [None]:
@gen function generate_source(audio_sr::Int, scene_duration::Float64)
    
    ## Generate source-level latent variables 
    mu, n_tones = @trace(source_latents(audio_sr))
    ## Generate tone-level latent variables 
    onsets, offsets, f0s = @trace(tone_sequence_latents(mu, n_tones))
    
    # Then generate the actual sound wave for each element, calling out to the renderer
    element_waves = [generate_tone_wave(f0s[tone_idx], onsets[tone_idx], 
            offsets[tone_idx], audio_sr, scene_duration) for tone_idx = 1:n_tones]

    #Sum over all tones to produce source waveform 
    n_samples = Int(floor(scene_duration*audio_sr));
    source_wave = reduce(+, element_waves; init=zeros(n_samples))

    return source_wave
    
end;

The code above generates a single source. To generate multiple sources, we can use the generative function combinator `Map`. A generative function combination takes one or more Julia functions as input and returns a new generative function. In particular, `Map` returns a new generative function that applies the kernel independently for a vector of inputs.

In [None]:
generate_sources = Map(generate_source);

`generate_sources` takes in input lists and returns a list of `source_waves`, of the same length.

Lastly we create the top-level of the model. At this level, we sample the number of sources `n_sources` and generate a sound for each source `source_waves`. Then, we add up all of the source sounds into a single sound mixture, `scene_wave`. 

Our last step is to generate a gammatonegram `scene_wave`. We will compute our likelihood function `p(observed_sound | scene)` in the gammatonegram domain, so it is equal to `p(observed_cochleagram | scene_cochleagram) = N(observed_cochleagram | mu = scene_cochleagram, sigma = 1.0`

In [None]:
@gen function generate_scene(scene_duration::Float64, audio_sr::Int, 
                        wts::Array{Float64,1}, gtg_params::GammatonegramParams)
    
    # how many sources are there
    n_sources = @trace(part_dist(0.6), :n_sources)
    source_waves = @trace(generate_sources(fill(audio_sr,n_sources),
            fill(scene_duration,n_sources)),:source)

    # sum over all sources to produce scene waveform
    n_samples = Int(floor(scene_duration*audio_sr))
    scene_wave = reduce(+, source_waves; init=zeros(n_samples))
    
    # generate spectrogram from saveform
    scene_gram, t = gammatonegram(scene_wave, wts, audio_sr, gtg_params) 
    
    # add observation noise
    @trace(noisy_matrix(scene_gram, 1.0), :scene)
    
    return scene_gram, scene_wave, source_waves
    
end;

# Show simulated scenes

Let's now try generating some traces. First we prepare the arguments to our generative function, `generate_scene`. 

In [None]:
scene_duration = 2.0
audio_sr = 20000
gtg_params = GammatonegramParams(0.025,0.010, 20.0, 64, 0.50, 1e-6, 1e-80, 20.0)
wts, f = gtg_weights(audio_sr,gtg_params)
args = (scene_duration, audio_sr, wts, gtg_params);

Write code to obtain a scene trace.

In [None]:
trace = simulate(generate_scene, args)
scene_gram, scene_wave, source_waves = get_retval(trace);
println(get_submap(get_choices(trace),:source))

We can use the `plot_gtg` function to look at the scene as a gammetonegram.

In [None]:
 plot_gtg(scene_gram, scene_duration, audio_sr, 20.0, 100.0)

We can use `plot_sources` to look at the source structure of our trace.

In [None]:
plot_sources(trace, scene_gram, 0; save=false)

Now we'll generate several scenes so that we can get a sense of what kind of distribution over scenes that our probabilistic program describes.

In [None]:
figure(figsize=(12, 4))
for i=1:12
    subplot(2, 6, i)
    trace = simulate(generate_scene, args)
    scene_gram, scene_wave, source_waves = get_retval(trace);
    n_sources = trace[:n_sources]
    ax = gca()
    ax[:set_title]("n_sources: $n_sources")
    plot_gtg(scene_gram, scene_duration, audio_sr, 20.0, 100.0)
end
tight_layout()

# MCMC inference

Similar to our process of writing model, we'll write our inference moves by looking at the different levels of model hierarchy.

### Inference for tone level variables

Let's start with inference moves for the tone-level latents. For tones, the latent variables are the `gap`, `duration`, and `frequency`. 
 
We sampled our tones with `gap` and `duration` variables in order to ensure that they didn't overlap. However, we'd rather not shift an entire set of tones if we're trying to change the timing of just a single tone.

So, first we'll write a function to retrieve the "absolute timings" of each tone (i.e. onset and offset) from the relative timings (gap and duration).

In [None]:
function absolute_timing(source_choices)
    #= Get a list of onset/offset pairs for a source =#

    n_tones = source_choices[:n_tones];
    total_time = 0; t = []
    for i = 1:n_tones
        onset = total_time + source_choices[(:tone, i) => :gap]
        duration = source_choices[(:tone, i) => :duration] 
        offset = onset + duration
        push!(t, [onset, offset])
        total_time += source_choices[(:tone, i) => :gap] + source_choices[(:tone, i) => :duration] 
    end
    return t
    
end
absolute_timing(get_submap(get_choices(trace),:source=>1))

Because our model contains `gap` and `duration` but in a sense we actually want to  work with `onset` and `offset`, we need a new kind of function that we haven't seen before: Metropolis-Hastings paired with an involution function.

The involution allows you to make proposals **by sampling different variables, with different addresses than the variables currently in the** `trace`, and then deterministically make an assignment to the model trace addresses, to make a `new_trace`. 

There are few parts of writing an involution proposal. 
1. a generative **randomness function** where you sample the randomness that is necessary for your proposal -- these are called the `fwd_choices`. 
2. a deterministic **involution function** where you transform the proposed random variables into the model trace addresses to make a `new_trace` -- these are called the `new_choices`. 
3. you need to define what the traced values in the **randomness function** would be, if you were using it to go from `new_trace` to `trace` -- these are called the `bwd_choices`.

An involution is a bijection that is its own inverse, meaning these `round_trip_trace` and `trace` will end up equal:

`
fwd_choices ~ randomness(trace)
new_trace, bwd_choices = involution(trace, fwd_choices)
round_trip_trace, round_trip_fwd_choice = involution(new_trace, bwd_choices)
`

Later we'll see how we can use involutions for inference moves that change dimensionality.

##### Gap (or onset) inference move

We'll still propose a `gap` for the chosen tone, but we will make other deterministic adjustments to ensure that only the `onset` of the chosen tone changes, and nothing else.

In [None]:
@gen function onset_randomness(trace, source_idx::Int, tone_idx::Int)
    
    source_choices = get_submap(get_choices(trace),:source=>source_idx)
    abs_t = absolute_timing(source_choices)
    
    old_gap = trace[:source => source_idx => (:tone, tone_idx) => :gap]
    old_duration = trace[:source => source_idx => (:tone, tone_idx) => :duration]
    
    last_offset = tone_idx == 1 ? 0 : abs_t[tone_idx-1][2]
    this_offset = abs_t[tone_idx][2]
    
    new_gap = @trace(uniform(0, old_gap + old_duration - 0.050), :proposed_gap)
    new_duration = old_duration + (old_gap - new_gap) 
    
    return old_gap, old_duration, new_gap, new_duration
    
end

function onset_involution(trace, fwd_choices, fwd_ret, proposal_args)

    source_idx, tone_idx = proposal_args
    old_gap, old_duration, new_gap, new_duration = fwd_ret
    
    bwd_choices = choicemap()
    bwd_choices[:proposed_gap] = old_gap
    
    new_choices = choicemap()
    new_choices[:source => source_idx => (:tone, tone_idx)=>:gap] = new_gap
    new_choices[:source => source_idx => (:tone, tone_idx)=>:duration] = new_duration
    
    new_trace, weight = update(trace, get_args(trace), (), new_choices)
    return new_trace, bwd_choices, weight
    
end

To look at what the onset involution proposal does, we'll pull apart the "behind-the-scenes" involution code:

In [None]:
function run_involution(randomness, involution, randomness_args, init_trace; print_choices=false)

    (fwd_choices, fwd_score, fwd_ret) = propose(randomness, (init_trace, randomness_args...))
    (new_trace, bwd_choices, weight) = involution(init_trace, fwd_choices, fwd_ret, randomness_args)
    (bwd_score, _) = assess(randomness, (new_trace,randomness_args...), bwd_choices)
    a = weight - fwd_score + bwd_score
    println("Weight: $(weight) - Fwd: $(fwd_score) + Bwd: $(bwd_score) = $a")
    if print_choices
        println("Forward choices:")
        println(fwd_choices)
        println("Backward choices")
        println(bwd_choices)
    end
    return fwd_choices, bwd_choices, new_trace

end

In [None]:
# plot_sources(trace, scene_gram, 0; save=false)
fwd_choices, bwd_choices, new_trace = run_involution(onset_randomness, onset_involution, (1,1,), trace; print_choices=true)
plot_sources(new_trace, scene_gram, 1; save=false)

##### Duration (or offset) inference move

We will fill in each `@trace` statement in `fwd_choices`, the `bwd_choices`, and `new_choices`.

In [None]:
@gen function duration_randomness(trace, source_idx::Int, tone_idx::Int)
    
    scene_duration, audio_sr, wts, gtg_params = get_args(trace)
    
    source_choices = get_submap(get_choices(trace),:source=>source_idx)
    n_tones = source_choices[:n_tones]
    abs_t = absolute_timing(source_choices)
    
    old_duration = trace[:source => source_idx => (:tone, tone_idx)=>:duration]
    
    if tone_idx == n_tones
        
        #if a tone is at the end of a sequence
        new_duration = #TO DO- sample :proposed_duration from the duration prior
    
        return false, old_duration, false, new_duration
    
    else
    
        old_gap = trace[:source => source_idx => (:tone, tone_idx + 1) => :gap]
        next_onset = abs_t[tone_idx + 1][1]
        this_onset = abs_t[tone_idx][1]
        new_duration = #TO DO- sample :proposed_duration from a uniform distribution
        new_gap = old_gap + (old_duration - new_duration)
        
        return old_gap, old_duration, new_gap, new_duration 
        
    end
    
end

function duration_involution(trace, fwd_choices, fwd_ret, proposal_args)
    
    source_idx, tone_idx = proposal_args
    old_gap, old_duration, new_gap, new_duration = fwd_ret
    
    bwd_choices = choicemap()
    new_choices = choicemap()
    if old_gap == false 
    
        bwd_choices[:proposed_duration] = #TO DO 
        new_choices[:source => source_idx => (:tone, tone_idx) => :duration] = #TO DO
        
    else
        
        bwd_choices[:proposed_duration] = #TO DO

        new_choices[:source => source_idx => (:tone, tone_idx + 1) => :gap] = #TO DO
        new_choices[:source => source_idx => (:tone, tone_idx) => :duration] = #TO DO
        
    end
    
    new_trace, weight = update(trace, get_args(trace), (), new_choices)
    return new_trace, bwd_choices, weight
    
end

In [None]:
# plot_sources(trace, scene_gram, 0; save=false)
fwd_choices, bwd_choices, new_trace = run_involution(duration_randomness, duration_involution, (1,1,), trace; print_choices=true)
plot_sources(new_trace, scene_gram, 1; save=false)

##### Tone frequency inference move

For the frequency of a tone, we'll write a Gaussian Drift proposal. 

In [None]:
@gen function tonefreq_gaussian_drift(trace, source_idx::Int, tone_idx::Int)
    f = trace[:source => source_idx => (:tone, tone_idx) => :f0]
    @trace(normal(f, 1.0), :source => source_idx => (:tone, tone_idx) => :f0)
end;

In [None]:
(proposal_choices, _, _) = propose(tonefreq_gaussian_drift, (trace, 1, 1,))
(new_trace, _, _, _) = update(trace, args, (),proposal_choices)
plot_sources(new_trace, scene_gram, 1; save=false)

### Inference for source level variables

##### Proposal for mean frequency of a source

We can similarly write a drift proposal for the mean frequency of a source. We use a `truncated_normal` distribution (defined in `extra_distributions.jl`) because `:mu` is originally drawn from a `uniform` distribution that has zero probability density outside certain bounds:

In [None]:
@gen function mu_gaussian_drift(trace, source_idx::Int)
    scene_duration, audio_sr, wts, gtg_params = get_args(trace)
    f = trace[:source => source_idx => :mu]
    @trace(truncated_normal(f, 4., freq_to_ERB(20.0), freq_to_ERB(audio_sr)), 
        :source => source_idx => :mu)
end;

##### Proposal for number of tones in a source

For changing the number of tones in a source, we will also write an inference move. It would be possible to use an ordinary `mh` move, i.e.:

`trace, accepted = mh(trace, select(:source => source_idx => :n_tones))`

however, this would only ever remove or add tones at the end of a sequence.

In order to choose to remove or add a tone in the middle of a sequence, we need to write an involution move because it requires making a new random choice: where in the middle of the sequence should we add/remove a tone? (Note that we'll restrict the change in the number of tones to plus or minus 1).

In a sense, this is not solely a source level inference move because we need to also choose the tone-level variables of the new tone that is added.

In [None]:
@gen function ntones_randomness(trace, source_idx::Int)

    source_choices = get_submap(get_choices(trace),:source=>source_idx)
    n_tones = source_choices[:n_tones]
    
    #To keep it simple, we'll never get rid of a source by getting rid of a tone
    p = n_tones == 1 ? 1.0 : 0.5
    add_tone = @trace(bernoulli(p),:add_tone)
    
    if add_tone
        
        #since we have a minimum value for the duration of a tone,
        #find a gap in the sequence where the tone can fit 
        enough_space = [source_choices[(:tone, i) => :gap] > 0.050 for i = 1:n_tones]
        #add an extra possibility to add a tone onto the end
        push!(enough_space, true)
        enough_space = Int.(enough_space)
        enough_space = enough_space ./ sum(enough_space)
        
        #sample the index of the tone that will be added
        add_idx = @trace(categorical(enough_space), :add_idx)
        
        if add_idx <= n_tones
            old_gap = source_choices[(:tone,add_idx)=>:gap]
            add_duration = @trace(uniform(0.050, old_gap), :add_duration)
            add_gap = @trace(uniform(0, old_gap-add_duration), :add_gap)
        elseif add_idx == n_tones + 1
            add_gap = @trace(gamma(2,0.1), :add_gap)
            add_duration = @trace(dur_dist(2,0.1,0.050), :add_duration)
        end
        add_frequency = @trace(normal(source_choices[:mu], 3.0), :add_frequency)
        
    else #delete tone
        
        delete_idx = @trace(uniform_discrete(1,n_tones), :delete_idx)
        
    end
    
    return add_tone
    
end

The assignment of `new_choices` in the `involution` requires shifting the indexes of the tones, without changing their latent variables (eg. see lines 33 and 57, using `set_submap`). Fill in the `bwd_choices`.

In [None]:
function ntones_involution(trace, fwd_choices, fwd_ret, proposal_args)
        
    source_idx, = proposal_args
    add_tone = fwd_ret
    old_choices = get_submap(get_choices(trace),:source => source_idx)
    abs_t = absolute_timing(old_choices)
    
    bwd_choices = choicemap()
    new_choices = choicemap()  
    bwd_choices[:add_tone] = ~add_tone
    
    if add_tone
        
        ##Add_element
        new_choices[:source => source_idx => :n_tones] = old_choices[:n_tones] + 1

        #Add in new timing info 
        add_idx = fwd_choices[:add_idx]
        new_choices[:source => source_idx => (:tone, add_idx) => :gap] = fwd_choices[:add_gap]
        new_choices[:source => source_idx => (:tone, add_idx) => :duration] = fwd_choices[:add_duration]
        new_choices[:source => source_idx => (:tone, add_idx) => :f0] = fwd_choices[:add_frequency]
         
        #If there are any elements after the added element: 
        if add_idx != old_choices[:n_tones] + 1
            #Need to change the "gap" of the tone immediately after the add_element, in order to retain the same onset
            new_choices[:source => source_idx => (:tone, add_idx + 1) => :gap] = old_choices[(:tone, add_idx) => :gap] - fwd_choices[:add_gap] - fwd_choices[:add_duration]
            new_choices[:source => source_idx => (:tone, add_idx +1) => :duration] = old_choices[(:tone, add_idx) => :duration]
            new_choices[:source => source_idx => (:tone, add_idx + 1) => :f0] = old_choices[(:tone, add_idx) => :f0]

            #All elements after (the element immediately following the add_telement) are simply shifted in index
            old_idxs = add_idx + 1 == old_choices[:n_tones] + 1 ? [] : add_idx+1:old_choices[:n_tones]
            for old_tone_idx = old_idxs
                set_submap!(new_choices, :source => source_idx => (:tone, old_tone_idx + 1), get_submap(old_choices,(:tone, old_tone_idx)))
            end
            
        end
                                                            
        bwd_choices[:delete_idx] = add_idx
        
        
    else
        
        new_choices[:source => source_idx => :n_tones] = old_choices[:n_tones] - 1
        delete_idx = fwd_choices[:delete_idx]
        
        #If the deleted tone is in the middle of a sequence:
        if delete_idx < old_choices[:n_tones]
            #Adjust the wait time of the element following the one that was removed
            last_offset = delete_idx > 1 ? abs_t[delete_idx - 1][2] : 0
            new_choices[:source => source_idx => (:tone, delete_idx) => :gap] = abs_t[delete_idx+1][1] - last_offset
            new_choices[:source => source_idx => (:tone, delete_idx) => :duration] = old_choices[(:tone, delete_idx + 1) => :duration]
            new_choices[:source => source_idx => (:tone, delete_idx) => :f0] = old_choices[(:tone, delete_idx + 1) => :f0]
        
            #All the elements after the one immediately following the deleted element must shift their indices 
            old_idxs = delete_idx == old_choices[:n_tones] - 1 ? [] : (delete_idx + 2):old_choices[:n_tones]
            for old_tone_idx = old_idxs  
                set_submap!(new_choices, :source => source_idx => (:tone, old_tone_idx-1), get_submap(old_choices,(:tone, old_tone_idx)))
            end
        end
                
        #Define backwards choice by putting remove_elementback in                                                                
        bwd_choices[:add_idx] = delete_idx        
        bwd_choices[:add_gap] = old_choices[(:tone, delete_idx) => :gap]  
        bwd_choices[:add_duration] = old_choices[(:tone, delete_idx) => :duration] 
        bwd_choices[:add_frequency] = old_choices[(:tone, delete_idx) => :f0]   
        
    end
    
    new_trace, weight = update(trace, get_args(trace), (), new_choices)
    return new_trace, bwd_choices, weight
    
end

In [None]:
# plot_sources(trace, scene_gram, 0; save=false)
fwd_choices, bwd_choices, new_trace = run_involution(ntones_randomness, ntones_involution, (1,), trace; print_choices=false)
plot_sources(new_trace, scene_gram, 1; save=false)

##### Proposal to switch a tone from one source to another

There is no sampled latent variable that corresponds to the choice "the `tone` with this {`gap`, `duration`, `frequency`} belongs to `source` 1". Rather, this is implied by the structure of the generative function (and therefore `trace`). 

Therefore, in order to move a `tone` from one `source` to another, we need to use an `involution`. We'd also like to be able to move a tone into a brand new `source`, which requires changing dimensionality. 

You can see that ensuring that the involution is its own inverse requires accounting for the different cases of how dimensionality could change (both in the sense of restricting the random choices and writing several parts to the involution):
1. If the origin source had only one tone in it, the switched tone will always move into an existing source. `n_sources` decreases by 1
2. A tone can be moved to a new source if the origin source has more than one tone. `n_sources` increases by 1.
3. The tone is switched between two existing sources. `n_sources` remains the same.

In [None]:
@gen function switch_randomness(trace)
    
    scene_duration, audio_sr, wts, gtg_params = get_args(trace)
    old_choices = get_choices(trace)
    
    #onset/offset information for each element in each source: all source timings
    all_source_timings = []
    old_n_sources = old_choices[:n_sources]
    for i = 1:old_n_sources
        #list of lists of times
        #[ (element 1)[onset, offset], (element2)[onset, offset], ... ]
        old_abs_timings = absolute_timing(get_submap(old_choices, :source => i))
        push!(all_source_timings,old_abs_timings) 
    end
    
    origin = @trace(uniform_discrete(1, old_n_sources), :origin)
    old_n_elements = old_choices[:source => origin => :n_tones]
    element_idx = @trace(uniform_discrete(1,old_n_elements),:element_idx)
    onset = all_source_timings[origin][element_idx][1];
    offset = all_source_timings[origin][element_idx][2];
    
    #Find the sources into which a element can be switched 
    source_switch = []; which_spot = [];
    for i = 1:old_choices[:n_sources]
        if i == origin 
            append!(source_switch, 0); append!(which_spot, 0)
        else
            source_nt = old_choices[:source => i => :n_tones]
            timings = all_source_timings[i];
            for j = 1:source_nt + 1

                #switch into spot before the first element
                if j == 1 
                    fits = (0 < onset) && (offset < timings[j][1])
                elseif 1 < j <= source_nt
                    fits = (timings[j-1][2] < onset) && (offset < timings[j][1])
                elseif j == source_nt + 1
                    fits = (timings[j - 1][2] < onset)
                end

                if fits
                    append!(source_switch, 1)
                    append!(which_spot, j)
                    break
                elseif j == source_nt + 1
                    append!(source_switch, 0)
                    append!(which_spot, 0)
                end

            end
        end
    end
    
    #Decide whether to move the element into an existing source
    #Or make a new source, where it will be the only element
    switch_to_existing_source = sum(source_switch)
    switch_to_new_source = old_n_elements > 1 ? 1 : 0     
    switch_weights = [switch_to_existing_source, switch_to_new_source]
    if sum(switch_weights) == 0
        return "abort"
    end
    switch_weights = [switch_to_existing_source, switch_to_new_source]

    ps = switch_weights./sum(switch_weights)
    new_source = @trace(bernoulli(ps[2]), :new_source)
    #Decide the idx of the destination source
    #If it's a new source, it can go before any of the old sources or at the end
    #If it's an old source, you need to choose from the ones in source_switch
    destination_ps = new_source ? fill(1/(old_choices[:n_sources] + 1), old_choices[:n_sources] + 1) : source_switch./sum(source_switch)
    destination = @trace(categorical(destination_ps), :destination)
    
    ## change the means of the sources, so that there's the best possibility for acceptance
    if ~new_source                                                                                            
        dest_f0s = [trace[:source => destination => (:tone, j) => :f0] for j = 1:old_choices[:source => destination => :n_tones]]
    else
        dest_f0s = []
    end
    push!(dest_f0s, old_choices[:source => origin => (:tone, element_idx) => :f0])
    m = mean(dest_f0s); s = length(dest_f0s) > 1 ? std(dest_f0s) : 3.0; l = freq_to_ERB(20.0); u = freq_to_ERB(audio_sr/2.);
    @trace(truncated_normal(m, s, l, u), :dest => :mu)                                                                                                               
                                                                                                                    
    if old_choices[:source => origin => :n_tones] > 1  
        orig_f0s = [trace[:source => origin => (:tone, j) => :f0] for j in 1:trace[:source => origin => :n_tones] if j != element_idx]
        m = mean(orig_f0s); s = length(orig_f0s) > 1 ? std(orig_f0s) : 3.0; 
        @trace(truncated_normal(m, s, l, u), :orig => :mu) 
    end 
                                                                                                                                                                                                                                      
    return which_spot, all_source_timings
               
end                                                                                                                                    
                                                    
function switch_involution(trace, fwd_choices, fwd_ret, proposal_args)
    
    if fwd_ret == "abort"
        return trace, fwd_choices, 0
    end
                    
    #we need to specify how to go backwards
    #and how to construct the new trace
    bwd_choices = choicemap()
    new_choices = choicemap()
    which_gaps = fwd_ret[1]; all_source_timings = fwd_ret[2];
    
    scene_duration, audio_sr, wts, gtg_params = get_args(trace)
    old_choices = get_choices(trace); 
    old_n_sources = old_choices[:n_sources]; 
    
    ## indexes for moving a element from origin to destination source
    origin_idx = fwd_choices[:origin]
    new_source = fwd_choices[:new_source]
    destination_idx = fwd_choices[:destination]
    element_switch_idx = fwd_choices[:element_idx]
                                                                                                                             
    old_origin_nt = old_choices[:source => origin_idx => :n_tones]
    old_destination_nt = new_source ? 0 : old_choices[:source => destination_idx => :n_tones]
    which_gap = new_source ? 1 : which_gaps[destination_idx]
    element_attributes = [:gap, :duration, :f0]
    element_attributes_no_wait = [:duration, :f0]

    ##Get all the properties of the switch element in the new source
    switch_element = Dict()
    #absolute onset and offset stay the same, so do duration and gps
    switch_element[:onset] = all_source_timings[origin_idx][element_switch_idx][1] 
    switch_element[:offset] = all_source_timings[origin_idx][element_switch_idx][2]         
    switch_element[:duration]= old_choices[:source => origin_idx => (:tone, element_switch_idx) => :duration]
    switch_element[:f0]=old_choices[:source => origin_idx => (:tone, element_switch_idx) => :f0]
    #gap depends on what is before the switch_element in the destination stream
    prev_offset = (which_gap == 1) ? 0 : all_source_timings[destination_idx][which_gap - 1][2] 
    switch_element[:gap] = switch_element[:onset] - prev_offset;

    ##compute new GAPS OF elementS FOLLOWING SWITCH element, in both destination and origin
    #inserting switch_element before a element in destination source
    if which_gap <= old_destination_nt
        dest_after_gap = all_source_timings[destination_idx][which_gap][1] - switch_element[:offset]
    end
    #removing switch_element before a element in origin source
    if element_switch_idx < old_origin_nt
        prev_offset = element_switch_idx == 1 ? 0 : all_source_timings[origin_idx][element_switch_idx - 1][2]
        orig_after_gap = all_source_timings[origin_idx][element_switch_idx + 1][1] - prev_offset
    end
    
    if old_origin_nt == 1 
        # If the origin stream had only one element in it, it should be removed
        # The switch_element will not be moved into a new stream,
        # so n_sources should always decrease by 1
        
        # new_source = false
        # destination_idx chooses an existing source 
        
        # if destination_idx is larger than origin_idx
        # the idx of the destination source needs to be shifted down one
        # and any sources after the destination source need to be shifted down one
        
        # if destination_idx is smaller than origin_idx
        # the idx of the destination source can remain the same, 
        # but others may be shifted down one 

        new_choices[:n_sources] = old_n_sources - 1
        #Get indexes of old sources that must be changed 
        new_idx = 1:(old_n_sources - 1)
        #the OLD labels after the origin index are shifted up one because their new labels will be one smaller
        old_idx = [(n >= origin_idx ? (n + 1) : n) for n in new_idx] 
        #Only need to change old indexes that are greater than or equal to the origin idx
        #Get rid of (old_idx < origin_idx) because those won't change...
        #...as well as the destination index, which will be treated on its own
        matching_new_idx = [new_idx[i] for i in 1:length(new_idx) if (old_idx[i] != destination_idx)]
        old_idx = [o for o in old_idx if (o != destination_idx)]
        
        ##Shift sources that do not change
        for i = 1:length(old_idx)                        
            set_submap!(new_choices, :source=>matching_new_idx[i], get_submap(old_choices,:source=>old_idx[i]))                        
        end
        
        ##Deal with destination source specifically
        old_destination_idx = destination_idx
        new_destination_idx = old_destination_idx > origin_idx ? old_destination_idx - 1 : old_destination_idx 
        old_nt = old_choices[:source=>old_destination_idx=>:n_tones]
                                                
        #Get source attributes
        new_choices[:source => new_destination_idx => :n_tones] = old_nt + 1
        new_choices[:source => new_destination_idx => :mu] = fwd_choices[:dest => :mu]

        #Switch element
        for a in element_attributes
            new_choices[:source => new_destination_idx => (:tone, which_gap) => a] = switch_element[a]
        end
                                
        #All elements before the switch_element stay the same
        if which_gap > 1
            for j = 1:which_gap - 1
                set_submap!(new_choices, :source => new_destination_idx => (:tone, j), get_submap(old_choices,:source=>old_destination_idx=>(:tone,j)))
            end
        end
        #If there are any elements after the switch_element they must be increased in index by one
        if which_gap <= old_nt #comes before one of the old elements
            for new_element_idx = (which_gap + 1):(old_nt + 1)
                new_choices[:source => new_destination_idx => (:tone, new_element_idx) => :gap] = (new_element_idx == (which_gap + 1)) ? dest_after_gap : old_choices[:source => old_destination_idx => (:tone,new_element_idx-1) => :gap]
                for a in element_attributes_no_wait
                    new_choices[:source => new_destination_idx => (:tone, new_element_idx) => a] = old_choices[:source=>old_destination_idx=>(:tone,new_element_idx-1)=>a]
                end 
            end
        end

        bwd_choices[:origin] = new_destination_idx
        bwd_choices[:new_source] = true

        bwd_choices[:dest => :mu] = old_choices[:source => origin_idx => :mu]
        bwd_choices[:orig => :mu] = old_choices[:source => destination_idx => :mu]

        bwd_choices[:destination] = origin_idx
        bwd_choices[:element_idx] = which_gap
            
    elseif new_source
        # we put the element in a new stream 
        # we keep the origin stream as well
        # so n_sources increases by 1
        # need to shift all the sources after the destination_idx
                                
        new_choices[:n_sources] = old_n_sources + 1
        
        ##Create new destination source with a single element in it
        new_choices[:source => destination_idx => :n_tones] = 1
        new_choices[:source => destination_idx => :mu] = fwd_choices[:dest => :mu]

        for a in element_attributes
            new_choices[:source => destination_idx => (:tone, 1) => a] = switch_element[a]
        end

        ##in origin source, move all elements down one index if they're after the switch index
        old_origin_idx = origin_idx
        new_origin_idx = origin_idx >= destination_idx ? origin_idx + 1 : origin_idx
        old_nt = old_choices[:source => origin_idx => :n_tones]
        new_choices[:source => new_origin_idx => :n_tones] = old_nt - 1
        new_choices[:source => new_origin_idx => :mu] = fwd_choices[:orig => :mu]

        if element_switch_idx > 1
            for j = 1:element_switch_idx - 1
                set_submap!(new_choices, :source => new_origin_idx => (:tone, j), get_submap(old_choices,:source=>old_origin_idx=>(:tone,j)))
            end
        end                        
        if element_switch_idx < old_nt
            for old_element_idx = (element_switch_idx + 1):old_nt
                new_choices[:source => new_origin_idx => (:tone, old_element_idx-1)=>:gap] = (old_element_idx == (element_switch_idx + 1)) ? orig_after_gap : old_choices[:source => origin_idx => (:tone, old_element_idx)=> :gap]
                for a in element_attributes_no_wait
                    new_choices[:source => new_origin_idx => (:tone, old_element_idx-1) => a] = old_choices[:source=>origin_idx=>(:tone,old_element_idx)=>a]
                end 
            end
        end
                                
        ##shift all sources after destination_idx up one
        if destination_idx < new_choices[:n_sources]
            shift_idxs = [i for i in (destination_idx+1):new_choices[:n_sources] if i != new_origin_idx]
            for i in shift_idxs
                set_submap!(new_choices, :source=>i, get_submap(old_choices, :source=>i-1))
            end
        end
                                
        bwd_choices[:origin] = destination_idx
        bwd_choices[:new_source] = false  

        bwd_choices[:dest => :mu] = old_choices[:source => origin_idx => :mu]

        bwd_choices[:destination] = new_origin_idx
        bwd_choices[:element_idx] = 1
            
    else
        # we put the element in an old stream, and keep the origin stream
        # streams do not have to be shifted 
        # new_source = false
        ##in origin source, move all elements to earlier index if they're after the switch index
        old_nt = old_choices[:source => origin_idx => :n_tones]
        new_choices[:source => origin_idx => :n_tones] = old_nt - 1
        new_choices[:source => origin_idx => :mu] = fwd_choices[:orig => :mu]

        if element_switch_idx < old_nt
            for old_element_idx = (element_switch_idx + 1):old_nt
                new_choices[:source => origin_idx => (:tone, old_element_idx-1)=>:gap] = (old_element_idx == (element_switch_idx + 1)) ? orig_after_gap : old_choices[:source => origin_idx => (:tone, old_element_idx)=>:gap]
                for a in element_attributes_no_wait
                    new_choices[:source => origin_idx => (:tone, old_element_idx-1) => a] = old_choices[:source=>origin_idx=>(:tone,old_element_idx)=>a]
                end 
            end
        end
            
        ##in destination source, insert element and then shift elements to later index
        old_nt = old_choices[:source => destination_idx => :n_tones]
        new_choices[:source => destination_idx => :n_tones] = old_nt + 1
        new_choices[:source => destination_idx => :mu] = fwd_choices[:dest => :mu]

        #Switch element
        for a in element_attributes
            new_choices[:source => destination_idx => (:tone, which_gap) => a] = switch_element[a]
        end
        #elements after switch_element
        if which_gap <= old_nt
            for new_element_idx = (which_gap + 1):(old_nt + 1)
                new_choices[:source => destination_idx =>(:tone,new_element_idx)=>:gap] = (new_element_idx == (which_gap + 1)) ? dest_after_gap : old_choices[:source => destination_idx => (:tone, new_element_idx-1)=> :gap]
                for a in element_attributes_no_wait
                    new_choices[:source => destination_idx => (:tone, new_element_idx) => a] = old_choices[:source=>destination_idx=>(:tone,new_element_idx-1)=>a]
                end
            end
        end
            
        bwd_choices[:origin] = destination_idx
        bwd_choices[:destination] = origin_idx
        bwd_choices[:element_idx] = which_gap
        bwd_choices[:new_source] = false

        bwd_choices[:dest => :mu] = old_choices[:source => origin_idx => :mu]
        bwd_choices[:orig => :mu] = old_choices[:source => destination_idx => :mu]

                                            
              
    end
    new_trace, weight = update(trace, get_args(trace), (), new_choices)
    return new_trace, bwd_choices, weight

end

In [None]:
plot_sources(trace, scene_gram, 0; save=false)
fwd_choices, bwd_choices, new_trace = run_involution(switch_randomness, switch_involution, (), trace; print_choices=false)
plot_sources(new_trace, scene_gram, 1; save=false)

##### Proposal to completely add or remove a source 

Like with tones, we could add or remove a source with the simple `mh` function with a `Selection`:

`trace, accepted = mh(trace, select(:n_sources))`

However, this would always add or remove the sources from the end (like we discussed before with regards to tones). However, unlike tones, the indexing of the sources is arbitrary and meaningless. Rather than writing another involution to actually change the number of sources, we'll write a `swap` proposal that will swap one source with the highest-indexed source, so that arbitrary sources can be removed by our black-box `mh + Selection`. 

This move should always be accepted because acceptance ratio should be equal to one. 

In [None]:
@gen function swap_sources_randomness(trace)
    n1 = @trace(uniform_discrete(1, trace[:n_sources]), :to_move)
    return n1
end

function swap_sources_involution(trace, fwd_choices, fwd_ret, proposal_args)
    new_choices = choicemap()
    n1 = fwd_ret
    n_sources = trace[:n_sources]
    set_submap!(new_choices, :source => n1, get_submap(get_choices(trace), :source => n_sources))
    set_submap!(new_choices, :source => n_sources, get_submap(get_choices(trace), :source => n1))
    new_trace, = update(trace, get_args(trace), (), new_choices)
    return new_trace, fwd_choices, 0
end

## Putting it all together

We'll now combine our proposals into a single inference block. You can see that the same basic `mh` operator can be used with different inputs to use the variety of inference moves that we've written here.

In [None]:
function run_proposals(initial_trace, obs_gram, n_blocks)

    trace = initial_trace
    for block_idx = 1:n_blocks
        for source_idx = 1:trace[:n_sources]
            
            n_tones = trace[:source => source_idx => :n_tones]
            for tone_idx = 1:n_tones
                trace, accepted = mh(trace, onset_randomness, (source_idx, tone_idx,), onset_involution)
                trace, accepted = mh(trace, duration_randomness, (source_idx,tone_idx,), duration_involution)
                trace, accepted = mh(trace, tonefreq_gaussian_drift, (source_idx, tone_idx,)) 
            end
            
            for rep_idx = 1:10
                trace, accepted = mh(trace, mu_gaussian_drift, (source_idx,))
                trace, accepted = mh(trace, ntones_randomness, (source_idx,), ntones_involution)
            end
            
        end
        
        for rep_idx = 1:10
            trace, accepted = mh(trace, switch_randomness, (), switch_involution)
        end
        
        for rep_idx = 1:2
            trace, accepted = mh(trace, swap_sources_randomness, (), swap_sources_involution)
            @assert accepted
            trace, accepted = mh(trace, select(:n_sources))
        end
        
        plot_sources(trace, obs_gram, block_idx; save=false)
        
    end
    
    return trace
    
end

function run_inference(obs_gram, args; n_blocks=10)
    
    #=To evaluate our likelihood function, 
    we constrain the scene to be equal to obs_gram
    this means that we will evaluate the probability of obs_gram
    under the normal distribution defined by our model 
    when it samples scene_gram=#
    constraints = choicemap((:scene, obs_gram))
    #We randomly generate an initial trace
    initial_trace, _ = generate(generate_scene, args, constraints)
    #Then we make proposals for all the latent, non-obseved variables
    final_trace = run_proposals(initial_trace, obs_gram, n_blocks)
    return final_trace
    
end

Let's generate a random trace from our prior to use as our observation:

In [None]:
trace = simulate(generate_scene, args)
scene_gram, scene_wave, source_waves = get_retval(trace);
plot_gtg(scene_gram, scene_duration, audio_sr, 20.0, 100.0)

In [None]:
run_inference(scene_gram, args, n_blocks=50);

### To ponder

Is there anything that looks like it's probably a local optimum, and we could find a better explanation than where inference has settled? What kind of inference move could we design to move out of the local optimum? 

### Tone sequence demonstration

Let's see if our model can recover the sources that we hear in a classic auditory scene analysis demonstration. 

In [None]:
using WAV;

In [None]:
gtg_params = GammatonegramParams(0.025,0.010, 20.0, 64, 0.50, 1e-6, 1e-80, 20.0)

sound_name = string("./sounds/bouncing60dB.wav")
demo, audio_sr = wavread(sound_name);
audio_sr = Int(audio_sr)
demo = demo[:,1];
wts, gtg_freqs = gtg_weights(audio_sr, gtg_params)
demo_gram, t = gammatonegram(demo, wts, audio_sr, gtg_params)
scene_duration = length(demo)/audio_sr; 

args = (scene_duration, audio_sr, wts, gtg_params);

In [None]:
 plot_gtg(demo_gram, scene_duration, audio_sr, 20.0, 100.0)

In [None]:
run_inference(demo_gram, args, n_blocks=100);

We could then do any number of procedures that we saw in the first notebook, for example, predict what sounds each source will make next or predict what would happen behind a masker.