## Setup

In [None]:
]activate ..

In [None]:
include("main.jl")

In [None]:
using .AudioInference: gammatonegram, AudioSource
using .AudioInference
AI = AudioInference

In [None]:
trr = tones_with_noise(10.); nothing

In [None]:
vis_and_write_wave(trr, "trr.wav")

In [None]:
using Dates

In [None]:
using PyPlot

In [None]:
include("worldmodel/inference2.jl")

## Functions For Likelihood Recording & Graphing

In [None]:
#=
Outputs matrices where each column is a run for one initial trace,
and each row represents the ith iteration of each run.
=#
function get_times_and_likelihoods(initial_trs, run_inf!, iters)
    likelihoods = zeros(Float64, (iters, length(initial_trs)))
    times = zeros(Float64, (iters, length(initial_trs)))
    starttime = Dates.now()
    run_inf!(initial_trs[1], 20, (tr,) -> nothing) # compilation run
    for (i, initial_tr) in enumerate(initial_trs)
        print("Running trial $i...;")
        println(" $(Dates.now() - starttime) ms ellapsed in total")
        (l, t, record!) = AudioInference.get_worldmodel_likelihood_time_tracker_and_recorder()
        run_inf!(initial_tr, iters, record!)
        likelihoods[:, i] = l
        times[:, i] = t
    end
    return (times, likelihoods)
end

In [None]:
run_specs = Dict(
    :generic => (AudioInference.do_generic_inference, 1),
    :dumb_bd => (AudioInference.do_birth_death_inference, 1.2),
   # :dumb_sm => (AudioInference.do_split_merge_inference, .6),
    :old_smbd => (AudioInference.do_smart_smbd_inference, .7),
    :new_smbd => (drift_smartsmbd_inference, 0.7),
    :new_bd => (drift_smartbd_inference, 0.75)
)
nothing

In [None]:
function perform_runs(run_specs, initial_trs, num_generic_iters)
    times = Dict()
    likelihoods = Dict()
    
    for (label, (inf, num_iters_multiplier)) in run_specs
        n_iters = Int(floor(num_iters_multiplier*num_generic_iters))
        println("Running ", label, " for ", n_iters, " iterations per initial trace.")
        (t, l) = get_times_and_likelihoods(initial_trs, inf, n_iters)
        times[label] = t
        likelihoods[label] = l
    end
    
    return (times, likelihoods)
end

In [None]:
function plot_avg_times_and_likelihoods(times, likelihoods; POINT_SIZE=3, order=nothing, names=nothing, miny=nothing, maxx=nothing)
    key_itr = order === nothing ? keys(times) : order
    for label in key_itr
        t = times[label]
        l = likelihoods[label]
        avg_t = sum(t, dims=2) / size(t)[2]
        avg_l = sum(l, dims=2) / size(l)[2]
        name = names === nothing ? String(label) : names[label]
        scatter(avg_t, avg_l, label=name, s=POINT_SIZE)
    end
    if miny !== nothing
        ylim(bottom=miny)
    end
    if maxx !== nothing
        xlim(right=maxx)
    end
    xlabel("time (s)")
    ylabel("log likelihood of observed sound given inferred sources")
    title("Predictive log-likelihod of inferred audio sources over time")
    legend(loc="lower right")
end

In [None]:
# merges the new times and likelihoods into the running ones (by concatenating to the
# matrix of run data)
# assumes that the same keys are 
function merge_in_runs!(running_times, running_likelihoods, new_times, new_likelihoods)
    for key in keys(running_times)
       running_times[key] = hcat(running_times[key], new_times[key])
       running_likelihoods[key] = hcat(running_likelihoods[key], new_likelihoods[key])
    end
end

### Fetch serialized data

In [None]:
using Serialization

In [None]:
(rt, rl) = deserialize("illusion_from_0.data")
size(rl[:generic])

In [None]:
plot_avg_times_and_likelihoods(rt, rl, POINT_SIZE=1, miny=-300000)

In [None]:
function record_likelihoods!(generate_initial_trs, run_specs, NUM_ITERS;
        filename=nothing,
        running_times=nothing,
        running_likelihoods=nothing,
        num_cycles=1
    )
    @assert num_cycles >= 1
    if running_times === nothing || running_likelihoods === nothing
        @assert running_times === nothing && running_likelihoods === nothing "Cannot provide only 1 of running_times and running_likelihoods"
        initial_trs = generate_initial_trs()
        (running_times, running_likelihoods) = perform_runs(run_specs, initial_trs, NUM_ITERS)
        num_cycles -= 1
    else
       @assert size(running_times[:generic])[1] == NUM_ITERS 
    end
    
    for _=1:(num_cycles)
        if filename !== nothing
            serialize(filename, (times=running_times, likelihoods=running_likelihoods))
        end
        try
            initial_trs = generate_initial_trs()
            (t, l) = perform_runs(run_specs, initial_trs, NUM_ITERS)
            merge_in_runs!(running_times, running_likelihoods, t, l)
        catch e end
    end
    if filename !== nothing
        serialize(filename, (times=running_times, likelihoods=running_likelihoods))
    end
    
    return (running_times, running_likelihoods)
end

In [None]:
NUM_RUNS_PER_EXPERIMENT = 4

In [None]:
## record likelihoods on illusion:
record_likelihoods!(
    () -> [generate_initial_tr(trr, num_sources=0)[1] for _=1:NUM_RUNS_PER_EXPERIMENT],
    run_specs, 500;
    running_times=rt, running_likelihoods=rl, num_cycles=6,
    filename="illusion_from_0_1-4-2021--1-25pm.data"
)

In [None]:
size(rl[:generic])

In [None]:
deserialize("illusion_from_0_1-4-2021--7-53am.data").times[:generic] |> size

In [None]:
## record likelihoods on synthetic:
synth_t, synth_l = record_likelihoods!(
    () -> [generate_initial_tr(
                generate(AI.generate_scene, AI.args, choicemap((:kernel => :n_tones, 3)))[1],
                num_sources=0
            )[1]
        for _=1:NUM_RUNS_PER_EXPERIMENT
        ],
    run_specs, 500; running_times=synth_t, running_likelihoods=synth_l,
    num_cycles=4,
    filename="synthetic_from_0_1-4-2021--9-45am.data"
)

In [None]:
plot_avg_times_and_likelihoods(rt, rl, POINT_SIZE=1, miny=-300000)

In [None]:
deserialize("synthetic_from_0_1-4-2021--9-45am.data").likelihoods[:generic] |> size

In [None]:
plot_avg_times_and_likelihoods(synth_t, synth_l, POINT_SIZE=1, miny=-300000)

## Record Likelihoods

In [None]:
NUM_RUNS_PER_EXPERIMENT = 4
initial_trs = [generate_initial_tr(trr, num_sources=0)[1] for _=1:NUM_RUNS_PER_EXPERIMENT]
nothing


In [None]:
NUM_ITERS = 500
(t, l) = perform_runs(run_specs, initial_trs, NUM_ITERS)
nothing


In [None]:
running_times, running_likelihoods = (t, l)
nothing

In [None]:
NUM_ITERS = 500
for _=1:8
    try
        initial_trs = [generate_initial_tr(trr)[1] for _=1:4]
        (t, l) = perform_runs(run_specs, initial_trs, NUM_ITERS)
        merge_in_runs!(running_times, running_likelihoods, t, l)
    catch e
        continue;
    end
end

In [None]:
serialize("illusion_from_0_take2.data", (times=running_times, likelihoods=running_likelihoods))

## Plot Likelihoods

In [None]:
plot_avg_times_and_likelihoods(t, l, POINT_SIZE=1, miny=-400000)

In [None]:
plot_avg_times_and_likelihoods(running_times, running_likelihoods, miny=-300000)

In [None]:
[running_likelihoods[:new_smbd][:, i][100] for i=1:8]

In [None]:
wanted_tags = [:new_bd, :generic, :dumb_bd, :new_smbd, :old_smbd]
filt_t = Dict(t => running_times[t] for t in wanted_tags)
filt_l = Dict(t => running_likelihoods[t] for t in wanted_tags)

In [None]:
plot_avg_times_and_likelihoods(filt_t, filt_l)

In [None]:
names = Dict(
    :generic => "Ancestral resampling",
    :dumb_bd => "Generic birth/death",
    :new_bd => "Data-driven custom birth/death",
    :new_smbd => "Data-driven birth/death & split-merge"
   # :old_smbd => "AABI Submission best performer"
)

In [None]:
order = [:generic, :dumb_bd, :new_bd, :new_smbd]#, :old_smbd]

In [None]:
plot_avg_times_and_likelihoods(rt, rl; order=order, names=names, maxx=20, miny=-300000)

In [None]:
rl[:generic] |> size

### 

In [None]:
using Serialization
serialize("illusion_from_0.data", (running_times, running_likelihoods))

## Likelihoods on synthetic data

In [None]:
NUM_RUNS = 4
ground_truth_trs = [
    generate(AI.generate_scene, AI.args, choicemap((:kernel => :n_tones, 3)))[1]
    for i=1:NUM_RUNS
]
initial_trs = [
    AI.generate_initial_tr(ground_truth, num_sources=0)[1] for ground_truth in ground_truth_trs
]
nothing

In [None]:
NUM_ITERS = 650

In [None]:
(tms, lhs) = perform_runs(run_specs, initial_trs, NUM_ITERS)
nothing

In [None]:
plot_avg_times_and_likelihoods(tms, lhs, miny=-300000)

In [None]:
(running_times2, running_likelihoods2) = (tms, lhs)
nothing 

In [None]:
for _=1:5
    try
        ground_truth_trs = [
            generate(AI.generate_scene, AI.args, choicemap((:kernel => :n_tones, 3)))[1]
            for i=1:NUM_RUNS
        ]
        initial_trs = [
            AI.generate_initial_tr(ground_truth, num_sources=0)[1] for ground_truth in ground_truth_trs
        ]
        (t2, l2) = perform_runs(run_specs, initial_trs, NUM_ITERS)
        merge_in_runs!(running_times2, running_likelihoods2, t2, l2)
    catch e
        continue;
    end
end

In [None]:
plot_avg_times_and_likelihoods(running_times2, running_likelihoods2, miny=-600000)

In [None]:
serialize("synthetic_from_prior_from_0.data", (times = running_times2, likelihoods = running_likelihoods2))

In [None]:
plot(sum(running_likelihoods[:new_smbd][:,8:8], dims=2))

In [None]:
]add Serialize

In [None]:
using Serialization

In [None]:
serialize("runs_on_audio_illusion.data", (times=running_times, likelihoods=running_likelihoods))

In [None]:
serialize("runs_on_samples_from_prior.data", (times=running_times2, likelihoods=running_likelihoods2))

In [None]:
x = deserialize("runs_on_audio_illusion.data")

## Functions for time vs count recording

In [None]:
function get_worldmodel_count_likelihood_time_tracker_and_recorder()
    counts = Int[]
    times = Float64[]
    starttime = Dates.now()
    function record_worldmodel_iter!(tr)
        push!(counts, tr[:kernel => :n_tones])
        push!(times, Dates.value(Dates.now() - starttime)/1000)
    end
    return (counts, times, record_worldmodel_iter!)
end

In [None]:
#=
Outputs matrices where each column is a run for one initial trace,
and each row represents the ith iteration of each run.
=#
function get_times_and_counts(initial_trs, run_inf!, iters)
    counts = zeros(Int, (iters, length(initial_trs)))
    times = zeros(Float64, (iters, length(initial_trs)))
    starttime = Dates.now()
    run_inf!(initial_trs[1], 20, (tr,) -> nothing) # compilation run
    for (i, initial_tr) in enumerate(initial_trs)
        print("Running trial $i...;")
        println(" $(Dates.now() - starttime) ms ellapsed in total")
        (c, t, record!) = get_worldmodel_count_likelihood_time_tracker_and_recorder()
        run_inf!(initial_tr, iters, record!)
        counts[:, i] = c
        times[:, i] = t
    end
    return (times, counts)
end

In [None]:
function perform_count_runs(run_specs, initial_trs, num_generic_iters)
    times = Dict()
    counts = Dict()
    
    for (label, (inf, num_iters_multiplier)) in run_specs
        n_iters = Int(floor(num_iters_multiplier*num_generic_iters))
        println("Running ", label, " for ", n_iters, " iterations per initial trace.")
        (t, c) = get_times_and_counts(initial_trs, inf, n_iters)
        times[label] = t
        counts[label] = c
    end
    
    return (times, counts)
end

## Record time vs counts

In [None]:
filtered_specs = Dict(i => run_specs[i] for i in (:generic, :smart_bd, :dumb_bd, :smart_smbd))

In [None]:
initial_trs = [AudioInference.generate_initial_tr(trr)[1] for _=1:200]
(t, c) = perform_count_runs(filtered_specs, initial_trs, 20)

In [None]:
# counts = matrix, where each column has counts at each iter
# times = matrix, where each column has time at which each iter occurred
function get_count_to_time(times, counts)
    counts_to_time = [Set() for _=1:maximum(counts)]
    num_runs = size(times)[2]
    for run=1:num_runs
        t = times[:, run]
        c = counts[:, run]
        for i=1:length(t)
            i == 1 && continue
            Δt = t[i] - t[i-1]
            push!(counts_to_time[c[i]], Δt)
        end
    end
    return counts_to_time
end

In [None]:
label = :smart_smbd_drift
ctot = get_count_to_time(t[label], c[label])

In [None]:
reduce(+, ctot[1]; init=0.)

In [None]:
function plot_count_to_avg_time(times, counts; POINT_SIZE=5, order=nothing, names=nothing)
    key_itr = order === nothing ? keys(times) : order
    for label in key_itr
        t = times[label]
        c = counts[label]
        count_to_time = get_count_to_time(t, c)
        count_to_avg_time = map(v -> reduce(+, v; init=0.)/length(v), count_to_time)
        plot(1:length(count_to_avg_time), count_to_avg_time, label=label)
    end
     ylabel("time (s) for inference pass")
     xlabel("num sources in inferred trace")
    title("Runtime/iter for different numbers of inferred sources")
    legend(loc="lower right")
end

In [None]:
plot_count_to_avg_time(t, c)

# Timing individual kernel runs vs "active set"

Here, I want to time how long it takes to run each kernel in traces with different numbers of components.

In [None]:
names = Dict(
    :generic => "Ancestral resampling",
    :dumb_bd => "Generic birth/death",
    :smart_bd => "Data-driven custom birth/death",
    :smart_smbd => "Data-driven birth/death & split-merge"
)

In [None]:
order = [:generic, :dumb_bd, :smart_bd, :smart_smbd]

In [None]:
plot_avg_times_and_likelihoods(filt_t, filt_l; order=order, names=names, maxx=18, miny=-10000)

In [None]:
using Serialization

In [None]:
serialize("save.txt", running_times)

In [None]:
initial_trs = [generate_initial_tr(trr)[1] for _=1:2]
(t, l) = perform_runs(run_specs, initial_trs, 10)
nothing

In [None]:
plot_count_to_avg_time(t, c)

In [None]:
rt, rl = t, l

In [None]:
initial_trs = [generate_initial_tr(trr)[1] for _=1:2]
(t, l) = perform_runs(run_specs, initial_trs, 10)

In [None]:
merge_in_runs!(rt, rl, t, l)

In [None]:
for _=1:2
    initial_trs = [generate_initial_tr(trr)[1] for _=1:2]
    (t, l) = perform_runs(run_specs, initial_trs, 10)
    merge_in_runs!(rt, rl, t, l)
end

In [None]:
(generic_times, generic_likelihoods) = get_avg_likelihoods(fill(initial_tr,20), AudioInference.do_generic_inference, 800)
plot(generic_times, generic_likelihoods)

In [None]:
(bd_times, bd_likelihoods) = get_avg_likelihoods(fill(initial_tr,20), AudioInference.do_birth_death_inference, 800)
plot(bd_times, bd_likelihoods)

In [None]:
(bd_times3, bd_likelihoods3) = get_avg_likelihoods(fill(initial_tr,20), AudioInference.do_birth_death_inference, 1200)
plot(bd_times3, bd_likelihoods3)

In [None]:
do_sm = (args...,) -> AudioInference.do_split_merge_inference(args...; num_sm_per_iter=3)
(sm_times, sm_likelihoods) = get_avg_likelihoods(fill(initial_tr,20), do_sm, 600)
plot(sm_times, sm_likelihoods)

In [None]:
(generic_times, generic_likelihoods) = get_avg_likelihoods(initial_trs), AudioInference.do_generic_inference, 180)
plot(generic_times, generic_likelihoods)

In [None]:
f = open("real_data.txt", "w")

In [None]:
(bd_times, bd_likelihoods) = get_avg_likelihoods(initial_trs, AudioInference.do_birth_death_inference, 200)
plot(bd_times, bd_likelihoods)

In [None]:
(sm_times, sm_likelihoods) = get_avg_likelihoods(initial_trs, AudioInference.do_split_merge_inference, 100)
plot(sm_times, sm_likelihoods)

In [None]:
(smart_bd_times, smart_bd_likelihoods) = get_avg_likelihoods(initial_trs, AudioInference.do_split_merge_inference, 100)

In [None]:
TIME_CAP = 7.0
g_indices = filter(i -> generic_times[i] < TIME_CAP, 1:length(generic_times))
bd_indices = filter(i -> bd_times[i] < TIME_CAP, 1:length(bd_times))
sm_indices = filter(i -> sm_times[i] < TIME_CAP, 1:length(sm_times))
nothing

In [None]:
ax = gca()
ax[:set_ylim]([-200000, 0])
POINT_SIZE = 2
scatter(generic_times[g_indices], generic_likelihoods[g_indices], label="ancestral resampling (generic MCMC)", s=POINT_SIZE, color="g")
scatter(bd_times3[bd_indices], bd_likelihoods3[bd_indices], label="birth/death", s=POINT_SIZE, color="darkorange")
scatter(sm_times[sm_indices], sm_likelihoods[sm_indices], label="birth/death + split/merge", s=POINT_SIZE)
xlabel("time (s)")
ylabel("log likelihood of observed sound given inferred waves")
title("Quality of inferred waveforms over time")
legend(loc="lower right")

In [None]:
(generic_times2, generic_likelihoods2) = get_avg_likelihoods(fill(initial_tr,20), AudioInference.do_generic_inference, 540)
plot(generic_times2, generic_likelihoods2)

In [None]:
(bd_times2, bd_likelihoods2) = get_avg_likelihoods(fill(initial_tr,20), AudioInference.do_birth_death_inference, 600)
plot(bd_times2, bd_likelihoods2)

In [None]:
(sm_times2, sm_likelihoods2) = get_avg_likelihoods(fill(initial_tr,20), AudioInference.do_split_merge_inference, 400)
plot(sm_times2, sm_likelihoods2)

In [None]:
TIME_CAP = 30.0
g_indices2 = filter(i -> generic_times2[i] < TIME_CAP, 1:length(generic_times2))
bd_indices2 = filter(i -> bd_times2[i] < TIME_CAP, 1:length(bd_times2))
sm_indices2 = filter(i -> sm_times2[i] < TIME_CAP, 1:length(sm_times2))
nothing

In [None]:
ax = gca()
ax[:set_ylim]([-400000, 0])
POINT_SIZE = 1
scatter(generic_times2[g_indices2], generic_likelihoods2[g_indices2], label="generic", s=POINT_SIZE)
scatter(bd_times2[bd_indices2], bd_likelihoods2[bd_indices2], label="birth/death", s=POINT_SIZE)
scatter(sm_times2[sm_indices2], sm_likelihoods2[sm_indices2], label="split/merge", s=POINT_SIZE)
xlabel("time (s)")
ylabel("log likelihood of observed sound given inferred waves")
title("Quality of inferred waveforms over time")
legend(loc="lower right")

In [None]:
d

In [None]:
function get_avg_likelihoods(initial_trs, run_inf!, iters)
  likelihoods = zeros(Float64, iters)
  times = zeros(Float64, iters)
  starttime = Dates.now()
  run_inf!(initial_tr, 20, (tr,) -> nothing) # compilation run
  for (i, initial_tr) in enumerate(initial_trs)
    print("Running trial $i...;")
    println(" $(Dates.now() - starttime) ms ellapsed in total")
    (l, t, record!) = AudioInference.get_worldmodel_likelihood_time_tracker_and_recorder()
    run_inf!(initial_tr, iters, record!)
    likelihoods += l
    times += t
  end
  likelihoods /= length(initial_trs)
  times /= length(initial_trs)
  return (times, likelihoods)
end

In [None]:
function get_avg_likelihoods_and_counts(initial_trs, run_inf!, iters)
  likelihoods = zeros(Float64, iters)
  counts = []
  starttime = Dates.now()
  for (i, initial_tr) in enumerate(initial_trs)
    print("Running trial $i...;")
    println(" $(Dates.now() - starttime) ms ellapsed in total")
    (l, record!) = AudioInference.get_worldmodel_likelihood_tracker_and_recorder()
    tr = run_inf!(initial_tr, iters, record!)
    push!(counts, tr[:kernel => :n_tones])
    likelihoods += l
  end
  likelihoods /= length(initial_trs)
  return (likelihoods, counts)
end

In [None]:
(l, c) = get_avg_likelihoods_and_counts(fill(initial_tr,2), AudioInference.do_birth_death_inference, 600)
c

In [None]:
(l, c) = get_avg_likelihoods_and_counts(fill(initial_tr,2), AudioInference.do_split_merge_inference, 600)
c

In [None]:
using Pkg; Pkg.add("ProfileView")

In [None]:
using Profile; using ProfileView;

In [None]:
@profile get_avg_likelihoods_and_counts(fill(initial_tr,5), AudioInference.do_split_merge_inference, 600)

In [None]:
ProfileView.view()

In [None]:
function plot_gtg(gtg, duration, audio_sr, vmin, vmax;colors="Blues",plot_colorbar=false)
    
    max_freq=audio_sr/2
    imshow(gtg, cmap=colors, origin="lower", extent=(0, duration, 0, max_freq),vmin=vmin, vmax=vmax, aspect=1/1300)
    locs, labels = yticks();
    lowlim = AudioInference.freq_to_ERB(1.)
    hilim = AudioInference.freq_to_ERB(max_freq)
    fs = Int.(floor.(AudioInference.ERB_to_freq(range(lowlim, stop=hilim, length=length(locs)))))
    setp(gca().set_yticklabels(fs), fontsize="small")
    if plot_colorbar
        plt.colorbar()
    end
end

In [None]:
function vis(tr)
    duration, _, sr, = AudioInference.get_args(tr)
    gram, scene_wave, = AudioInference.get_retval(tr)
  #  wavwrite(scene_wave/maximum(abs.(scene_wave)), title, Fs=sr)
 # display(gram)
    plot_gtg(gram, duration, sr, 0, 100)
end

In [None]:
trr = tones_with_noise(10.); nothing

In [None]:
vis(trr)

In [None]:
tr = AudioInference.do_birth_death_inference(initial_tr, 500, (tr,) -> ())

In [None]:
vis(tr)