# Setup

In [None]:
using WAV
include("../tools/plotting.jl");
include("../model/model.jl");
include("../model/gammatonegram.jl");
include("../model/time_helpers.jl");

In [None]:
using GenWorldModels

In [None]:
source_params, steps, gtg_params, obs_noise = include("../params/base.jl")
sr = 2000.0
gtg_params["dB_threshold"] = 0.0
wts, = gtg_weights(sr, gtg_params);

# Model

### Embed a sound within a larger scene

In [None]:
function embed_in_scene(scene_length, sr, wave, onset)
  n_samples = Int(floor(sr * scene_length))
  scene_wave = zeros(n_samples)
  sample_start = max(1, Int(floor(onset * sr)))
  sample_finish = min(sample_start + length(wave), length(scene_wave))
  scene_wave[sample_start:sample_finish-1] = wave[1:length(sample_start:sample_finish-1)]
  return scene_wave
end

### Generate white noise sound

In [None]:
@gen function generate_single_noise(scene_length, steps, sr)
  onset ~ uniform(0, scene_length)
  duration ~ uniform(0.1, 1.0)
  amp ~ normal(10.0, 8.0)
  times, t, f = get_gp_spectrotemporal([onset, onset+duration], steps, sr)
  noise_wave = generate_noise(transpose(reshape(fill(amp, length(times)), (length(f), length(t)))), duration, steps, sr, 1e-6)
  return embed_in_scene(scene_length, sr, noise_wave, onset)
end;

### Generate tone with pitch

In [None]:
@gen function generate_single_tone(scene_length, step_size, sr)
  step_size = step_size["t"]
  erb ~ uniform(0.4, 37.0)
  onset ~ uniform(0.0, scene_length)
  duration ~ uniform(0.1, 1.0)
  times = get_element_gp_times([onset, onset + duration], step_size)
  wave = generate_tone(fill(erb, length(times)), fill(50.0, length(times)), duration, step_size, sr, 1.0e-6)
  return embed_in_scene(scene_length, sr, wave, onset)
end;

### Generate sound (noise or tone)

In [None]:
@gen function generate_single_sound(world, source_idx)
    scene_length ~ lookup_or_generate(world[:args][:scene_length])
    steps ~ lookup_or_generate(world[:args][:steps])
    sr ~ lookup_or_generate(world[:args][:sr])

    is_noise ~ bernoulli(0.4)
    if is_noise
        wave = {*} ~ generate_single_noise(scene_length, steps, sr)
    else
        wave = {*} ~ generate_single_tone(scene_length, steps, sr)
    end
    return wave
end
#generate_sounds = Map(generate_single_sound);

### Generate scene with many sounds

In [None]:
@gen (static) function _generate_scene(world, wts, gtg_params)
  n_tones ~ uniform_discrete(1, 4)
    
  scene_duration ~ lookup_or_generate(world[:args][:scene_length])
  audio_sr ~ lookup_or_generate(world[:args][:sr])

  waves ~ Map(lookup_or_generate)([world[:waves][i] for i=1:n_tones])
  n_samples = Int(floor(scene_duration * audio_sr))
  scene_wave = reduce(+, waves; init=zeros(n_samples))
  scene_gram, = gammatonegram(scene_wave, wts, audio_sr, gtg_params)
  scene ~ noisy_matrix(scene_gram, 1.0)
  return scene_gram, scene_wave, waves
end;

generate_scene = UsingWorld(
    _generate_scene,
    :waves => generate_single_sound;
    world_args=(:scene_length, :steps, :sr)
)

In [None]:
@load_generated_functions

# Trace visualization and playback

In [None]:
function vis_and_write_wave(tr, title)
  duration, _, sr, = get_args(tr)
  gram, scene_wave, = get_retval(tr)
  wavwrite(scene_wave/maximum(abs.(scene_wave)), title, Fs=sr)
  plot_gtg(gram, duration, sr, 0, 100)
end

Some default arguments to `generate_scene`:

In [None]:
scene_length, steps, sr = (2.0, steps, sr)
args = (scene_length, steps, sr, wts, gtg_params)

Generate and visualize a trace:

In [None]:
tr = simulate(generate_scene, args);
vis_and_write_wave(tr, "simulated_scene.wav")

Playback (works on Mac, maybe not other platforms):

In [None]:
; afplay simulated_scene.wav

# Inference

In [None]:
function do_inference(tr, iters)
  for i=1:iters
    for j=1:tr[:kernel => :n_tones]
        tr, = mh(tr, select(:world => :waves => j))
        if tr[:world => :waves => j => :is_noise]
          tr, = mh(tr, select(:world => :waves => j => :amp))
        else
          tr, = mh(tr, select(:world => :waves => j => :erb))
        end
        tr, = mh(tr, select(:world => :waves => j => :onset))
        tr, = mh(tr, select(:world => :waves => j => :duration))
    end
    tr, = mh(tr, select(:kernel => :n_tones))
  end
  tr
end

# Testing

Generate ground truth:

In [None]:
ground_truth = simulate(generate_scene, args);
vis_and_write_wave(ground_truth, "ground_truth.wav")

In [None]:
; afplay ground_truth.wav

Generate trace with constraints:

In [None]:
inferred_trace, = generate(generate_scene, args, choicemap((:kernel => :scene, ground_truth[:kernel => :scene])));

Run inference:

In [None]:
inferred_trace = do_inference(inferred_trace, 100);
get_score(inferred_trace)

In [None]:
vis_and_write_wave(inferred_trace, "inferred.wav")

In [None]:
; afplay ground_truth.wav

In [None]:
observations = choicemap((:kernel => :scene, tr[:kernel => :scene]))
inferred_tr, = generate(generate_scene, (2.0, steps, sr, wts, gtg_params), observations);

In [None]:
inferred_tr = do_inference(inferred_tr, 100);
get_score(inferred_tr)

In [None]:
vis_and_write_wave(inferred_tr, "inferred.wav")

In [None]:
; afplay inferred.wav

# Auditory illusion setup

In [None]:
function tones_with_noise(amp)
    cm = choicemap((:kernel => :n_tones) => 3,
              (:world => :waves => 1 => :is_noise) => false,
              (:world => :waves => 1 => :erb) => 10.0,
              (:world => :waves => 1 => :onset) => 0.5,
              (:world => :waves => 1 => :duration) => 0.3,
              (:world => :waves => 2 => :is_noise) => false,
              (:world => :waves => 2 => :erb) => 10.0,
              (:world => :waves => 2 => :onset) => 1.1,
              (:world => :waves => 2 => :duration) => 0.3,
              (:world => :waves => 3 => :is_noise) => true,
              (:world => :waves => 3 => :amp) => amp,
              (:world => :waves => 3 => :onset) => 0.8,
              (:world => :waves => 3 => :duration) => 0.3)
    tr, = generate(generate_scene, args, cm)
    return tr
end

In [None]:
trr = tones_with_noise(10.0);

In [None]:
vis_and_write_wave(trr, "trr.wav")

In [None]:
; afplay trr.wav

In [None]:
observations = choicemap((:kernel => :scene, trr[:kernel => :scene]))
inferred_tr,weight = generate(generate_scene, args, observations);
weight

In [None]:
inferred_tr[:kernel => :scene] == trr[:kernel => :scene]

In [None]:
inferred_tr = do_inference(inferred_tr, 1000);
get_score(inferred_tr)

In [None]:
vis_and_write_wave(inferred_tr, "inferred.wav")

In [None]:
get_submap(get_choices(inferred_tr), :world => :waves)

In [None]:
; afplay inferred.wav