In [None]:
]activate ..

In [None]:
using Gen

In [None]:
using WAV
include("../tools/plotting.jl")
include("../model/gammatonegram.jl");
include("../model/time_helpers.jl");
include("../model/extra_distributions.jl");

In [None]:
source_params, steps, gtg_params, obs_noise = include("../params/base.jl")
sr = 2000.0
gtg_params["dB_threshold"] = 0.0
wts, = gtg_weights(sr, gtg_params);

scene_length, steps, sr = (2.0, steps, sr)
args = (scene_length, steps, sr, wts, gtg_params)

In [None]:
using GenWorldModels

In [None]:
include("../main/worldmodel/model.jl")
include("worldmodel_utils.jl")

In [None]:
trr = tones_with_noise(10.); nothing

In [None]:
vis_and_write_wave(trr, "trr.wav")

### Birth/Death

In [None]:
@gen function birth_death_proposal(tr)
    do_birth ~ bernoulli(0.5)
    if do_birth
        idx ~ uniform_discrete(1, tr[:kernel => :n_tones] + 1)
    else
        idx ~ uniform_discrete(1, tr[:kernel => :n_tones])
    end
end

In [None]:
@oupm_involution birth_death_inv (old_tr, fwd_prop_tr) to (new_tr, bwd_prop_tr) begin
    do_birth = @read(fwd_prop_tr[:do_birth], :disc)
    idx = @read(fwd_prop_tr[:idx], :disc)
    num = @read(old_tr[:kernel => :n_tones], :disc)
    if do_birth
        @birth(AudioSource(idx))
        @write(new_tr[:kernel => :n_tones], num + 1, :disc)
        @regenerate(:world => :waves => AudioSource(idx))
    else
        @death(AudioSource(idx))
        @write(new_tr[:kernel => :n_tones], num - 1, :disc)
        @save_for_reverse_regenerate(:world => :waves => AudioSource(idx))
    end
    @write(bwd_prop_tr[:do_birth], !do_birth, :disc)
    @write(bwd_prop_tr[:idx], idx, :disc)
end
birth_death_kernel = OUPMMHKernel(birth_death_proposal, (), birth_death_inv)

In [None]:
function birth_death_iter(tr)
  tr = generic_no_num_change_inference_iter(tr)
  tr, _ = mh(tr, birth_death_kernel)
  tr
end

In [None]:
initial_tr = get_initial_tr(trr)
println("Initial trace score: $(get_score(initial_tr))")
inferred_tr = run_inference(initial_tr, birth_death_iter, 200)
println("Inferred trace score: $(get_score(inferred_tr))")

In [None]:
vis_and_write_wave(inferred_tr, "inferred_bd.wav")

### Split/Merge

In [None]:
tone_indices(tr, n_tones) = [idx for idx = 1:n_tones if !tr[:world => :waves => AudioSource(idx) => :is_noise]]
noise_indices(tr, n_tones) = [idx for idx = 1:n_tones if tr[:world => :waves => AudioSource(idx) => :is_noise]]
tone_merge_possible(tr) = length(tone_indices(tr, tr[:kernel => :n_tones])) > 1
noise_merge_possible(tr) = length(noise_indices(tr, tr[:kernel => :n_tones])) > 1
merge_possible(tr) = tr[:kernel => :n_tones] > 1 && (tone_merge_possible(tr) || noise_merge_possible(tr))

In [None]:
@gen function split_merge_proposal(tr)
  mp = merge_possible(tr)
  split_likelihood = mp ? 0.5 : 1.
  do_split ~ bernoulli(split_likelihood)
  
  if do_split
    {*} ~ sample_split(tr)
  else
    {*} ~ sample_merge(tr)
  end
end

In [None]:
@oupm_involution splitmerge_inv (old_tr, fwd_prop_tr) to (new_tr, bwd_prop_tr) begin
  n_tones = @read(old_tr[:kernel => :n_tones], :disc)
  do_split = @read(fwd_prop_tr[:do_split], :disc)
  deuce_idx1 = @read(fwd_prop_tr[:deuce_idx1], :disc)
  deuce_idx2 = @read(fwd_prop_tr[:deuce_idx2], :disc)
  solo_idx = @read(fwd_prop_tr[:solo_idx], :disc)
  
  if do_split
    @tcall handle_split(n_tones, solo_idx, deuce_idx1, deuce_idx2)
  else
    @tcall handle_merge(n_tones, solo_idx, deuce_idx1, deuce_idx2)
  end
  
  @write(bwd_prop_tr[:do_split], !do_split, :disc)
  @write(bwd_prop_tr[:solo_idx], solo_idx, :disc)
  @write(bwd_prop_tr[:deuce_idx1], deuce_idx1, :disc)
  @write(bwd_prop_tr[:deuce_idx2], deuce_idx2, :disc)
end

In [None]:
@dist uniform_from_list(list) = list[uniform_discrete(1, length(list))]

In [None]:
@gen function sample_split(tr)
  n_tones = tr[:kernel => :n_tones]
  solo_idx ~ uniform_discrete(1, n_tones)
  deuce_idx1 ~ uniform_discrete(1, n_tones+1)
  deuce_idx2 ~ uniform_from_list([i for i=1:n_tones+1 if i!=deuce_idx1])
  
  ch = get_submap(get_choices(tr), :world => :waves => AudioSource(solo_idx))
  if !ch[:is_noise]
    erb1 ~ normal(ch[:erb], .5)
    erb2 ~ normal(ch[:erb], .5)
  else
    amp1 ~ normal(ch[:amp], .5)
    amp2 ~ normal(ch[:amp], .5)
  end
  
  # duration of sounds we are splitting into:
  dur1 ~ uniform(0.1, max(.11, 0.7 * ch[:duration]))
  dur2 ~ uniform(0.1, max(.11, 0.7 * ch[:duration]))
end

In [None]:
@oupm_involution handle_split(n_tones, from_idx, to_idx1, to_idx2) (old_tr, fwd_prop_tr) to (new_tr, bwd_prop_tr) begin
  @split(AudioSource(from_idx), to_idx1, to_idx2)
  @write(new_tr[:kernel => :n_tones], n_tones + 1, :disc)
  
  o(x) = :world => :waves => AudioSource(from_idx) => x
  n1(x) = :world => :waves => AudioSource(to_idx1) => x
  n2(x) = :world => :waves => AudioSource(to_idx2) => x
  
  # copy is noise
  @copy(old_tr[o(:is_noise)], new_tr[n1(:is_noise)])
  @copy(old_tr[o(:is_noise)], new_tr[n2(:is_noise)])

  # handle start and end times
  @copy(old_tr[o(:onset)], new_tr[n1(:onset)])
  @copy(fwd_prop_tr[:dur1], new_tr[n1(:duration)])
  @copy(fwd_prop_tr[:dur2], new_tr[n2(:duration)])

  old_ons = @read(old_tr[o(:onset)], :cont)
  old_dur = @read(old_tr[o(:duration)], :cont)
  dur2 = @read(fwd_prop_tr[:dur2], :cont)
  @write(new_tr[n2(:onset)], old_ons + old_dur - dur2, :cont)
  
  # amp/erb
  if @read(old_tr[o(:is_noise)], :disc)
      @copy(fwd_prop_tr[:amp1], new_tr[n1(:amp)])
      @copy(fwd_prop_tr[:amp2], new_tr[n2(:amp)])
      @copy(old_tr[o(:amp)], bwd_prop_tr[:amp])
      @write(bwd_prop_tr[:merge_tone], false, :disc)
  else
      @copy(fwd_prop_tr[:erb1], new_tr[n1(:erb)])
      @copy(fwd_prop_tr[:erb2], new_tr[n2(:erb)])
      @copy(old_tr[o(:erb)], bwd_prop_tr[:erb])
      @write(bwd_prop_tr[:merge_tone], true, :disc)
  end
end

In [None]:
@gen function sample_merge(tr)
  n_tones = tr[:kernel => :n_tones]
  solo_idx ~ uniform_discrete(1, max(1, n_tones - 1))
  tone_prob = tone_merge_possible(tr) ? (noise_merge_possible(tr) ? 0.5 : 1.) : 0.
  merge_tone ~ bernoulli(tone_prob)
  indices = merge_tone ? tone_indices(tr, n_tones) : noise_indices(tr, n_tones)
  sorted = sort(indices, by=(idx -> tr[:world => :waves => AudioSource(idx) => :onset]))
  deuce_idx1 ~ uniform_from_list(sorted[1:end-1])
  list_idx1 = findall(x -> x == deuce_idx1, sorted)[1]
  deuce_idx2 ~ uniform_from_list(sorted[list_idx1+1:end])

  ch1 = get_submap(get_choices(tr), :world => :waves => AudioSource(deuce_idx1))
  ch2 = get_submap(get_choices(tr), :world => :waves => AudioSource(deuce_idx2))
  if !merge_tone
    amp ~ normal((ch1[:amp] + ch2[:amp])/2, 0.5)
  else
    erb ~ normal((ch1[:erb] + ch2[:erb])/2, 0.5)
  end
end

In [None]:
@oupm_involution handle_merge(n_tones, to_idx, from_idx1, from_idx2) (old_tr, fwd_prop_tr) to (new_tr, bwd_prop_tr) begin
  @merge(AudioSource(to_idx), from_idx1, from_idx2)
  @write(new_tr[:kernel => :n_tones], n_tones - 1, :disc)
  
  n(x) = :world => :waves => AudioSource(to_idx) => x
  o1(x) = :world => :waves => AudioSource(from_idx1) => x
  o2(x) = :world => :waves => AudioSource(from_idx2) => x

  # is_noise
  @copy(old_tr[o1(:is_noise)], new_tr[n(:is_noise)])

  # onset & duration
  start1 = @read(old_tr[o1(:onset)], :cont)
  dur1 = @read(old_tr[o1(:duration)], :cont)
  start2 = @read(old_tr[o2(:onset)], :cont)
  dur2 = @read(old_tr[o2(:duration)], :cont)
  end2 = start2 + dur2
  full_dur = (end2 - start1)
  real_start = min(start1, start2)
  
  @copy(old_tr[o1(:onset)], new_tr[n(:onset)])
  @write(new_tr[n(:duration)], full_dur, :cont)

  @write(bwd_prop_tr[:dur1], dur1, :cont)
  @write(bwd_prop_tr[:dur2], dur2, :cont)

  # noise / tone parameters
  if @read(old_tr[o1(:is_noise)], :disc)
      @copy(old_tr[o1(:amp)], bwd_prop_tr[:amp1])
      @copy(old_tr[o2(:amp)], bwd_prop_tr[:amp2])
      @copy(fwd_prop_tr[:amp], new_tr[n(:amp)])
  else
      @copy(old_tr[o1(:erb)], bwd_prop_tr[:erb1])
      @copy(old_tr[o2(:erb)], bwd_prop_tr[:erb2])
      @copy(fwd_prop_tr[:erb], new_tr[n(:erb)])
  end
end

In [None]:
split_merge_kernel = OUPMMHKernel(split_merge_proposal, (), splitmerge_inv)
function split_merge_iter(tr)
  tr, acc = mh(tr, birth_death_kernel, check=false)
  acc && println("birthdeath accepted")
  tr = generic_no_num_change_inference_iter(tr)
  tr, acc = mh(tr, split_merge_kernel, check=false)
  acc && println("splitmerge accepted")
  tr
end

In [None]:
initial_tr = get_initial_tr(trr)
get_score(initial_tr)

In [None]:
inferred_tr = run_inference(initial_tr, split_merge_iter, 200)
println("Initial trace score: $(get_score(initial_tr))")
println("Inferred trace score: $(get_score(inferred_tr))")

In [None]:
vis_and_write_wave(inferred_tr, "splitmerge_inf.wav")