In [None]:
using WAV;
using Gen;
import Random;
include("model.jl");
include("plotting.jl");
include("proposals.jl");
include("gammatonegram.jl");
include("time_helpers.jl");
include("inference_helpers.jl");
include("initialization.jl");

Random.seed!(1);

In [None]:
println("Loading in command-line args...")
#Load in command-line arguments
paper_name = "mcDermottOxenham2008/fig2/"
demo_name = string(paper_name, "2i");
f = open(string("./sounds/",paper_name, "parameters.json"),"r")
dt = read(f,String)
sound_gen_params = JSON.parse(dt)
close(f)

In [None]:
n=Dict("sources"=>Dict("type"=>"max","val"=>5),"elements"=>Dict("type"=>"max","val"=>5))   
audio_sr=20000; steps = Dict("t"=>0.020, "min"=> 0.020, "f"=>4); 
gtg_params=Dict("twin"=>0.025, "thop"=>0.010, "nfilts"=>64, "fmin"=>50, "width"=>1.0, "log_constant"=>1e-80, "dB_threshold"=>20, "ref"=>1e-4)

## Inference parameters
annealing_noise = -1

println("Loading in sound...")
## Inference
#Load in sound observation and its neural network guide distribution
demo_gram, wts, scene_duration, audio_sr = load_sound(demo_name, gtg_params)
# demo_guide = read_guide(demo_name)
plot_gtg(demo_gram,scene_duration,audio_sr/2,gtg_params["dB_threshold"],gtg_params["dB_threshold"]+60.);

args = (n, float(scene_duration), wts, steps, Int(audio_sr), annealing_noise, gtg_params)

In [None]:
demo_gram

In [None]:
scene_duration = round(1.5*rand() + 0.5,digits=3)
args = (n, float(scene_duration), wts, steps, Int(audio_sr), annealing_noise, gtg_params)

trace, = generate(generate_scene, args);
scene_gram, t, scene_wave, element_waves = get_retval(trace)
plot_gtg(scene_gram,scene_duration,audio_sr/2,gtg_params["dB_threshold"],gtg_params["dB_threshold"]+70.);

In [None]:
###Get all cochleagrams and masks
element_grams = Array{Float64}(undef, 0, size(scene_gram)[2], size(scene_gram)[1]) #nElements t f 
elements_to_keep = []
for i = 1:length(element_waves)
    source_elements_to_keep = []
    for j = 1:length(element_waves[i])
        gtg, t = gammatonegram(element_waves[i][j],wts,audio_sr,gtg_params)
        gtg = reshape(gtg, (size(gtg)[1], size(gtg)[2], 1))
        gtg = permutedims(gtg, [3, 2, 1]) 
        element_grams = i == 1 && j == 1 ? gtg : cat(element_grams, gtg; dims=1)
        push!(source_elements_to_keep, all(gtg .== gtg_params["dB_threshold"]) ? 0 : 1)
    end
    push!(elements_to_keep, source_elements_to_keep)
end
tokeep(x) = x == 1
idx_to_keep = findall(tokeep, vcat(elements_to_keep...))
element_grams = element_grams[idx_to_keep, :, :]
s = sum([sum(j) for j in elements_to_keep])

element_thresh = element_grams .== gtg_params["dB_threshold"]
r=trues(1,size(element_grams)[2],size(element_grams)[3])
all!(r,element_thresh)
z=zeros(size(element_grams))
vals, idxs = findmax(element_grams, dims=1)
for idx in idxs
    z[idx] = 1
end
z[1,:,:] = z[1,:,:] .* .!r[1,:,:];
plt.imshow(permutedims(z[10,:,end:-1:1],[2,1]))