In [None]:
]activate ..

In [None]:
include("main.jl")

In [None]:
using .AudioInference

In [None]:
trr = tones_with_noise(10.); nothing

In [None]:
vis_and_write_wave(trr, "trr.wav")

In [None]:
AudioInference.savefig("truth.pdf")

In [None]:
(initial_tr, weight) = generate_initial_tr(trr)
weight

In [None]:
using PyPlot
using Dates

In [None]:
function get_avg_likelihoods(initial_trs, run_inf!, iters)
  likelihoods = zeros(Float64, iters)
  times = zeros(Float64, iters)
  starttime = Dates.now()
  run_inf!(initial_tr, 20, (tr,) -> nothing) # compilation run
  for (i, initial_tr) in enumerate(initial_trs)
    print("Running trial $i...;")
    println(" $(Dates.now() - starttime) ms ellapsed in total")
    (l, t, record!) = AudioInference.get_worldmodel_likelihood_time_tracker_and_recorder()
    run_inf!(initial_tr, iters, record!)
    likelihoods += l
    times += t
  end
  likelihoods /= length(initial_trs)
  times /= length(initial_trs)
  return (times, likelihoods)
end

In [None]:
#initial_trs = [AudioInference.simulate(AudioInference.generate_scene, AudioInference.args) for _=1:5]

In [None]:
(generic_times, generic_likelihoods) = get_avg_likelihoods(fill(initial_tr,5), AudioInference.do_generic_inference, 180)
plot(generic_times, generic_likelihoods)

In [None]:
(bd_times, bd_likelihoods) = get_avg_likelihoods(fill(initial_tr,5), AudioInference.do_birth_death_inference, 200)
plot(bd_times, bd_likelihoods)

In [None]:
(sm_times, sm_likelihoods) = get_avg_likelihoods(fill(initial_tr,5), AudioInference.do_split_merge_inference, 100)
plot(sm_times, sm_likelihoods)

In [None]:
TIME_CAP = 7.0
g_indices = filter(i -> generic_times[i] < TIME_CAP, 1:length(generic_times))
bd_indices = filter(i -> bd_times[i] < TIME_CAP, 1:length(bd_times))
sm_indices = filter(i -> sm_times[i] < TIME_CAP, 1:length(sm_times))
nothing

In [None]:
POINT_SIZE = 8
scatter(generic_times[g_indices], generic_likelihoods[g_indices], label="generic", s=POINT_SIZE)
scatter(bd_times[bd_indices], bd_likelihoods[bd_indices], label="birth/death", s=POINT_SIZE)
scatter(sm_times[sm_indices], sm_likelihoods[sm_indices], label="split/merge", s=POINT_SIZE)
xlabel("time (s)")
ylabel("log likelihood of observed sound given inferred waves")
title("Quality of inferred waveforms over time")
legend(loc="lower right")

In [None]:
(generic_times2, generic_likelihoods2) = get_avg_likelihoods(fill(initial_tr,20), AudioInference.do_generic_inference, 540)
plot(generic_times2, generic_likelihoods2)

In [None]:
(bd_times2, bd_likelihoods2) = get_avg_likelihoods(fill(initial_tr,20), AudioInference.do_birth_death_inference, 600)
plot(bd_times2, bd_likelihoods2)

In [None]:
(sm_times2, sm_likelihoods2) = get_avg_likelihoods(fill(initial_tr,20), AudioInference.do_split_merge_inference, 400)
plot(sm_times2, sm_likelihoods2)

In [None]:
TIME_CAP = 30.0
g_indices2 = filter(i -> generic_times2[i] < TIME_CAP, 1:length(generic_times2))
bd_indices2 = filter(i -> bd_times2[i] < TIME_CAP, 1:length(bd_times2))
sm_indices2 = filter(i -> sm_times2[i] < TIME_CAP, 1:length(sm_times2))
nothing

In [None]:
ax = gca()
ax[:set_ylim]([-400000, 0])
POINT_SIZE = 1
scatter(generic_times2[g_indices2], generic_likelihoods2[g_indices2], label="generic", s=POINT_SIZE)
scatter(bd_times2[bd_indices2], bd_likelihoods2[bd_indices2], label="birth/death", s=POINT_SIZE)
scatter(sm_times2[sm_indices2], sm_likelihoods2[sm_indices2], label="split/merge", s=POINT_SIZE)
xlabel("time (s)")
ylabel("log likelihood of observed sound given inferred waves")
title("Quality of inferred waveforms over time")
legend(loc="lower right")

In [None]:
function get_avg_likelihoods_and_counts(initial_trs, run_inf!, iters)
  likelihoods = zeros(Float64, iters)
  counts = []
  starttime = Dates.now()
  for (i, initial_tr) in enumerate(initial_trs)
    print("Running trial $i...;")
    println(" $(Dates.now() - starttime) ms ellapsed in total")
    (l, record!) = AudioInference.get_worldmodel_likelihood_tracker_and_recorder()
    tr = run_inf!(initial_tr, iters, record!)
    push!(counts, tr[:kernel => :n_tones])
    likelihoods += l
  end
  likelihoods /= length(initial_trs)
  return (likelihoods, counts)
end

In [None]:
(l, c) = get_avg_likelihoods_and_counts(fill(initial_tr,2), AudioInference.do_birth_death_inference, 600)
c

In [None]:
(l, c) = get_avg_likelihoods_and_counts(fill(initial_tr,2), AudioInference.do_split_merge_inference, 600)
c

In [None]:
using Pkg; Pkg.add("ProfileView")

In [None]:
using Profile; using ProfileView;

In [None]:
@profile get_avg_likelihoods_and_counts(fill(initial_tr,5), AudioInference.do_split_merge_inference, 600)

In [None]:
ProfileView.view()

In [None]:
function plot_gtg(gtg, duration, audio_sr, vmin, vmax;colors="Blues",plot_colorbar=false)
    
    max_freq=audio_sr/2
    imshow(gtg, cmap=colors, origin="lower", extent=(0, duration, 0, max_freq),vmin=vmin, vmax=vmax, aspect=1/1300)
    locs, labels = yticks();
    lowlim = AudioInference.freq_to_ERB(1.)
    hilim = AudioInference.freq_to_ERB(max_freq)
    fs = Int.(floor.(AudioInference.ERB_to_freq(range(lowlim, stop=hilim, length=length(locs)))))
    setp(gca().set_yticklabels(fs), fontsize="small")
    if plot_colorbar
        plt.colorbar()
    end
end

In [None]:
function vis(tr)
    duration, _, sr, = AudioInference.get_args(tr)
    gram, scene_wave, = AudioInference.get_retval(tr)
  #  wavwrite(scene_wave/maximum(abs.(scene_wave)), title, Fs=sr)
 # display(gram)
    plot_gtg(gram, duration, sr, 0, 100)
end

In [None]:
trr = tones_with_noise(10.); nothing

In [None]:
vis(trr)

In [None]:
tr = AudioInference.do_birth_death_inference(initial_tr, 500, (tr,) -> ())

In [None]:
vis(tr)