# Amoritized Inference for source-level variables

In [None]:
using Gen;
using Random;
using Statistics: mean, std;
using LinearAlgebra: dot;
using StatsFuns: logsumexp;
using PyPlot
include("./time_helpers.jl")
include("./extra_distributions.jl")
include("./gaussian_helpers.jl")

In [None]:
source_params, steps, gtg_params, obs_noise = include("./base_params.jl")
scene_duration = 2.0
println(steps)

### Wait proposal

In [None]:
@gen function wait_model()
    
    ##Element spacing 
    tp_latents = Dict(:wait=>Dict()) #just do wait for now
    tp_params = source_params["tp"]
    for tp_type in keys(tp_latents)
        hyperpriors = tp_params[String(tp_type)]
        for latent in keys(hyperpriors)
            hyperprior = hyperpriors[latent]; syml = Symbol(latent)
            tp_latents[tp_type][syml] = @trace(hyperprior["dist"](hyperprior["args"]...), tp_type => syml)
        end
        tp_latents[tp_type][:args] = (tp_latents[tp_type][:a], tp_latents[tp_type][:mu]/tp_latents[tp_type][:a]) #a, b for gamma
    end
    
    ne_params = source_params["n_elements"]
    n_elements = ne_params["type"] == "max" ? 
        @trace(uniform_discrete(1, ne_params["val"]),:n_elements) : 
        @trace(geometric(ne_params["val"]),:n_elements)    
    
    waits = []
    for element_idx = 1:n_elements
        wait = element_idx == 1 ? @trace(uniform(0, scene_duration), (:element,element_idx)=>:wait) : 
            @trace(gamma(tp_latents[:wait][:args]...), (:element,element_idx)=>:wait)
        push!(waits, wait)
    end
    
    return waits
    
end


function data_generator()
    
    trace = simulate(wait_model, ())
    waits = get_retval(trace)
    
    constraints = choicemap()
    constraints[:wait => :mu] = trace[:wait => :mu]
    constraints[:wait => :a] = trace[:wait => :a]
    
    return ((waits,), constraints)
    
end


@gen function custom_dest_proposal_trainable(waits)
    
    #B_Qparameter
    @param B_mu_mu::Vector{Float64}
    @param B_logsigma_mu::Vector{Float64}
    @param B_mu_alpha::Vector{Float64}
    @param B_logsigma_alpha::Vector{Float64}
    @param log_std_replace::Float64

    n_w = length(waits)
    mu_w = mean(waits)
    sigma_w = n_w > 1.0 ? std(waits) : exp(log_std_replace)
    X = [mu_w, sigma_w, n_w, 1]
    
    mu_mu = dot(X, B_mu_mu)
    logsigma_mu = dot(X, B_logsigma_mu)
    mu_alpha = dot(X, B_mu_alpha)
    logsigma_alpha = dot(X, B_logsigma_alpha)
    
    @trace(log_normal(mu_mu,exp(logsigma_mu)),:wait => :mu)
    @trace(log_normal(mu_alpha,exp(logsigma_alpha)),:wait => :a)
    
    return nothing
    
end

for p in [:B_mu_mu, :B_logsigma_mu, :B_mu_alpha, :B_logsigma_alpha]
    Gen.init_param!(custom_dest_proposal_trainable, p, zeros(4))
end
Gen.init_param!(custom_dest_proposal_trainable, :log_std_replace,0.)
#Gen.ADAM(0.001, 0.9, 0.999, 1e-08)
update = Gen.ParamUpdate(Gen.FixedStepGradientDescent(0.00001), custom_dest_proposal_trainable);
scores = Gen.train!(custom_dest_proposal_trainable, data_generator, update,
    num_epoch=100, epoch_size=1000, num_minibatch=1, 
    minibatch_size=1000, evaluation_size=100, verbose=true);

plot(scores)
xlabel("Iterations of stochastic gradient descent")
ylabel("Estimate of expected conditional log probability density");
title("Wait")

println("Wait Q params:")
println("B_mu_mu: ", Gen.get_param(custom_dest_proposal_trainable, :B_mu_mu))
println("B_logsigma_mu: ", Gen.get_param(custom_dest_proposal_trainable, :B_logsigma_mu))
println("B_mu_alpha: ", Gen.get_param(custom_dest_proposal_trainable, :B_mu_alpha))
println("B_logsigma_alpha: ", Gen.get_param(custom_dest_proposal_trainable, :B_logsigma_alpha))
println("log_std_replace: ", Gen.get_param(custom_dest_proposal_trainable, :log_std_replace))

### Duration proposal

In [None]:
@gen function durminusmin_model()
    
    ##Element spacing 
    tp_latents = Dict(:dur_minus_min=>Dict()) #just do wait for now
    tp_params = source_params["tp"]
    for tp_type in keys(tp_latents)
        hyperpriors = tp_params[String(tp_type)]
        for latent in keys(hyperpriors)
            hyperprior = hyperpriors[latent]; syml = Symbol(latent)
            tp_latents[tp_type][syml] = @trace(hyperprior["dist"](hyperprior["args"]...), tp_type => syml)
        end
        tp_latents[tp_type][:args] = (tp_latents[tp_type][:a], tp_latents[tp_type][:mu]/tp_latents[tp_type][:a]) #a, b for gamma
    end
    
    ne_params = source_params["n_elements"]
    n_elements = ne_params["type"] == "max" ? 
        @trace(uniform_discrete(1, ne_params["val"]),:n_elements) : 
        @trace(geometric(ne_params["val"]),:n_elements)    
    
    dur_minus_mins = []
    for element_idx = 1:n_elements
        dur_minus_min = @trace(truncated_gamma(tp_latents[:dur_minus_min][:args]..., source_params["duration_limit"]), (:element,element_idx)=>:dur_minus_min); 
        push!(dur_minus_mins, dur_minus_min)
    end
    
    return dur_minus_mins
    
end


function data_generator()
    
    trace = simulate(durminusmin_model, ())
    dur_minus_mins = get_retval(trace)
    
    constraints = choicemap()
    constraints[:dur_minus_min => :mu] = trace[:dur_minus_min => :mu]
    constraints[:dur_minus_min => :a] = trace[:dur_minus_min => :a]
    
    return ((dur_minus_mins,), constraints)
    
end

@gen function trainable_durminusmin_proposal(dur_minus_mins)
    
    #B_Qparameter
    @param B_mu_mu::Vector{Float64}
    @param B_logsigma_mu::Vector{Float64}
    @param B_mu_alpha::Vector{Float64}
    @param B_logsigma_alpha::Vector{Float64}
    @param log_std_replace::Float64

    n_d = length(dur_minus_mins)
    mu_d = mean(dur_minus_mins)
    sigma_d = n_d > 1.0 ? std(dur_minus_mins) : exp(log_std_replace)
    X = [mu_d, sigma_d, n_d, 1]
    
    mu_mu = dot(X, B_mu_mu)
    logsigma_mu = dot(X, B_logsigma_mu)
    mu_alpha = dot(X, B_mu_alpha)
    logsigma_alpha = dot(X, B_logsigma_alpha)
    
    @trace(log_normal(mu_mu,exp(logsigma_mu)),:dur_minus_min => :mu)
    @trace(log_normal(mu_alpha,exp(logsigma_alpha)),:dur_minus_min => :a)
    
    return nothing
    
end

for p in [:B_mu_mu, :B_logsigma_mu, :B_mu_alpha, :B_logsigma_alpha]
    Gen.init_param!(trainable_durminusmin_proposal, p, zeros(4))
end
Gen.init_param!(trainable_durminusmin_proposal, :log_std_replace,0.)
#Gen.ADAM(0.001, 0.9, 0.999, 1e-08)
update = Gen.ParamUpdate(Gen.FixedStepGradientDescent(0.00001), trainable_durminusmin_proposal);
scores = Gen.train!(trainable_durminusmin_proposal, data_generator, update,
    num_epoch=100, epoch_size=1000, num_minibatch=1, 
    minibatch_size=1000, evaluation_size=100, verbose=false);

plot(scores)
xlabel("Iterations of stochastic gradient descent")
ylabel("Estimate of expected conditional log probability density");
title("Dur minus min")
println("Dur_minus_min Q params:")
println("B_mu_mu: ", Gen.get_param(custom_dest_proposal_trainable, :B_mu_mu))
println("B_logsigma_mu: ", Gen.get_param(custom_dest_proposal_trainable, :B_logsigma_mu))
println("B_mu_alpha: ", Gen.get_param(custom_dest_proposal_trainable, :B_mu_alpha))
println("B_logsigma_alpha: ", Gen.get_param(custom_dest_proposal_trainable, :B_logsigma_alpha))
println("log_std_replace: ", Gen.get_param(custom_dest_proposal_trainable, :log_std_replace))

In [None]:
@gen function gp_model(source_type, gp_type)
    
    ##GPs 
    gp_params = source_params["gp"]
    gp_latents = Dict(gp_type => Dict())
    for gp_type in keys(gp_latents)
        hyperpriors = gp_type === :erb ? gp_params["erb"] : 
            ((source_type == "noise" || source_type == "harmonic") ? gp_params["amp"]["2D"] : gp_params["amp"]["1D"] )
        for latent in keys(hyperpriors)
            hyperprior = hyperpriors[latent]; syml = Symbol(latent)
            gp_latents[gp_type][syml] = @trace(hyperprior["dist"](hyperprior["args"]...), gp_type => syml)
        end
    end
    
    ##Element spacing 
    tp_latents = Dict(:wait=>Dict(),:dur_minus_min=>Dict()) 
    tp_params = source_params["tp"]
    for tp_type in keys(tp_latents)
        hyperpriors = tp_params[String(tp_type)]
        for latent in keys(hyperpriors)
            hyperprior = hyperpriors[latent]; syml = Symbol(latent)
            tp_latents[tp_type][syml] = @trace(hyperprior["dist"](hyperprior["args"]...), tp_type => syml)
        end
        tp_latents[tp_type][:args] = (tp_latents[tp_type][:a], tp_latents[tp_type][:mu]/tp_latents[tp_type][:a]) #a, b for gamma
    end
    
    ne_params = source_params["n_elements"]
    n_elements = ne_params["type"] == "max" ? 
        @trace(uniform_discrete(1, ne_params["val"]),:n_elements) : 
        @trace(geometric(ne_params["val"]),:n_elements)    
    
    waits = []
    dur_minus_mins = []
    prev_gps = Dict(:x => [], :t => [], gp_type => [], :reshaped=>[])
    time_so_far = 0.0;
    for element_idx = 1:n_elements
        wait = element_idx == 1 ? @trace(uniform(0, scene_duration), (:element,element_idx)=>:wait) : 
            @trace(gamma(tp_latents[:wait][:args]...), (:element,element_idx)=>:wait)
        dur_minus_min = @trace(truncated_gamma(tp_latents[:dur_minus_min][:args]..., source_params["duration_limit"]), (:element,element_idx)=>:dur_minus_min); 

        push!(waits, wait)
        push!(dur_minus_mins, dur_minus_min)
        duration = dur_minus_min + steps["min"]; onset = time_so_far + wait; 
        time_so_far = onset + duration; element_timing = [onset, time_so_far]

        ## Define points at which the GPs should be sampled
        if gp_type === :erb || (source_type == "tone" && gp_type === :amp)
            x = get_element_gp_times(element_timing, steps["t"])
        elseif gp_type === :amp && (source_type == "noise"  || source_type == "harmonic")
            x, ts, prev_gps[:f] = get_gp_spectrotemporal(element_timing, steps, audio_sr)
            append!(prev_gps[:t],ts)
        end

        mu, cov = element_idx == 1 ? get_mu_cov(x, gp_latents[gp_type]) : 
                get_cond_mu_cov(x, prev_gps[:x], prev_gps[gp_type], gp_latents[gp_type])
        element_gp = @trace(mvnormal(mu, cov), (:element, element_idx) => gp_type)
        if (source_type == "harmonic" || source_type == "noise") && (gp_type == :amp)
            reshaped_gp = transpose(reshape(element_gp, (length(prev_gps[:f]), length(ts))))
            push!(prev_gps[:reshaped],reshaped_gp)
        end
        
        append!(prev_gps[:x],x) #could also be push
        append!(prev_gps[gp_type],element_gp)
                
    end
    
    return waits, dur_minus_mins, prev_gps
    
end

source_type="tone";gp_type=:amp;
function data_generator()
    
    trace = simulate(gp_model, ("tone", :amp))
    waits, dur_minus_mins, prev_gps = get_retval(trace)
    
    constraints = choicemap()
    d = gp_type == :erb ? source_params["gp"]["erb"] : (source_type == "tone" ? source_params["gp"]["amp"]["1D"] : source_params["gp"]["amp"]["2D"]) 
    for k in keys(d)
        constraints[gp_type => Symbol(k)] = trace[gp_type => Symbol(k)]
    end
    
    return ((prev_gps,), constraints)
    
end

data_generator()

In [None]:
@gen function trainable_amp1D_proposal(gps)
    
    #Mu and Sigma of GP
    @param B_mu_mu::Vector{Float64}
    @param B_logsigma_mu::Vector{Float64}
    @param B_mu_sigma::Vector{Float64}
    @param B_logsigma_sigma::Vector{Float64}
    @param log_std_replace::Float64
    
    mu_d = mean(gps[:amp])
    sigma_d = length(gps[:amp]) > 1 ? std(gps[:amp]) : exp(log_std_replace)
    X = [mu_d, sigma_d, 1]
    
    mu_mu = dot(X, B_mu_mu)
    logsigma_mu = dot(X, B_logsigma_mu)
    mu_sigma = dot(X, B_mu_sigma)
    logsigma_sigma = dot(X, B_logsigma_sigma)
    
    gp_mu = @trace(normal(mu_mu,exp(logsigma_mu)),:amp => :mu)
    gp_sigma = @trace(log_normal(mu_sigma, exp(logsigma_sigma)), :amp => :sigma)
    
    #Temporal lengthscale of GP
    @param W_mu_scale::Vector{Float64}
    @param C_mu_scale::Float64
    @param replace_mu_scale::Float64
    
    @param W_logsigma_scale::Vector{Float64}
    @param C_logsigma_scale::Float64
    @param replace_logsigma_scale::Float64
    
    @param mu_times::Vector{Float64}
    @param logsigma_times::Vector{Float64}
    
    mu_scale = 0.0; logsigma_scale = 0.0;
    len_gp = length(gps[:x])
    if len_gp > 1
        
        #Subsample if GP is too big 
        #sample pairs instead of sampling points and getting the pairs within that
        idx = []; max_len = 50
        if len_gp > max_len
            idx = Random.randsubseq(1:len_gp, min((max_len*1.5 * 1.0/len_gp),1.0))
            idx = idx[1:min(max_len, length(idx))]
        else
            idx = 1:len_gp
        end
        
        n = sum(1:length(idx)-1)
        delta_ts = Vector{Float64}(undef, n)
        delta_ys = Vector{Float64}(undef, n)
        rs = Array{Float64}(undef, 1, n)
        ps = Array{Float64}(undef, n, 3) #maybe defining this first is breaking the gradients? logsumexp? 
            
        k = 1
        for i = 1:length(idx)
            for j = (i+1):length(idx)
                i_idx = idx[i]; j_idx = idx[j]
                delta_ts[k] = gps[:x][j_idx] - gps[:x][i_idx]
                delta_ys[k] = abs(gps[:amp][j_idx] - gps[:amp][i_idx])
                rs[k] = delta_ys[k]/(delta_ts[k]*gp_sigma)
                lp = [Gen.logpdf(normal, delta_ts[k], mu_times[g],exp(logsigma_times[g])) for g = 1:3]
                p = exp.(lp .- logsumexp(lp))
                ps[k,:] = p
                k += 1
            end
        end

        mu_scale = (rs*ps*reshape(W_mu_scale, 3, 1))[1] + C_mu_scale #mu_scale(1)=rs(n,)*ps(n,3)*W(3,)
        logsigma_scale = (rs*ps*reshape(W_logsigma_scale, 3, 1))[1] + C_logsigma_scale 
        
    else

        mu_scale = replace_mu_scale
        logsigma_scale = replace_logsigma_scale
        
    end
    @trace(log_normal(mu_scale, exp(logsigma_scale)), :amp=>:scale)
    
    ## epislon, local noise parameter of GP
    #Could also calculate the deviation of a middle-point
    #from the straight line between its neighbours
    @param B_mu_noise::Vector{Float64}
    @param replace_mu_noise::Float64
    @param B_logsigma_noise::Vector{Float64}
    @param replace_logsigma_noise::Float64
    
    local_stds = []
    for i = 1:3:length(gps[:x])
        t = gps[:x][i]
        if issubset([t-steps["t"], t, t+steps["t"]], gps[:x])
            local_gp = gps[:amp][i-1:i+1] #they have to be continguous
            push!(local_stds, std(local_gp))
        end
    end
    mu_noise = 0.0; logsigma_noise = 0.0;
    if length(local_stds) > 1
        mu_epsilon = mean(local_stds)
        std_epsilon = std(local_stds)
        X = [mu_epsilon, std_epsilon, 1]

        mu_noise = dot(X, B_mu_noise)
        logsigma_noise = dot(X, B_logsigma_noise)
    else
        mu_noise = replace_mu_noise
        logsigma_noise = replace_logsigma_noise
    end
    @trace(log_normal(mu_noise, exp(logsigma_noise)), :amp=>:noise)
    
    return nothing
    
end


In [None]:
for p in [:B_mu_mu,:B_logsigma_mu,:B_mu_sigma,:B_logsigma_sigma,:W_mu_scale,:W_logsigma_scale,:B_mu_noise,:B_logsigma_noise]
    Gen.init_param!(trainable_amp1D_proposal, p, zeros(3))
end
#Gaussians:
Gen.init_param!(trainable_amp1D_proposal, :mu_times, [0.01,0.8,2.0])
Gen.init_param!(trainable_amp1D_proposal, :logsigma_times, [-2.0,-2.0,-2.0])
for p in [:log_std_replace,:C_mu_scale,:replace_mu_scale,:C_logsigma_scale,:replace_logsigma_scale,:replace_mu_noise,:replace_logsigma_noise]
    Gen.init_param!(trainable_amp1D_proposal, p, 0.0)
end
update = Gen.ParamUpdate(Gen.FixedStepGradientDescent(1e-8), trainable_amp1D_proposal);
scores = Gen.train!(trainable_amp1D_proposal, data_generator, update,
    num_epoch=100, epoch_size=50, num_minibatch=1, 
    minibatch_size=50, evaluation_size=10, verbose=false);

plot(scores)
xlabel("Iterations of stochastic gradient descent")
ylabel("Estimate of expected conditional log probability density");
title("GP")

In [None]:
for p in [:B_mu_mu,:B_logsigma_mu,:B_mu_sigma,:B_logsigma_sigma,:W_mu_scale,:W_logsigma_scale,:mu_times,:logsigma_times,:B_mu_noise,:B_logsigma_noise,:log_std_replace,:C_mu_scale,:replace_mu_scale,:C_logsigma_scale,:replace_logsigma_scale,:replace_mu_noise,:replace_logsigma_noise]
    s = string(p)
    println("$s: ", Gen.get_param(trainable_amp1D_proposal, p))
end


In [None]:
source_type="tone";gp_type=:erb;
function data_generator()
    
    trace = simulate(gp_model, ("tone", :erb))
    waits, dur_minus_mins, prev_gps = get_retval(trace)
    
    constraints = choicemap()
    d = gp_type == :erb ? source_params["gp"]["erb"] : (source_type == "tone" ? source_params["gp"]["amp"]["1D"] : source_params["gp"]["amp"]["2D"]) 
    for k in keys(d)
        constraints[gp_type => Symbol(k)] = trace[gp_type => Symbol(k)]
    end
    
    return ((prev_gps,), constraints)
    
end

In [None]:
@gen function trainable_erb_proposal(gps)
    
    #Mu and Sigma of GP
    @param B_mu_mu::Vector{Float64}
    @param B_logsigma_mu::Vector{Float64}
    @param B_mu_sigma::Vector{Float64}
    @param B_logsigma_sigma::Vector{Float64}
    @param log_std_replace::Float64
    
    mu_d = mean(gps[:erb])
    sigma_d = length(gps[:erb]) > 1 ? std(gps[:erb]) : exp(log_std_replace)
    X = [mu_d, sigma_d, 1]
    
    mu_mu = dot(X, B_mu_mu)
    logsigma_mu = dot(X, B_logsigma_mu)
    mu_sigma = dot(X, B_mu_sigma)
    logsigma_sigma = dot(X, B_logsigma_sigma)
    
    gp_mu = @trace(normal(mu_mu,exp(logsigma_mu)),:erb => :mu)
    gp_sigma = @trace(log_normal(mu_sigma, exp(logsigma_sigma)), :erb => :sigma)
    
    #Temporal lengthscale of GP
    @param W_mu_scale::Vector{Float64}
    @param C_mu_scale::Float64
    @param replace_mu_scale::Float64
    
    @param W_logsigma_scale::Vector{Float64}
    @param C_logsigma_scale::Float64
    @param replace_logsigma_scale::Float64
    
    @param mu_times::Vector{Float64}
    @param logsigma_times::Vector{Float64}
    
    mu_scale = 0.0; logsigma_scale = 0.0;
    len_gp = length(gps[:x])
    if len_gp > 1
        
        #Subsample if GP is too big 
        idx = []; max_len = 50
        if len_gp > max_len
            idx = Random.randsubseq(1:len_gp, min((max_len*1.2 * 1.0/len_gp),1.0))
            idx = idx[1:min(max_len, length(idx))]
        else
            idx = 1:len_gp
        end
        
        n = sum(1:length(idx)-1)
        delta_ts = Vector{Float64}(undef, n)
        delta_ys = Vector{Float64}(undef, n)
        rs = Array{Float64}(undef, 1, n)
        ps = Array{Float64}(undef, n, 3)
            
        k = 1
        for i = 1:length(idx)
            for j = (i+1):length(idx)
                i_idx = idx[i]; j_idx = idx[j]
                delta_ts[k] = gps[:x][j_idx] - gps[:x][i_idx]
                delta_ys[k] = abs(gps[:erb][j_idx] - gps[:erb][i_idx])
                rs[k] = delta_ys[k]/(delta_ts[k]*gp_sigma)
                lp = [Gen.logpdf(normal, delta_ts[k], mu_times[g],exp(logsigma_times[g])) for g = 1:3]
                p = exp.(lp .- logsumexp(lp))
                ps[k,:] = p
                k += 1
            end
        end

        mu_scale = (rs*ps*reshape(W_mu_scale, 3, 1))[1] + C_mu_scale #mu_scale(1)=rs(n,)*ps(n,3)*W(3,)
        logsigma_scale = (rs*ps*reshape(W_logsigma_scale, 3, 1))[1] + C_logsigma_scale 
        
    else

        mu_scale = replace_mu_scale
        logsigma_scale = replace_logsigma_scale
        
    end
    @trace(log_normal(mu_scale, exp(logsigma_scale)), :erb=>:scale)
    
    ## epislon, local noise parameter of GP
    #Could also calculate the deviation of a middle-point
    #from the straight line between its neighbours
    @param B_mu_noise::Vector{Float64}
    @param replace_mu_noise::Float64
    @param B_logsigma_noise::Vector{Float64}
    @param replace_logsigma_noise::Float64
    
    local_stds = []
    for i = 1:3:length(gps[:x])
        t = gps[:x][i]
        if issubset([t-steps["t"], t, t+steps["t"]], gps[:x])
            local_gp = gps[:erb][i-1:i+1] #they have to be continguous
            push!(local_stds, std(local_gp))
        end
    end
    mu_noise = 0.0; logsigma_noise = 0.0;
    if length(local_stds) > 1
        mu_epsilon = mean(local_stds)
        std_epsilon = std(local_stds)
        X = [mu_epsilon, std_epsilon, 1]

        mu_noise = dot(X, B_mu_noise)
        logsigma_noise = dot(X, B_logsigma_noise)
    else
        mu_noise = replace_mu_noise
        logsigma_noise = replace_logsigma_noise
    end
    @trace(log_normal(mu_noise, exp(logsigma_noise)), :erb=>:noise)
    
    return nothing
    
end

for p in [:B_mu_mu,:B_logsigma_mu,:B_mu_sigma,:B_logsigma_sigma,:W_mu_scale,:W_logsigma_scale,:B_mu_noise,:B_logsigma_noise]
    Gen.init_param!(trainable_erb_proposal, p, zeros(3))
end
#Gaussians:
Gen.init_param!(trainable_erb_proposal, :mu_times, [0.01,0.8,2.0])
Gen.init_param!(trainable_erb_proposal, :logsigma_times, [-2.0,-2.0,-2.0])
for p in [:log_std_replace,:C_mu_scale,:replace_mu_scale,:C_logsigma_scale,:replace_logsigma_scale,:replace_mu_noise,:replace_logsigma_noise]
    Gen.init_param!(trainable_erb_proposal, p, 0.0)
end
update = Gen.ParamUpdate(Gen.FixedStepGradientDescent(1e-8), trainable_erb_proposal);
scores = Gen.train!(trainable_erb_proposal, data_generator, update,
    num_epoch=100, epoch_size=50, num_minibatch=1, 
    minibatch_size=50, evaluation_size=10, verbose=false);

plot(scores)
xlabel("Iterations of stochastic gradient descent")
ylabel("Estimate of expected conditional log probability density");
title("erb GP")

In [None]:
@gen function trainable_amp2D_proposal(gps)
    
    #Mu and Sigma of GP
    @param B_mu_mu::Vector{Float64}
    @param B_logsigma_mu::Vector{Float64}
    @param B_mu_sigma::Vector{Float64}
    @param B_logsigma_sigma::Vector{Float64}
    @param log_std_replace::Float64
    
    mu_d = mean(gps[:amp])
    sigma_d = length(gps[:amp]) > 1 ? std(gps[:amp]) : exp(log_std_replace)
    X = [mu_d, sigma_d, 1]
    
    mu_mu = dot(X, B_mu_mu)
    logsigma_mu = dot(X, B_logsigma_mu)
    mu_sigma = dot(X, B_mu_sigma)
    logsigma_sigma = dot(X, B_logsigma_sigma)
    
    gp_mu = @trace(normal(mu_mu,exp(logsigma_mu)),:amp => :mu)
    gp_sigma = @trace(log_normal(mu_sigma, exp(logsigma_sigma)), :amp => :sigma)
    
    #Temporal lengthscale of GP
    @param W_mu_scale_t::Vector{Float64}
    @param C_mu_scale_t::Float64
    @param replace_mu_scale_t::Float64
    
    @param W_logsigma_scale_t::Vector{Float64}
    @param C_logsigma_scale_t::Float64
    @param replace_logsigma_scale_t::Float64
    
    @param mu_times_t::Vector{Float64}
    @param logsigma_times_t::Vector{Float64}
    
    mu_scale = 0.0; logsigma_scale = 0.0;
    len_gp = length(gps[:t])
    if len_gp > 1
        
        #Subsample if GP is too big 
        idx = []; max_len = 50
        if len_gp > max_len
            idx = Random.randsubseq(1:len_gp, min((max_len*1.5 * 1.0/len_gp),1.0))
            idx = idx[1:min(max_len, length(idx))]
        else
            idx = 1:len_gp
        end
        
        n = sum(1:length(idx)-1)
        delta_ts = Vector{Float64}(undef, n)
        delta_ys = Vector{Float64}(undef, n)
        rs = Array{Float64}(undef, 1, n)
        ps = Array{Float64}(undef, n, 3)
            
        k = 1
        for i = 1:length(idx)
            for j = (i+1):length(idx)
                i_idx = idx[i]; j_idx = idx[j]
                delta_ts[k] = gps[:x][j_idx] - gps[:x][i_idx]
                delta_ys[k] = abs(gps[:amp][j_idx] - gps[:amp][i_idx]) #correlation the two spectra
                rs[k] = delta_ys[k]/(delta_ts[k]*gp_sigma)
                lp = [Gen.logpdf(normal, delta_ts[k], mu_times_t[g],exp(logsigma_times_t[g])) for g = 1:3]
                p = exp.(lp .- logsumexp(lp))
                ps[k,:] = p
                k += 1
            end
        end

        mu_scale = (rs*ps*reshape(W_mu_scale_t, 3, 1))[1] + C_mu_scale_t #mu_scale(1)=rs(n,)*ps(n,3)*W(3,)
        logsigma_scale = (rs*ps*reshape(W_logsigma_scale_t, 3, 1))[1] + C_logsigma_scale_t 
        
    else

        mu_scale = replace_mu_scale
        logsigma_scale = replace_logsigma_scale
        
    end
    @trace(log_normal(mu_scale, exp(logsigma_scale)), :amp=>:scale_t)
    
    ## epislon, local noise parameter of GP
    #Could also calculate the deviation of a middle-point
    #from the straight line between its neighbours
    @param B_mu_noise::Vector{Float64}
    @param replace_mu_noise::Float64
    @param B_logsigma_noise::Vector{Float64}
    @param replace_logsigma_noise::Float64
    
    local_stds = []
    for i = 1:3:length(gps[:x])
        t = gps[:x][i]
        if issubset([t-steps["t"], t, t+steps["t"]], gps[:x])
            local_gp = gps[:amp][i-1:i+1] #they have to be continguous
            push!(local_stds, std(local_gp))
        end
    end
    mu_noise = 0.0; logsigma_noise = 0.0;
    if length(local_stds) > 1
        mu_epsilon = mean(local_stds)
        std_epsilon = std(local_stds)
        X = [mu_epsilon, std_epsilon, 1]

        mu_noise = dot(X, B_mu_noise)
        logsigma_noise = dot(X, B_logsigma_noise)
    else
        mu_noise = replace_mu_noise
        logsigma_noise = replace_logsigma_noise
    end
    @trace(log_normal(mu_noise, exp(logsigma_noise)), :amp=>:noise)
    
    return nothing
    
end


In [None]:
sort([3,2,1])

In [None]:
x = Dict() 
for i = 1:10
    if i == 1
        x[:rs] = zeros(3, 4)
    else
        x[:rs] = cat(x[:rs], zeros(3,4), dims=2)
    end
end
println(size(x[:rs]))

In [None]:
rand(3:10,3)

In [None]:
using Statistics: mean, std, cov;


In [None]:
cov(rand(4,5),dims=1)

In [None]:
length(rand(4,5))

In [None]:
using Statistics: I

In [None]:
1.0I