In [None]:
using PyCall #Put this first! 
using GenTF
tf = pyimport("tensorflow")
tf.compat.v1.disable_eager_execution()
# add https://github.com/mcusi/GenTF
# using Pkg; ENV["PYTHON"] = "/Users/maddiecusimano/Documents/basa-gen/env/bin/python"; Pkg.build("PyCall")
# julia> using PyCall; println(PyCall.python)
# /Users/maddiecusimano/.julia/conda/3/bin/python


In [None]:
using Gen;
using Random;
using Statistics: mean, std, cor;
using LinearAlgebra: dot;
using StatsFuns: logsumexp, softplus;
using PyPlot
using SpecialFunctions: digamma,trigamma;
include("./time_helpers.jl")
include("./extra_distributions.jl")
include("./gaussian_helpers.jl")

In [None]:
source_params, steps, gtg_params, obs_noise = include("./base_params.jl")
audio_sr = 20000;
scene_duration = 2.0

In [None]:
function make_source_latent_model(source_params, audio_sr, steps, scene_duration)
    
    @gen function source_latent_model(latents)

        ##Single function for generating data for amoritized inference to propose source-level latents
        # Wait: can generate on its own
        # Dur_minus_min: can generate on its own
        # GPs:
        # Can have a separate amoritized inference move for ERB, Amp-1D, Amp-2D 
        # Need to sample wait and dur_minus_min to define the time points for the GP 
        #
        # Format of latents: 
        # latents = Dict(:gp => :amp OR :tp => :wait, :source_type => "tone")

        ### SOURCE-LEVEL LATENTS 
        ## Sample GP source-level latents if needed 
        gp_latents = Dict()
        if :gp in keys(latents)

            gp_params = source_params["gp"]
            gp_type = latents[:gp]; gp_latents[gp_type] = Dict();
            source_type = latents[:source_type]
            
            hyperpriors = gp_type == :erb ? gp_params["erb"] : 
                ((source_type == "noise" || source_type == "harmonic") ? gp_params["amp"]["2D"] : gp_params["amp"]["1D"] )
    
            for latent in keys(hyperpriors)
                hyperprior = hyperpriors[latent]; syml = Symbol(latent)
                gp_latents[gp_type][syml] = @trace(hyperprior["dist"](hyperprior["args"]...), gp_type => syml)
            end
            
        end

        ## Sample temporal source-level latents 
        tp_latents = Dict()
        if :tp in keys(latents)
            tp_latents[latents[:tp]] = Dict()
        elseif :gp in keys(latents)
            tp_latents[:wait] = Dict()
            tp_latents[:dur_minus_min] = Dict()
        end            
        tp_params = source_params["tp"]
        for tp_type in keys(tp_latents)
            hyperpriors = tp_params[String(tp_type)]
            for latent in keys(hyperpriors)
                hyperprior = hyperpriors[latent]; syml = Symbol(latent)
                tp_latents[tp_type][syml] = @trace(hyperprior["dist"](hyperprior["args"]...), tp_type => syml)
            end
            tp_latents[tp_type][:args] = (tp_latents[tp_type][:a], tp_latents[tp_type][:mu]/tp_latents[tp_type][:a]) #a, b for gamma
        end

        ## Sample a number of elements 
        ne_params = source_params["n_elements"]
        n_elements = ne_params["type"] == "max" ? 
            @trace(uniform_discrete(1, ne_params["val"]),:n_elements) : 
            @trace(geometric(ne_params["val"]),:n_elements)    

        ### ELEMENT-LEVEL LATENTS 
        #Storage for what inputs are needed
        tp_elems = Dict( [ k => [] for k in keys(tp_latents)]... )
        gp_elems = Dict(); x_elems = [];
        if :gp in keys(latents)       
            
            gp_type = latents[:gp]; source_type = latents[:source_type]
            
            gp_elems[gp_type] = []
            gp_type = latents[:gp]
            gp_elems[:t] = []
            if gp_type == :amp && (source_type == "noise" || source_type == "harmonic")
                gp_elems[:reshaped] = []
                gp_elems[:f] = []
                gp_elems[:tf] = []
            end 
            
        end
                
        time_so_far = 0.0;
        for element_idx = 1:n_elements
            
            if :wait in keys(tp_elems)
                wait = element_idx == 1 ? @trace(uniform(0, scene_duration), (:element,element_idx)=>:wait) : 
                    @trace(gamma(tp_latents[:wait][:args]...), (:element,element_idx)=>:wait)
                push!(tp_elems[:wait], wait)
            end
            
            if :dur_minus_min in keys(tp_elems)
                dur_minus_min = @trace(truncated_gamma(tp_latents[:dur_minus_min][:args]..., source_params["duration_limit"]), (:element,element_idx)=>:dur_minus_min); 
                push!(tp_elems[:dur_minus_min], dur_minus_min)
            end
            
            if :gp in keys(latents)
                
                gp_type = latents[:gp]; source_type = latents[:source_type]
                duration = dur_minus_min + steps["min"]; onset = time_so_far + wait; 
                time_so_far = onset + duration; element_timing = [onset, time_so_far]
                
                if onset > scene_duration
                    break
                end
                
                ## Define points at which the GPs should be sampled
                x = []; ts = [];
                if gp_type === :erb || (source_type == "tone" && gp_type === :amp)
                    x = get_element_gp_times(element_timing, steps["t"])
                elseif gp_type === :amp && (source_type == "noise"  || source_type == "harmonic")
                    x, ts, gp_elems[:f] = get_gp_spectrotemporal(element_timing, steps, audio_sr)
                end

                mu, cov = element_idx == 1 ? get_mu_cov(x, gp_latents[gp_type]) : 
                        get_cond_mu_cov(x, x_elems, gp_elems[gp_type], gp_latents[gp_type])
                element_gp = @trace(mvnormal(mu, cov), (:element, element_idx) => gp_type)
                
                ## Save the element data 
                append!(x_elems, x)
                if gp_type === :erb || (source_type == "tone" && gp_type === :amp)
                    append!(gp_elems[:t], x)
                    append!(gp_elems[gp_type], element_gp)
                elseif gp_type === :amp && (source_type == "noise"  || source_type == "harmonic")
                    append!(gp_elems[:t],ts)
                    append!(gp_elems[:tf],x) 
                    append!(gp_elems[gp_type], element_gp)
                    reshaped_elem = reshape(element_gp, (length(gp_elems[:f]), length(ts))) 
                    if element_idx == 1
                        gp_elems[:reshaped] = reshaped_elem
                    else
                        gp_elems[:reshaped] = cat(gp_elems[:reshaped], reshaped_elem, dims=2)
                    end
                end
                
                if time_so_far > scene_duration
                    break
                end
           
            end

        end

        return tp_latents, gp_latents, tp_elems, gp_elems

    end
 
    return source_latent_model
    
end
source_latent_model = make_source_latent_model(source_params, audio_sr, steps, scene_duration);

In [None]:
function make_data_generator(source_latent_model, latents)
    
    function data_generator()

        trace = simulate(source_latent_model, (latents,))
        tp_latents, gp_latents, tp_elems, gp_elems = get_retval(trace)

        constraints = choicemap()
        if :tp in keys(latents)
            tp_latent = latents[:tp]
            constraints[tp_latent => :mu] = trace[tp_latent => :mu]
            constraints[tp_latent => :a] = trace[tp_latent => :a]
            
        elseif :gp in keys(latents)
            
            gp_type = latents[:gp]
            source_type = latents[:source_type]
            d = gp_type == :erb ? source_params["gp"]["erb"] : (source_type == "tone" ? source_params["gp"]["amp"]["1D"] : source_params["gp"]["amp"]["2D"]) 
            for k in keys(d)
                constraints[gp_type => Symbol(k)] = trace[gp_type => Symbol(k)]
            end
            
        end
        inputs = :tp in keys(latents) ? (tp_elems,) : (gp_elems,)
        
        return (inputs, constraints)

    end
    
    return data_generator
    
    end;

In [None]:
@gen function gen_gamma_model()
    @trace(gamma(3,1),:sigma)
    return rand()
end
function gamma_generator()
    trace = simulate(gen_gamma_model, ())
    r = get_retval(trace)
    constraints = choicemap()
    constraints[:sigma] = trace[:sigma]
    return ((r,), constraints)
end
nonlinearity=exp; 
@gen function gen_gamma_proposal(r)
    @param alpha::Float64
    @param beta::Float64
    a = nonlinearity(alpha); b=nonlinearity(beta);
    @trace(gamma(a,b),:sigma)
end
Gen.init_param!(gen_gamma_proposal, :alpha, 0.0)
Gen.init_param!(gen_gamma_proposal, :beta, 0.0)
#GradientDescent(1e-3,10)
update = Gen.ParamUpdate(Gen.FixedStepGradientDescent(1e-4), gen_gamma_proposal);
scores = pain!(gen_gamma_proposal, gamma_generator, update,
    num_epoch=5000, epoch_size=20, num_minibatch=1, 
    minibatch_size=20, evaluation_size=100, verbose=false);
plot(scores)
println("~~~~~~~~~")
println("~~~~~~~~~")
println("~~~~~~~~~")
for p in [:alpha, :beta]
    s = string(p)
    println("$s: ", nonlinearity(Gen.get_param(gen_gamma_proposal, p)))
end

In [None]:
latents = Dict()
latents[:wait] = Dict(:tp => :wait)
latents[:dur_minus_min] = Dict(:tp => :dur_minus_min)
latents[:amp1D] = Dict(:gp => :amp, :source_type => "tone")
latents[:amp2D] = Dict(:gp => :amp, :source_type => "noise")
latents[:erb] = Dict(:gp => :erb, :source_type => "tone")

data_generators = Dict()
for l in [:wait, :dur_minus_min, :amp1D, :amp2D, :erb]
    data_generators[l] = make_data_generator(source_latent_model, latents[l])
end

In [None]:
scene_t = get_element_gp_times([0,scene_duration], steps["t"]) 
net_inputs = []; masks = []; ms = []; ss = [];

In [None]:
gp_elem=data_generators[:amp2D]()[1][1]
mask = [t in gp_elem[:t] ? 1.0 : 0.0 for t in scene_t]
embedded_gp = []; k = 1; 
m = mean(gp_elem[:reshaped], dims=2); 
s = std(gp_elem[:reshaped]) #What about STD only in time direction--> (nf,) vector

nf = size(gp_elem[:reshaped])[1]
for t_idx in 1:length(scene_t)
    if mask[t_idx] == 1
        push!(embedded_gp, gp_elem[:reshaped][:,k])
        k += 1
    else
        push!(embedded_gp, m)
    end
end

In [None]:
net_input = cat(embedded_gp..., dims=2)
net_input = (net_input .- repeat(m, 1, length(scene_t))) ./ s
push!(net_inputs, reshape(net_input, 1, size(net_input)...))

mask = cat(mask...,dims=1)
push!(masks, reshape(mask, 1, size(mask)...))

push!(ms, m); push!(ss, s)

In [None]:
net_inputs

In [None]:
batch_net_input = cat(net_inputs..., dims=1)
szn = size(batch_net_input)
batch_net_input = reshape(batch_net_input, szn..., 1)
batch_mask_input = cat(masks..., dims=1)
szm = size(batch_mask_input) 
batch_mask_input = repeat(reshape(batch_mask_input, szm[1], 1, szm[2], 1), 1, szn[2], 1, 1)
batch_net_input = cat(batch_net_input, batch_mask_input, dims=4)
batch_net_input = permutedims(batch_net_input, [1, 3, 2, 4])
batch_mask_input = permutedims(batch_mask_input, [1, 3, 2, 4])

In [None]:
repeat(mean(cat(embedded_gp..., dims=2),dims=2),1,4)

In [None]:
permutedims(cat(ms..., dims=2),[2,1])


In [None]:
function create_trainable_tp_proposal(latent)
    
    @gen function trainable_tp_proposal(tp_elems)

        @param B_mean_mu::Vector{Float64}
        @param B_shape_mu::Float64
        @param C_shape_mu::Float64
        @param B_mean_alpha::Vector{Float64}
        @param B_shape_a::Float64
        @param C_shape_a::Float64
        @param MLE_estimate_alpha_0::Float64
        @param MOM_estimate_alpha_0::Float64
        
        n_elements = length(tp_elems[latent])
        if ( latent == :wait && n_elements > 1 ) || ( latent == :dur_minus_min )
            
            #For parameter estimation in Gamma distributions
            #MLE estimators             
            W = latent == :wait ? tp_elems[latent][2:end] : tp_elems[latent]
            MLE_estimate_mu = mean(W)
            if length(W) > 2    
                s = log(mean(W)) - mean(log.(W))
                MLE_estimate_alpha = (3.0 - s + sqrt((s-3)^2 + 24*s))/(12.0 * s) # initial guess
                for i = 1:5 #iterative updates based on Newton-Raphson method
                    numerator = log(MLE_estimate_alpha) - digamma(MLE_estimate_alpha) - s
                    denominator = (1.0/MLE_estimate_alpha) - trigamma(MLE_estimate_alpha)
                    MLE_estimate_alpha = MLE_estimate_alpha - (numerator/denominator)  
                end
            else
                MLE_estimate_alpha =  exp(MLE_estimate_alpha_0)
            end
            
            #Method of moments estimators 
            # MOM_estimate_mu = mean(W)
            # MOM_alpha is pretty close to MLE_Estimate_alpha
            MOM_estimate_alpha = length(W) > 2 ? (mean(W)^2)/(std(W)^2) : exp(MOM_estimate_alpha_0)
                        
            #For parameter estimation in log-normal distributions
            #MLE_estimate_mu = mean(log(W))
            #MLE_estimate_sigma = std(log(W))
            #MOM_estimate_mu = -0.5*log(sum(W.^2)) + 2*log(sum(W)) - 1.5*log(length(W))
            #MOM_estimate_sigma = log(sum(W.^2)) - 2*log(sum(W)) + log(length(W))
            #X = [log(mu), log(mu^2/sigma^2), log(n), 1.0] #First two are method of moments

            X_mu = [log(MLE_estimate_mu), 1.0]
            X_alpha = [log(MLE_estimate_alpha), 
                     log(MOM_estimate_alpha),  
                     1.0]            
#             X_shapes = [log(MLE_estimate_alpha),
#                        log(MOM_estimate_alpha), 
#                        log(mean([MLE_estimate_alpha, MOM_estimate_alpha])),
#                        log(n_elements),
#                        1.0]
                        
            mean_mu_estimate = exp( dot(X_mu, B_mean_mu) )
            shape_mu = exp( B_shape_mu*n_elements + C_shape_mu)
            mean_alpha_estimate = exp( dot(X_alpha, B_mean_alpha) )
            shape_a = exp( B_shape_a*n_elements + C_shape_a )

            @trace(gamma(shape_mu, mean_mu_estimate/shape_mu), latent => :mu)
            #@trace(log_normal(log(mean_alpha_estimate), shape_a), latent=> :a)
            @trace(gamma(shape_a, mean_alpha_estimate/shape_a), latent => :a)
            #@trace(log_normal(log(mu_mu),exp(logsigma_mu)), latent => :mu) #Mean of wait dist
            
        else
            
            @trace(log_normal(-2.0,1.5), latent => :mu)
            @trace(gamma(20,1), latent => :a)
            
        end

        return nothing

    end
    
    return trainable_tp_proposal
    
end

In [None]:
trainable_proposals = Dict()

In [None]:
trainable_proposals[:wait] = create_trainable_tp_proposal(:wait)
trainable_proposals[:dur_minus_min] = create_trainable_tp_proposal(:dur_minus_min);

In [None]:
latent = :dur_minus_min 
Gen.init_param!(trainable_proposals[latent], :B_mean_mu, [1.0, 0.0])
# Gen.init_param!(trainable_proposals[latent], :B_shape_mu, zeros(5))
Gen.init_param!(trainable_proposals[latent], :B_mean_alpha, zeros(3))
# Gen.init_param!(trainable_proposals[latent], :B_shape_a, zeros(5))
for p in [:MLE_estimate_alpha_0, :MOM_estimate_alpha_0, 
        :B_shape_mu, :B_shape_a, :C_shape_mu, :C_shape_a]
    Gen.init_param!(trainable_proposals[latent], p, 0.0)
end

#Gen.FixedStepGradientDescent(1e-5)
update = Gen.ParamUpdate(Gen.GradientDescent(1e-4,10), trainable_proposals[latent]);
scores = pain!(trainable_proposals[latent], data_generators[latent], update,
    num_epoch=5000, epoch_size=5, num_minibatch=1, 
    minibatch_size=5, evaluation_size=100, verbose=false);

plot(scores)
xlabel("Iterations of stochastic gradient descent")
ylabel("Estimate of expected conditional log probability density");
title(string("GP: ", string(latent)))

In [None]:
# for p in [:B_mu_mu,:B_logsigma_mu,:B_mu_alpha,:B_logsigma_alpha,
#           :log_std_replace]

for p in [:B_mean_mu,:B_shape_mu,:B_mean_alpha,:B_shape_a,
:MLE_estimate_alpha_0,:MOM_estimate_alpha_0,
    :C_shape_mu, :C_shape_a]
    s = string(p)
    println("$s: ", Gen.get_param(trainable_proposals[latent], p))
end

In [None]:
results_gt_2 = Dict(:a=>Dict(:c=>[], :t=>[]),
                :mu=>Dict(:c=>[], :t=>[]))
n_points = 1000
for z = 1:n_points
    i,c = data_generators[latent]()
    trace_list = Dict(:a=>[],:mu=>[])
    if length(i[1][:dur_minus_min]) >= 2
        for y = 1:10
            trace, _, _ = propose(trainable_proposals[latent],(i[1],))
            for a in keys(results_gt_2)
                append!(trace_list[a],trace[latent=>a])
            end
        end
        for a in keys(results_gt_2)
            push!(results_gt_2[a][:c],c[latent=>a])
            push!(results_gt_2[a][:t],mean(trace_list[a]))
        end
    end
end

results_lt_2 = Dict(:a=>Dict(:c=>[], :t=>[]),
                :mu=>Dict(:c=>[], :t=>[]))
n_points = 1000
for z = 1:n_points
    i,c = data_generators[latent]()
    trace_list = Dict(:a=>[],:mu=>[])
    if length(i[1][:dur_minus_min]) < 2
        for y = 1:10
            trace, _, _ = propose(trainable_proposals[latent],(i[1],))
            for a in keys(results_lt_2)
                append!(trace_list[a],trace[latent=>a])
            end
        end
        for a in keys(results_lt_2)
            push!(results_lt_2[a][:c],c[latent=>a])
            push!(results_lt_2[a][:t],mean(trace_list[a]))
        end
    end
end

In [None]:
scatter(results_gt_2[:a][:c],results_gt_2[:a][:t])
scatter(results_lt_2[:a][:c],results_lt_2[:a][:t])
#ylim([-0.5,100])
xlim([-0.5,50])
xlabel("Actual")
ylabel("Predicted=mean(10 samples from proposal)")
title("$latent, alpha parameter, n=$n_points")
legend([">=2elems", "<2elems"])

In [None]:
scatter(results_gt_2[:mu][:c],results_gt_2[:mu][:t])
scatter(results_lt_2[:mu][:c],results_lt_2[:mu][:t])
ylim([-0.5,7])
xlim([-0.5,7])
title("$latent, mu parameter, n=$n_points")
xlabel("Actual")
ylabel("Predicted=mean(10 samples from proposal)")
println("R for >=2elems: ", cor(results_gt_2[:mu][:c], results_gt_2[:mu][:t]))
println("Max element duration (for truncated Gamma): ", source_params["duration_limit"])
legend([">=2elems", "<2elems"])

In [None]:
using PyCall #Put this first! 
using GenTF
@pyimport tensorflow as tf

# add https://github.com/mcusi/GenTF
# using Pkg; ENV["PYTHON"] = "/Users/maddiecusimano/Documents/basa-gen/env/bin/python"; Pkg.build("PyCall")
# julia> using PyCall; println(PyCall.python)
# /Users/maddiecusimano/.julia/conda/3/bin/python

In [None]:
scene_t = get_element_gp_times([0,3.5], steps["t"])
length(scene_t)

In [None]:
tf.compat.v1.reset_default_graph()
        
mask = tf.compat.v1.placeholder(dtype=tf.float64, shape=(1, 1, 6,))

#Defining network
Wy = tf.compat.v1.get_variable("Wy",(1,1,6),initializer=tf.compat.v1.initializers.constant(0.0),dtype=tf.float64)
# WL = tf.compat.v1.get_variable("WL",[1,1,1],initializer=tf.compat.v1.initializers.glorot_normal(),dtype=tf.float64)
# L1 = tf.compat.v1.nn.conv1d(mask, WL, (1,), padding="SAME", data_format="NCW")
y = mask + Wy

nn = TFFunction([Wy], [mask], y);

x = [1.0, 1.0, 1.0, 0.0, 0.0, 0.0]
println("Input type: ", typeof(x))
m = reshape(x, 1, 1, length(x))
println("Size of m: ", size(m))
y_out = nn(m)

In [None]:
tf.compat.v1.reset_default_graph()
        
mask = tf.compat.v1.placeholder(dtype=tf.float64, shape=(1, 6, 1,))

#Defining network
WL = tf.compat.v1.get_variable("WL",(1,1,6),initializer=tf.compat.v1.initializers.glorot_uniform(),dtype=tf.float64)
L1 = tf.compat.v1.nn.conv1d(mask, WL, (1,), padding="SAME", data_format="NWC")

nn = TFFunction([WL], [mask], L1);

x = [1.0, 1.0, 1.0, 0.0, 0.0, 0.0]
println("Input type: ", typeof(x))
m = reshape(x, 1, length(x), 1)
println("Size of m: ", size(m))
y_out = nn(m)

In [None]:
x = [1.0, 1.0, 1.0, 0.0, 0.0, 0.0]
y = [2.0, 3.0, 4.0, 3.0, 3.0, 3.0]
print(typeof(y))
net_input = cat(x, y, dims=2)
net_input = reshape(net_input, 1, length(x), 1, 2)
println(size(net_input))
net_mask = reshape(x, 1, length(x),1, 1)
println(size(net_mask))
tf.compat.v1.global_variables_initializer()
params = nn(net_input, net_mask)
println(params)
(trace, _) = Gen.generate(nn, (net_input,net_mask,))
get_choices(trace)

In [None]:
tf.compat.v1.reset_default_graph()
function define_neural_net(filter_height, filter_width, out_channels, n_params; dilations=[1,2,4])
    
    #scene_t = get_element_gp_times([0,scene_duration], steps["t"])
    batch_size=1; in_channels=2; 
    in_height=nothing;#none;#could use scene_duration as input to "define neural net" then this equals length(scene_t); 
    in_width=1; #width=frequency 
    
    gp = tf.compat.v1.placeholder(dtype=tf.float64, 
        shape=(batch_size, in_height, in_width, in_channels)) #tf.placeholder
    mask = tf.compat.v1.placeholder(dtype=tf.float64, shape=(batch_size, in_height, 1, 1))

    #Defining network
    #filter_height = 3; filter_width = 1; out_channels = 16; n_params = 8;
    #dilations = [1,2,4]
    weights = []; layer_inputs = [gp];
    for i = 0:length(dilations)-1
        ic = i == 0 ? in_channels : out_channels
        oc = i == length(dilations) - 1 ? n_params : out_channels
        W = tf.compat.v1.get_variable("W$i", 
                (filter_height, filter_width, ic, oc), 
                dtype=tf.float64, 
                initializer=tf.compat.v1.initializers.glorot_normal())
        b = tf.compat.v1.get_variable("b$i", (oc), 
                dtype=tf.float64, 
                initializer=tf.constant_initializer(value=0.0))
        append!(weights, [W, b])
        L1 = tf.compat.v1.nn.convolution(layer_inputs[i+1], W, padding="SAME", 
             dilation_rate=(dilations[i+1], 1), data_format="NHWC")
        #L1 = tf.nn.conv2d(gp, W, (1,), padding="SAME", data_format="NHWC")
        L1 = tf.nn.bias_add(L1, b, data_format="NHWC")
        C = tf.nn.relu(L1)
        push!(layer_inputs, C)
    end
    #Convert maske convolutions to parameter estimates
    tile_mask = tf.tile(mask, (1, 1, 1, n_params))
    masked_conv = tf.multiply(tile_mask, layer_inputs[end])
    O = tf.reduce_sum(masked_conv,axis=(1,2)) #batch_size, n_params,
    denom = tf.expand_dims(tf.reduce_sum(mask, axis=(1,2,3)),1)
    O = tf.divide( O , denom ) 
    O = tf.nn.relu(O)
#     Wf = tf.compat.v1.get_variable("Wfinal",(1, n_params, n_params),dtype=tf.float64,
#             initializer=tf.compat.v1.initializers.random_normal(stddev=1.0)) --> put this back and train longer?
    Wf = tf.compat.v1.get_variable("Wfinal",(1,n_params),dtype=tf.float64,
              initializer=tf.compat.v1.initializers.random_normal(stddev=1.0))
    b = tf.compat.v1.get_variable("bfinal",(n_params),dtype=tf.float64, 
        initializer=tf.constant_initializer(value=0.0)) #put back to 1 and train longer?
    append!(weights, [Wf, b])
#     Wf = tf.tile(Wf, (batch_size, 1, 1))
#     M = tf.matmul(Wf, tf.expand_dims(O,2))
#     dist_parameters = tf.nn.bias_add(tf.squeeze(M,2), b)
    dist_parameters = tf.nn.bias_add(tf.multiply(Wf,O),b)
    
    for w in weights
        println(w)
    end
    
    nn = GenTF.TFFunction(weights, [gp, mask], dist_parameters);
    
    return nn, weights
    
end
nn,nn_weights = define_neural_net(5, 1, 16, 8, dilations=[1,2,4,8,16]);

In [None]:
function create_trainable_gp1D_proposal(latent)

    @gen function trainable_gp1D_proposal(gp_elems)

        #format GP for input into neural network 
        scene_t = get_element_gp_times([0,scene_duration], steps["t"]) #If training on variable length scenes, change this
        mask = [t in gp_elems[:t] ? 1.0 : 0.0 for t in scene_t]
        embedded_gp = []; k = 1; 
        m = mean(gp_elems[latent]); s = length(gp_elems[latent]) > 1 ? std(gp_elems[latent]) : 1.0
        for t_idx in 1:length(scene_t)
            if mask[t_idx] == 1
                push!(embedded_gp, gp_elems[latent][k])
                k += 1
            else
                push!(embedded_gp, m)
            end
        end
        embedded_gp = embedded_gp .- m
        net_input = cat(embedded_gp, mask, dims=2)
        mask = cat(mask...,dims=1)
        batch_size=1; in_channels=2; in_height=length(scene_t); in_width=1;
        net_input = float.(reshape(net_input, batch_size, in_height, in_width, in_channels))
        net_mask = float.(reshape(mask, batch_size, in_height, 1, 1))
        dparams = @trace(nn(net_input,net_mask), :net_parameters)
        
        mean_mu_estimate = dparams[1]*m; shape_mu = softplus(dparams[2])
        @trace(normal(mean_mu_estimate, shape_mu), latent => :mu) 
              
        mean_sigma_estimate = softplus(dparams[3]) + length(gp_elems[latent]) > 1 ? s : 0.0; 
        shape_sigma = softplus(dparams[4])
        @trace(gamma(shape_sigma, mean_sigma_estimate/shape_sigma), latent => :sigma)

        mean_scale_estimate = softplus(dparams[5]); shape_scale = softplus(dparams[6]); 
        @trace(gamma(shape_scale, mean_scale_estimate/shape_scale), latent => :scale)        
        
        mean_epsilon_estimate = softplus(dparams[7]); shape_epsilon = softplus(dparams[8]); 
        @trace(gamma(shape_epsilon, mean_epsilon_estimate/shape_epsilon), latent => :epsilon)

        return nothing

    end

    return trainable_gp1D_proposal
    
end
trainable_proposals[:erb] = create_trainable_gp1D_proposal(:erb);


In [None]:
evaluation_size = 100
evalsets=Dict()
evalsets[:erb] = [data_generators[:erb]() for i = 1:evaluation_size];

In [None]:
latent = :erb 

#-3.6
# update = Gen.ParamUpdate(Gen.FixedStepGradientDescent(1e-4),
#      nn => collect(get_params(nn)));
scores = pain!(trainable_proposals[latent], 
    data_generators[latent], 
    update,
    num_epoch=10, epoch_size=2000, num_minibatch=500, 
    minibatch_size=10, evaluation_size=100, 
    eval_inputs_and_constraints=evalsets[latent], verbose=true);
scatter(1:length(scores),scores,marker=".")
xlabel("Iterations of stochastic gradient descent")
ylabel("Estimate of expected conditional log probability density");
title(string("GP net: ", string(latent)))

In [None]:
softplus(-0.000133382)

In [None]:
saver = tf.compat.v1.train.Saver(Dict(string(var.name) => var for var in Gen.get_params(nn)))
saver.save(GenTF.get_session(nn), "./nn.ckpt")
saver.restore(GenTF.get_session(nn), "./nn.ckpt")

In [None]:
vartoplot = :sigma
n_points=1000
actual=[]; empirical=[];single=[];
for z = 1:n_points
    i,c = data_generators[latent]()
    gp = i[1][latent]
    if length(gp) > 1
        push!(empirical, std(gp))
        push!(actual, c[latent=>vartoplot])
    else
        push!(single, c[latent=>vartoplot])
    end
end
scatter(actual, empirical,marker=".")
scatter(single,zeros(length(single)))
legend([">=2","<2"])
xlabel("Sampled $vartoplot")
ylabel("STD calculated from GP")
title("$vartoplot compared to empirical std of gp")
maxy = maximum(empirical); miny=minimum(empirical);
maxx = maximum(actual); minx =minimum(actual);
xlim([min(miny,minx)-0.5, max(maxy,maxx)+0.5])
ylim([min(miny,minx)-0.5, max(maxy,maxx)+0.5])

In [None]:
n_points = 1000
n_same = 1
vartoplot = :sigma
actual = []
predicted = []
for z = 1:n_points
    i,c = data_generators[latent]()
    push!(actual, c[latent=>vartoplot])
    m = []
    for j = 1:n_same
        trace, _, _ = propose(trainable_proposals[latent],(i[1],))
        push!(m, trace[latent=>vartoplot])
    end
    push!(predicted, mean(m))
    #push!(predicted, std(i[1][:erb])) 
end
scatter(actual,predicted,marker=".")
xlabel("Actual")
ylabel("Mean of $n_same proposal samples")
r = round(cor(actual,predicted),digits = 4)
title("$latent gp $vartoplot from neural network, r=$r, n=$n_points")
maxy = maximum(predicted); miny=minimum(predicted);
maxx = maximum(actual); minx =minimum(actual);
xlim([min(miny,minx)-0.5, max(maxy,maxx)+0.5])
ylim([min(miny,minx)-0.5, max(maxy,maxx)+0.5])

In [None]:
sess = get_session(nn)
V = [var for var in tf.compat.v1.global_variables() if var.op.name=="Wfinal"]
println("Wlinear: ",sess[:run](V[1]))
v = [var for var in tf.compat.v1.global_variables() if var.op.name=="bfinal"]
println("blinear: ", sess[:run](v[1]))

In [None]:
a = :scale
h=[]
gp_val = [20.0]; t_val=[0.22]
n=length(gp_val)
for i = 1:1000
    choices,_,=propose(trainable_proposals[latent],(Dict(:erb=>gp_val,:t=>t_val),))
    push!(h,choices[latent => a])
end
hist(h,bins=20,density=true)
sp = source_params["gp"][string(latent)][string(a)]
xvals = a == :mu ? (0:50) : 0:0.05:6
plot(xvals,exp.([Gen.logpdf(sp["dist"],x,sp["args"]...) for x in xvals]))
legend(["pdf","samples"])
title("Hist $a with $n gp point vs. prior pdf")

In [None]:
function create_trainable_gp1D_proposal(latent; max_t=50, check_n=2, max_locals=50)

    @gen function trainable_gp1D_proposal(gp_elems)

        #Mu and Sigma of GP
        @param B_mu_mu::Vector{Float64}
        @param B_logsigma_mu::Vector{Float64}
        @param B_mu_sigma::Vector{Float64}
        @param B_logsigma_sigma::Vector{Float64}
        @param log_std_replace::Float64

        mu = mean(gp_elems[latent])
        sigma = length(gp_elems[latent]) > 1 ? std(gp_elems[latent]) : exp(log_std_replace)
        X = [mu, sigma, length(gp_elems[latent]), 1]

        mu_mu = dot(X, B_mu_mu)
        logsigma_mu = dot(X, B_logsigma_mu)
        mu_sigma = dot(X, B_mu_sigma)
        logsigma_sigma = dot(X, B_logsigma_sigma)

        gp_mu = @trace(normal(mu_mu,exp(logsigma_mu)), latent => :mu)
        gp_sigma = @trace(log_normal(mu_sigma, exp(logsigma_sigma)), latent => :sigma)

        #Temporal lengthscale of GP
        @param W_mu_scale::Vector{Float64}
        @param Bn_mu_scale::Float64
        @param C_mu_scale::Float64
        @param replace_mu_scale::Float64

        @param W_logsigma_scale::Vector{Float64}
        @param Bn_logsigma_scale::Float64
        @param C_logsigma_scale::Float64
        @param replace_logsigma_scale::Float64

        @param mu_cats::Vector{Float64} #for categorizing time differences
        @param logsigma_cats::Vector{Float64}

        mu_scale = 0.0; logsigma_scale = 0.0; t = gp_elems[:t];
        if length(t) > 1

            #Subsample if GP is too big 
            idx = []; 
            if length(t) <= max_t
                idx = 1:length(t) # 1 2 3 4
            else
                idx = Random.randperm(length(t)) # 4 7 2 1 8 5 3 6 10 9
                idx = idx[1:max_t] # 4 7 2 1
                idx = sort(idx) # 1 2 4 7 
            end

            n = sum(1:length(idx)-1) #how many pairs we will have 
            delta_ts = Vector{Float64}(undef, n)
            delta_ys = Vector{Float64}(undef, n)
            rs = Array{Float64}(undef, 1, n)
            ps = [] #need to do it this way (rather than initialize an array) otherwise the gradients will break

            k = 1
            for i = 1:length(idx)
                for j = (i+1):length(idx)
                    i_idx = idx[i]; j_idx = idx[j]
                    delta_ts[k] = t[j_idx] - t[i_idx]
                    delta_ys[k] = abs(gp_elems[latent][j_idx] - gp_elems[latent][i_idx])
                    rs[k] = delta_ys[k]/(delta_ts[k]*gp_sigma)
                    lp = [Gen.logpdf(normal, delta_ts[k], mu_cats[g],exp(logsigma_cats[g])) for g = 1:3]
                    p = exp.(lp .- logsumexp(lp))
                    push!(ps, p)
                    k += 1
                end
            end

            ps = transpose(cat(ps...,dims=2))
            mu_scale += (rs*ps*reshape(W_mu_scale, 3, 1))[1] + Bn_mu_scale*n + C_mu_scale #mu_scale(1)=rs(n,)*ps(n,3)*W(3,)
            logsigma_scale += (rs*ps*reshape(W_logsigma_scale, 3, 1))[1] + Bn_logsigma_scale*n + C_logsigma_scale 

        else

            mu_scale += replace_mu_scale
            logsigma_scale += replace_logsigma_scale

        end
        @trace(log_normal(mu_scale, exp(logsigma_scale)), latent => :scale)

        ## Epsilon, local noise parameter of GP
        #Check stds between closeby points only to get estimate
        #Could also calculate the deviation of a middle-point
        #from the straight line between its neighbours
        @param B_mu_epsilon::Vector{Float64}
        @param replace_mu_epsilon::Float64
        @param B_logsigma_epsilon::Vector{Float64}
        @param replace_logsigma_epsilon::Float64

        local_stds = []; segment_size = 2*check_n + 1
        for i = 1:segment_size:length(t) #It won't go beyond bounds, even if length(t) < segment_size
            if issubset( [ t[i]+k*steps["t"] for k = -check_n:check_n ], t )
                local_gp = gp_elems[latent][i-check_n:i+check_n] #contiguous GP points
                push!(local_stds, std(local_gp))
            end
            if length(local_stds) == max_locals
                break
            end
        end
        
        mu_epsilon = 0.0; logsigma_epsilon = 0.0;
        if length(local_stds) > 1
            mu_locals = mean(local_stds)
            std_locals = std(local_stds)
            X = [mu_locals, std_locals, length(local_stds), 1]

            mu_epsilon += dot(X, B_mu_epsilon)
            logsigma_epsilon += dot(X, B_logsigma_epsilon)
        else
            mu_epsilon += replace_mu_epsilon
            logsigma_epsilon += replace_logsigma_epsilon
        end
        
        @trace(log_normal(mu_epsilon, exp(logsigma_epsilon)), latent => :epsilon)

        return nothing

    end

    return trainable_gp1D_proposal
    
end
trainable_proposals[:amp1D] = create_trainable_gp1D_proposal(:amp)
trainable_proposals[:erb] = create_trainable_gp1D_proposal(:erb);

In [None]:
latent = :erb 
for p in [:B_mu_mu,:B_logsigma_mu,:B_mu_sigma,:B_logsigma_sigma,
          :B_mu_epsilon,:B_logsigma_epsilon]
    Gen.init_param!(trainable_proposals[latent], p, zeros(4))
end
for p in [:W_mu_scale,:W_logsigma_scale]
    Gen.init_param!(trainable_proposals[latent], p, zeros(3))
end
#Gaussians:
Gen.init_param!(trainable_proposals[latent], :mu_cats, [0.01,0.8,2.0])
Gen.init_param!(trainable_proposals[latent], :logsigma_cats, [-1.0,-1.0,-1.0])
for p in [:log_std_replace,
          :C_mu_scale,:replace_mu_scale,:C_logsigma_scale,:replace_logsigma_scale,
          :Bn_mu_scale, :Bn_logsigma_scale,
          :replace_mu_epsilon,:replace_logsigma_epsilon]
    Gen.init_param!(trainable_proposals[latent], p, 0.0)
end
update = Gen.ParamUpdate(Gen.FixedStepGradientDescent(1e-8), trainable_proposals[latent]);
scores = Gen.train!(trainable_proposals[latent], data_generators[latent], update,
    num_epoch=50, epoch_size=50, num_minibatch=1, 
    minibatch_size=50, evaluation_size=10, verbose=false);

plot(scores)
xlabel("Iterations of stochastic gradient descent")
ylabel("Estimate of expected conditional log probability density");
title(string("GP: ", string(latent)))

In [None]:
for p in [:B_mu_mu,:B_logsigma_mu,:B_mu_sigma,:B_logsigma_sigma,
          :B_mu_epsilon,:B_logsigma_epsilon,
          :W_mu_scale,:W_logsigma_scale,
          :log_std_replace,
          :C_mu_scale,:replace_mu_scale,:C_logsigma_scale,:replace_logsigma_scale,
          :Bn_mu_scale, :Bn_logsigma_scale,
          :replace_mu_epsilon,:replace_logsigma_epsilon]
    s = string(p)
    println("$s: ", Gen.get_param(trainable_proposals[latent], p))
end

In [None]:
function create_trainable_gp2D_proposal(latent; max_t=5, check_nt=1, check_nf=2, max_locals=5)

    @gen function trainable_gp2D_proposal(gp_elems)

        #Mu and Sigma of GP
        @param B_mu_mu::Vector{Float64}
        @param B_logsigma_mu::Vector{Float64}
        @param B_mu_sigma::Vector{Float64}
        @param B_logsigma_sigma::Vector{Float64}

        mu = mean(gp_elems[latent])
        sigma = std(gp_elems[latent])
        X = [mu, sigma, length(gp_elems[latent]), 1]

        mu_mu = dot(X, B_mu_mu)
        logsigma_mu = dot(X, B_logsigma_mu)
        mu_sigma = dot(X, B_mu_sigma)
        logsigma_sigma = dot(X, B_logsigma_sigma)

        gp_mu = @trace(normal(mu_mu,exp(logsigma_mu)), latent => :mu)
        gp_sigma = @trace(log_normal(mu_sigma, exp(logsigma_sigma)), latent => :sigma)
        gp_elems[:reshaped] = gp_elems[:reshaped] .- gp_mu #de-mean

        #Temporal lengthscale of GP
        @param W_mu_scale_t::Vector{Float64}
        @param Bn_mu_scale_t::Float64
        @param C_mu_scale_t::Float64
        @param replace_mu_scale_t::Float64

        @param W_logsigma_scale_t::Vector{Float64}
        @param Bn_logsigma_scale_t::Float64
        @param C_logsigma_scale_t::Float64
        @param replace_logsigma_scale_t::Float64

        @param mu_cats::Vector{Float64} #for categorizing time differences
        @param logsigma_cats::Vector{Float64}

        mu_scale_t = 0.0; logsigma_scale_t = 0.0; t = gp_elems[:t];
        if length(t) > 1

            #Subsample if GP is too big 
            idx = []; 
            if length(t) <= max_t
                idx = 1:length(t) # 1 2 3 4
            else
                idx = Random.randperm(length(t)) # 4 7 2 1 8 5 3 6 10 9
                idx = idx[1:max_t] # 4 7 2 1
                idx = sort(idx) # 1 2 4 7 
            end

            n = sum(1:length(idx)-1) #how many pairs we will have 
            delta_ts = Vector{Float64}(undef, n)
            delta_ys = Vector{Float64}(undef, n)
            rs = []#Array{Float64}(undef, 1, n)
            ps = [] #need to do it this way (rather than initialize an array) otherwise the gradients will break

            k = 1
            for i = 1:length(idx)
                for j = (i+1):length(idx)
                    i_idx = idx[i]; j_idx = idx[j]
                    delta_ts[k] = t[j_idx] - t[i_idx]
                    #Differences between 2D and 1D case
                    push!(rs, cor( gp_elems[:reshaped][:,j_idx] , gp_elems[:reshaped][:,i_idx] ))
                    #delta_ys[k] = sqrt(sum((p1 .- p2).^2)) 
                    #rs[k] = delta_ys[k]/delta_ts[k]
                    lp = [Gen.logpdf(normal, delta_ts[k], mu_cats[g],exp(logsigma_cats[g])) for g = 1:3]
                    p = exp.(lp .- logsumexp(lp))
                    push!(ps, p)
                    k += 1
                end
            end
            #cat(ps) is necessary to be able to compute the derivative given the list
            #You need to use a list rather than editting a pre-made vector in place
            ps = transpose(cat(ps...,dims=2))
            rs = reshape(cat(rs...,dims=1), 1, n)
            mu_scale_t += (rs*ps*reshape(W_mu_scale_t, 3, 1))[1] + Bn_mu_scale_t*n + C_mu_scale_t #mu_scale(1)=rs(n,)*ps(n,3)*W(3,)
            logsigma_scale_t += (rs*ps*reshape(W_logsigma_scale_t, 3, 1))[1] + Bn_logsigma_scale_t*n + C_logsigma_scale_t 

        else

            mu_scale_t += replace_mu_scale_t
            logsigma_scale_t += replace_logsigma_scale_t

        end
        @trace(log_normal(mu_scale_t, exp(logsigma_scale_t)), latent => :scale_t)

        #Frequency lengthscale of GP
        @param B_mu_scale_f::Vector{Float64} #len = sum(1:length(f)-1)+2
        @param B1_mu_scale_f::Vector{Float64}
        @param B_logsigma_scale_f::Vector{Float64}
        @param B1_logsigma_scale_f::Vector{Float64}
        
        f = gp_elems[:f]; 
        rs = [];
        for i = 1:length(f)
            for j = i+1:length(f)
                if length(t) > 1
                    r = cor( gp_elems[:reshaped][j,:], gp_elems[:reshaped][i,:] )
                else
                    dy = abs(gp_elems[:reshaped][j,1] - gp_elems[:reshaped][i,1])
#                     df = f[j] - f[i]
                    r = dy/gp_sigma
                end
                push!(rs, r)  
            end
        end
        #rcat is necessary to be able to compute the derivative given the list
        #You need to use a list rather than editting a pre-made vector in place
        rcat = cat(rs..., length(t), 1.0, dims=1) 
        Bm = length(t) > 1 ? B_mu_scale_f : B1_mu_scale_f
        Bs = length(t) > 1 ? B_logsigma_scale_f : B1_logsigma_scale_f
        mu_scale_f = dot(Bm,rcat)
        logsigma_scale_f = dot(Bs,rcat)
        @trace(log_normal(mu_scale_f, exp(logsigma_scale_f)), latent => :scale_f)
        
        ## Epsilon, local noise parameter of GP
        #Check stds between closeby points only to get estimate
        #Could also calculate the deviation of a middle-point
        #from the straight line between its neighbours
        @param B_mu_epsilon::Vector{Float64}
        @param replace_mu_epsilon::Float64
        @param B_logsigma_epsilon::Vector{Float64}
        @param replace_logsigma_epsilon::Float64

        local_stds = []; segment_size_t = 2*check_nt + 1; 
        for i = 1:segment_size_t:length(t) #It won't go beyond bounds, even if length(t) < segment_size
            if issubset( [ t[i]+k*steps["t"] for k = -check_nt:check_nt ], t )
                j = rand(1+check_nf:length(f)-check_nf)
                local_gp = gp_elems[:reshaped][j-check_nf:j+check_nf, i-check_nt:i+check_nt][:] #contiguous GP points
                push!(local_stds, std(local_gp))
            end
            if length(local_stds) == max_locals
                break
            end
        end

        mu_epsilon = 0.0; logsigma_epsilon = 0.0;
        if length(local_stds) > 1
            mu_locals = mean(local_stds)
            std_locals = std(local_stds)
            X = [mu_locals, std_locals, length(local_stds), 1.0]

            mu_epsilon += dot(X, B_mu_epsilon)
            logsigma_epsilon += dot(X, B_logsigma_epsilon)
        else
            mu_epsilon += replace_mu_epsilon
            logsigma_epsilon += replace_logsigma_epsilon
        end
        @trace(log_normal(mu_epsilon, exp(logsigma_epsilon)), latent => :epsilon)

        return nothing

    end

    return trainable_gp2D_proposal
    
end
trainable_proposals[:amp2D] = create_trainable_gp2D_proposal(:amp);

In [None]:
function pain!(gen_fn::GenerativeFunction, data_generator::Function,
                update::ParamUpdate;
                num_epoch=1, epoch_size=1, num_minibatch=1, minibatch_size=1,
                evaluation_size=epoch_size, eval_inputs_and_constraints=[], verbose=false)

    history = Vector{Float64}(undef, num_epoch)
    
    if length(eval_inputs_and_constraints) == 0 
        println("Making evaluation dataset:")
        for i = 1:evaluation_size
            print("$i ")
            ic = data_generator()
            push!(eval_inputs_and_constraints, ic)
        end
        println()
    else
        evaluation_size = length(eval_inputs_and_constraints)
        println("Using eval set of size $evaluation_size")
    end
    
    for epoch=1:num_epoch

        s1=time()
        # generate data for epoch
        if verbose
            println("epoch $epoch: generating $epoch_size training examples...")
        end
        epoch_inputs = Vector{Tuple}(undef, epoch_size)
        epoch_choice_maps = Vector{ChoiceMap}(undef, epoch_size)
        for i=1:epoch_size
            (epoch_inputs[i], epoch_choice_maps[i]) = data_generator()
        end
        s2=time()-s1
        println("Time to generate data: $s2")

        # train on epoch data
        s1=time()
        if verbose
            println("epoch $epoch: training using $num_minibatch minibatches of size $minibatch_size...")
        end
        for minibatch=1:num_minibatch
            permuted = Random.randperm(epoch_size)
            minibatch_idx = permuted[1:minibatch_size]
            minibatch_inputs = epoch_inputs[minibatch_idx]
            minibatch_choice_maps = epoch_choice_maps[minibatch_idx]
            for (inputs, constraints) in zip(minibatch_inputs, minibatch_choice_maps)
                (trace, _) = generate(gen_fn, inputs, constraints)
                retval_grad = accepts_output_grad(gen_fn) ? zero(get_retval(trace)) : nothing
                accumulate_param_gradients!(trace, retval_grad)
            end
            apply!(update)
        end
        s2=time()-s1
        println("Time to train: $s2")

        # evaluate score on held out data
        s1=time()
        if verbose
            println("epoch $epoch: evaluating on $evaluation_size examples...")
        end
        avg_score = 0.
        for i=1:evaluation_size
            (_, weight) = generate(gen_fn, eval_inputs_and_constraints[i][1], eval_inputs_and_constraints[i][2])
            avg_score += weight
        end
        avg_score /= evaluation_size
        @assert ~isnan(avg_score)
        history[epoch] = avg_score

        if verbose
            println("epoch $epoch: est. objective value: $avg_score")
        end
        s2=time()-s1
        println("Time to evaluate: $s2")

    end
    return history
end

In [None]:
latent = :amp2D
for p in [:B_mu_mu,:B_logsigma_mu,:B_mu_sigma,:B_logsigma_sigma,
        :B_mu_epsilon,:B_logsigma_epsilon]
    Gen.init_param!(trainable_proposals[latent], p, zeros(4))
end
#Gaussians for scale_t:
for p in [:W_mu_scale_t,:W_logsigma_scale_t]
    Gen.init_param!(trainable_proposals[latent], p, zeros(3))
end
Gen.init_param!(trainable_proposals[latent], :mu_cats, [0.01,0.8,2.0])
Gen.init_param!(trainable_proposals[latent], :logsigma_cats, [-1.0,-1.0,-1.0])
#scale_f
lf = length(get_element_gp_freqs(audio_sr, steps)); nf = sum(1:lf-1); 
for p in [:B_mu_scale_f, :B_logsigma_scale_f, :B1_mu_scale_f, :B1_logsigma_scale_f]
    Gen.init_param!(trainable_proposals[latent], p, zeros(nf + 2))
end
for p in [:log_std_replace,
        :C_mu_scale_t,:Bn_mu_scale_t,:replace_mu_scale_t,
        :C_logsigma_scale_t,:Bn_logsigma_scale_t,:replace_logsigma_scale_t,
        :replace_mu_epsilon,:replace_logsigma_epsilon]
    Gen.init_param!(trainable_proposals[latent], p, 0.0)
end
update = Gen.ParamUpdate(GradientDescent(1e-8, 100), trainable_proposals[latent]);
scores = pain!(trainable_proposals[latent], data_generators[latent], update,
    num_epoch=100, epoch_size=5, num_minibatch=1, 
    minibatch_size=5, evaluation_size=400, verbose=true)

In [None]:
plot(scores)
xlabel("Iterations of stochastic gradient descent")
ylabel("Estimate of expected conditional log probability density");
title(string("GP: ", string(latent)))

In [None]:
for p in [:B_mu_mu,:B_logsigma_mu,:B_mu_sigma,:B_logsigma_sigma,
        :B_mu_epsilon,:B_logsigma_epsilon,
    :W_mu_scale_t,:W_logsigma_scale_t,
    :B_mu_scale_f, :B_logsigma_scale_f, 
    :B1_mu_scale_f, :B1_logsigma_scale_f, :log_std_replace,
        :C_mu_scale_t,:Bn_mu_scale_t,:replace_mu_scale_t,
        :C_logsigma_scale_t,:Bn_logsigma_scale_t,:replace_logsigma_scale_t,
        :replace_mu_epsilon,:replace_logsigma_epsilon]
    s = string(p)
    println("$s: ", Gen.get_param(trainable_proposals[latent], p))    
end

In [None]:
function create_trainable_gp2Dmu_proposal(latent; max_t=50, check_nt=1, check_nf=2, max_locals=50)

    @gen function trainable_gp2Dmu_proposal(gp_elems)

        #Mu and Sigma of GP
        @param B_mu_mu::Matrix{Float64} #size=(f,f) ~ 400 
        @param W_logsigma_mu::Matrix{Float64} # size=(f,f+2) ~440
        @param C_logsigma_mu::Vector{Float64} #if Diagonal, size=f; if Matrix, size=(f,f) 
        @param log_std_replace::Vector{Float64} #if Diagonal, size=f; if Matrix size(f,f)
        
        mean_spectrum = mean(gp_elems[:reshaped],dims=2)[:,1]
        mean_estimate = B_mu_mu*mean_spectrum
        nt = size(gp_elems[:reshaped])[2] #min value = 1 
        delta_t = gp_elems[:t][end]-gp_elems[:t][1] #min value = 0, how far data has travlled 
        if nt > 1
            sig = std(gp_elems[:reshaped],dims=2)
            s = cat(sig..., nt, delta_t, dims=1)
            dia = exp.((W_logsigma_mu*s)[:,1] + C_logsigma_mu)
        else
            dia = exp.(log_std_replace)
        end
        gp_mu = @trace(mvnormal(mean_estimate, Diagonal(dia)), latent => :mu)
        
        #For the rest of the function, work with the de-mean'd elements
        @param B_mu_sigma::Vector{Float64}
        @param B_logsigma_sigma::Vector{Float64}        
        gp_elems[:reshaped] = gp_elems[:reshaped] - repeat(reshape(gp_mu, length(gp_mu), 1), 1, nt)
        sigma = std(gp_elems[:reshaped]);         
        X = [sigma, length(gp_elems[:reshaped]), 1]
        mu_sigma = dot(X, B_mu_sigma)
        logsigma_sigma = dot(X, B_logsigma_sigma)
        gp_sigma = @trace(log_normal(mu_sigma, exp(logsigma_sigma)), latent => :sigma)

        #Temporal lengthscale of GP
        @param W_mu_scale_t::Vector{Float64}
        @param Bn_mu_scale_t::Float64
        @param C_mu_scale_t::Float64
        @param replace_mu_scale_t::Float64

        @param W_logsigma_scale_t::Vector{Float64}
        @param Bn_logsigma_scale_t::Float64
        @param C_logsigma_scale_t::Float64
        @param replace_logsigma_scale_t::Float64

        @param mu_cats::Vector{Float64} #for categorizing time differences
        @param logsigma_cats::Vector{Float64}

        mu_scale_t = 0.0; logsigma_scale_t = 0.0; t = gp_elems[:t];
        if length(t) > 1

            #Subsample if GP is too big 
            idx = []; 
            if length(t) <= max_t
                idx = 1:length(t) # 1 2 3 4
            else
                idx = Random.randperm(length(t)) # 4 7 2 1 8 5 3 6 10 9
                idx = idx[1:max_t] # 4 7 2 1
                idx = sort(idx) # 1 2 4 7 
            end

            n = sum(1:length(idx)-1) #how many pairs we will have 
            delta_ts = Vector{Float64}(undef, n)
            delta_ys = Vector{Float64}(undef, n)
            rs = Array{Float64}(undef, 1, n)
            ps = [] #need to do it this way (rather than initialize an array) otherwise the gradients will break

            k = 1
            for i = 1:length(idx)
                for j = (i+1):length(idx)
                    i_idx = idx[i]; j_idx = idx[j]
                    delta_ts[k] = t[j_idx] - t[i_idx]
                    #Differences between 2D and 1D case
                    p1 = gp_elems[:reshaped][:,j_idx] #Already de-mean'd
                    p2 = gp_elems[:reshaped][:,i_idx] #Already de-mean'd
                    delta_ys[k] = cor( p1 , p2 )
                    #delta_ys[k] = sqrt(sum((p1 .- p2).^2)) 
                    rs[k] = delta_ys[k]/delta_ts[k]
                    lp = [Gen.logpdf(normal, delta_ts[k], mu_cats[g],exp(logsigma_cats[g])) for g = 1:3]
                    p = exp.(lp .- logsumexp(lp))
                    push!(ps, p)
                    k += 1
                end
            end
            #cat(ps) is necessary to be able to compute the derivative given the list
            #You need to use a list rather than editting a pre-made vector in place
            ps = transpose(cat(ps...,dims=2))
            mu_scale_t += (rs*ps*reshape(W_mu_scale_t, 3, 1))[1] + Bn_mu_scale_t*n + C_mu_scale_t #mu_scale(1)=rs(n,)*ps(n,3)*W(3,)
            logsigma_scale_t += (rs*ps*reshape(W_logsigma_scale_t, 3, 1))[1] + Bn_logsigma_scale_t*n + C_logsigma_scale_t 

        else

            mu_scale_t += replace_mu_scale_t
            logsigma_scale_t += replace_logsigma_scale_t

        end
        @trace(log_normal(mu_scale_t, exp(logsigma_scale_t)), latent => :scale_t)

        #Frequency lengthscale of GP
        @param B_mu_scale_f::Vector{Float64} #len = sum(1:length(f)-1)+2
        @param B1_mu_scale_f::Vector{Float64}
        @param B_logsigma_scale_f::Vector{Float64}
        @param B1_logsigma_scale_f::Vector{Float64}
        
        f = gp_elems[:f]; 
        rs = [];
        for i = 1:length(f)
            for j = i+1:length(f)
                if length(t) > 1
                    p1 = gp_elems[:reshaped][j,:] #Already de-mean'd
                    p2 = gp_elems[:reshaped][i,:] #Already de-mean'd
                    r = cor( p1, p2 )
                else
                    dy = abs(gp_elems[:reshaped][j,1] - gp_elems[:reshaped][i,1])
#                     df = f[j] - f[i]
                    r = dy/gp_sigma
                end
                push!(rs, r)  
            end
        end
        #rcat is necessary to be able to compute the derivative given the list
        #You need to use a list rather than editting a pre-made vector in place
        rcat = cat(rs..., length(t), 1.0, dims=1) 
        Bm = length(t) > 1 ? B_mu_scale_f : B1_mu_scale_f
        Bs = length(t) > 1 ? B_logsigma_scale_f : B1_logsigma_scale_f
        mu_scale_f = dot(Bm,rcat)
        logsigma_scale_f = dot(Bs,rcat)
        @trace(log_normal(mu_scale_f, exp(logsigma_scale_f)), latent => :scale_f)
        
        ## Epsilon, local noise parameter of GP
        #Check stds between closeby points only to get estimate
        #Could also calculate the deviation of a middle-point
        #from the straight line between its neighbours
        @param B_mu_epsilon::Vector{Float64}
        @param replace_mu_epsilon::Float64
        @param B_logsigma_epsilon::Vector{Float64}
        @param replace_logsigma_epsilon::Float64

        local_stds = []; segment_size_t = 2*check_nt + 1; 
        for i = 1:segment_size_t:length(t) #It won't go beyond bounds, even if length(t) < segment_size
            if issubset( [ t[i]+k*steps["t"] for k = -check_nt:check_nt ], t )
                j = rand(1+check_nf:length(f)-check_nf)
                local_gp = gp_elems[:reshaped][j-check_nf:j+check_nf, i-check_nt:i+check_nt][:] #contiguous GP points
                push!(local_stds, std(local_gp))
            end
            if length(local_stds) == max_locals
                break
            end
        end

        mu_epsilon = 0.0; logsigma_epsilon = 0.0;
        if length(local_stds) > 1
            mu_locals = mean(local_stds)
            std_locals = std(local_stds)
            X = [mu_locals, std_locals, length(local_stds), 1.0]

            mu_epsilon += dot(X, B_mu_epsilon)
            logsigma_epsilon += dot(X, B_logsigma_epsilon)
        else
            mu_epsilon += replace_mu_epsilon
            logsigma_epsilon += replace_logsigma_epsilon
        end
        @trace(log_normal(mu_epsilon, exp(logsigma_epsilon)), latent => :epsilon)

        return nothing
    end

    return trainable_gp2Dmu_proposal
    
end

In [None]:
spectrum_source_params = source_params
noise_mu_params = Dict(:sigma=> 10.0, :scale=> 4.0, :epsilon=> 1.0)
f = get_element_gp_freqs(audio_sr, steps)
c = generate_covariance_matrix(f, noise_mu_params)
spectrum_source_params["gp"]["amp"]["2D"]["mu"]["dist"]=mvnormal
spectrum_source_params["gp"]["amp"]["2D"]["mu"]["args"]=(fill(10.0,length(f)),c)
spectrum_latent_model = make_source_latent_model(spectrum_source_params, audio_sr, steps, scene_duration)

evaluation_size = 400
latents[:spectrum] = Dict(:gp => :amp, :source_type => "noise")
data_generators[:spectrum] = make_data_generator(spectrum_latent_model, latents[:spectrum])
trainable_proposals[:spectrum] = create_trainable_gp2Dmu_proposal(:amp);
ic = Dict()
ic[:spectrum] = [data_generators[:spectrum]() for i = 1:evaluation_size];

In [None]:
println("Made ", length(ic[:spectrum]), " datapoints for eval.")
#size(data_generators[:spectrum]()[2][:amp => :mu]) = (9,)

In [None]:
latent = :spectrum
#spectrum mean
lf = length(get_element_gp_freqs(audio_sr, steps)); 
Gen.init_param!(trainable_proposals[latent], :B_mu_mu, zeros(lf,lf) + Diagonal(fill(1e-5,lf)))
Gen.init_param!(trainable_proposals[latent], :W_logsigma_mu, zeros(lf, lf+2) + cat(Diagonal(fill(1e-5,lf)),1e-5*ones(lf,2),dims=2))
Gen.init_param!(trainable_proposals[latent], :C_logsigma_mu, zeros(lf))
Gen.init_param!(trainable_proposals[latent], :log_std_replace, zeros(lf))
#local epsilon noise
for p in [:B_mu_epsilon,:B_logsigma_epsilon]
    Gen.init_param!(trainable_proposals[latent], p, zeros(4))
end
#sigma of GP kernel & Gaussians for scale_t:
for p in [:B_mu_sigma,:B_logsigma_sigma,
          :W_mu_scale_t,:W_logsigma_scale_t]
    Gen.init_param!(trainable_proposals[latent], p, zeros(3))
end
Gen.init_param!(trainable_proposals[latent], :mu_cats, [0.01,0.8,2.0])
Gen.init_param!(trainable_proposals[latent], :logsigma_cats, [-1.0,-1.0,-1.0])
#scale_f
nf = sum(1:lf-1); 
for p in [:B_mu_scale_f, :B_logsigma_scale_f, :B1_mu_scale_f, :B1_logsigma_scale_f]
    Gen.init_param!(trainable_proposals[latent], p, zeros(nf + 2))
end
for p in [:C_mu_scale_t,:Bn_mu_scale_t,:replace_mu_scale_t,
        :C_logsigma_scale_t,:Bn_logsigma_scale_t,:replace_logsigma_scale_t,
        :replace_mu_epsilon,:replace_logsigma_epsilon]
    Gen.init_param!(trainable_proposals[latent], p, 0.0)
end
update = Gen.ParamUpdate(GradientDescent(1e-8, 100), trainable_proposals[latent]); #1e-8
scores = pain!(trainable_proposals[latent], data_generators[latent], update,
    num_epoch=1000, epoch_size=5, num_minibatch=1, 
    minibatch_size=5, evaluation_size=400, eval_inputs_and_constraints=ic[latent], verbose=true)

plot(scores)
xlabel("Iterations of stochastic gradient descent")
ylabel("Estimate of expected conditional log probability density");
title(string("GP: ", string(latent)))

In [None]:
for p in [:B_mu_mu,:W_logsigma_mu,:C_logsigma_mu,:log_std_replace,
    :B_mu_epsilon,:B_logsigma_epsilon,
    :B_mu_sigma,:B_logsigma_sigma,:W_mu_scale_t,:W_logsigma_scale_t,
        :mu_cats,:logsigma_cats,
    :B_mu_scale_f, :B_logsigma_scale_f, :B1_mu_scale_f, :B1_logsigma_scale_f,
    :log_std_replace,
        :C_mu_scale_t,:Bn_mu_scale_t,:replace_mu_scale_t,
        :C_logsigma_scale_t,:Bn_logsigma_scale_t,:replace_logsigma_scale_t,
        :replace_mu_epsilon,:replace_logsigma_epsilon]
    s = string(p)
    println("$s: ", Gen.get_param(trainable_proposals[latent], p))    
end

In [None]:
x =  [9.04086e-6 2.44144e-5 4.27344e-5 -6.39516e-5 2.30337e-5 5.08731e-5 3.33898e-5 9.96364e-7 -4.10042e-5; -1.45731e-6 2.34339e-5 5.1541e-5 -2.76425e-5 3.94316e-5 6.06721e-5 3.30423e-5 -4.3137e-6 -3.0916e-5; -1.72753e-5 2.34302e-5 9.49068e-5 -2.84909e-5 8.27286e-6 9.26835e-7 2.58571e-5 -1.45681e-5 -2.27462e-5; 6.87276e-8 3.88148e-5 7.13896e-5 -4.49074e-5 8.19337e-6 8.52178e-6 4.72579e-5 -1.79114e-5 5.91304e-6; -9.16937e-6 4.75617e-5 5.14232e-5 -3.78198e-5 3.42397e-5 -1.04747e-5 8.92965e-6 -1.74021e-5 3.83545e-5; -2.48914e-5 2.84504e-5 5.00691e-5 -3.7815e-5 3.12402e-5 7.34944e-6 1.09918e-5 -5.00089e-6 3.13616e-5; -2.44663e-5 -1.41264e-5 5.41902e-5 -3.58634e-5 4.40222e-5 3.12245e-5 2.85718e-5 1.25084e-5 -2.69875e-6; -4.60631e-5 -4.45336e-5 3.49122e-5 -3.78671e-5 1.43824e-5 1.97026e-5 7.41281e-6 2.50954e-5 -6.98757e-6; -3.80551e-5 -3.90626e-6 3.39495e-5 -2.45203e-5 7.07236e-6 -1.01003e-6 -1.64635e-5 2.10555e-5 2.16596e-5]
plt.imshow(x)
# y = [0.00756103 0.00789582 0.00861523 0.00837371 0.00816384 0.00770816 0.00873847 0.00829039 0.00851563 0.0255121 0.00129844; 0.00756103 0.00789582 0.00861523 0.00837371 0.00816384 0.00770816 0.00873847 0.00829039 0.00851563 0.0255121 0.00129844; 0.00756103 0.00789582 0.00861523 0.00837371 0.00816384 0.00770816 0.00873847 0.00829039 0.00851563 0.0255121 0.00129844; 0.00756103 0.00789582 0.00861523 0.00837371 0.00816384 0.00770816 0.00873847 0.00829039 0.00851563 0.0255121 0.00129844; 0.00756103 0.00789582 0.00861523 0.00837371 0.00816384 0.00770816 0.00873847 0.00829039 0.00851563 0.0255121 0.00129844; 0.00756103 0.00789582 0.00861523 0.00837371 0.00816384 0.00770816 0.00873847 0.00829039 0.00851563 0.0255121 0.00129844; 0.00756103 0.00789582 0.00861523 0.00837371 0.00816384 0.00770816 0.00873847 0.00829039 0.00851563 0.0255121 0.00129844; 0.00756103 0.00789582 0.00861523 0.00837371 0.00816384 0.00770816 0.00873847 0.00829039 0.00851563 0.0255121 0.00129844; 0.00756103 0.00789582 0.00861523 0.00837371 0.00816384 0.00770816 0.00873847 0.00829039 0.00851563 0.0255121 0.00129844]
# plt.imshow(y)
x = [-2.69401e-5 -8.06477e-6 -5.79068e-6 -3.43522e-5 2.64049e-5 1.51815e-5 -4.44227e-5 -1.55228e-5 1.84711e-5; -2.07865e-5 8.60812e-6 -5.41509e-6 -7.2196e-6 2.63588e-5 1.44334e-5 -3.13472e-5 -3.41347e-5 8.26784e-6; -2.2381e-5 4.44332e-6 4.57912e-5 2.86229e-5 4.57247e-5 2.26134e-5 -7.31034e-5 -5.7595e-5 2.98627e-5; 1.49586e-6 -3.62218e-6 6.22207e-5 5.23283e-5 9.82693e-5 6.75616e-5 -6.54816e-5 -6.67691e-5 2.78072e-5; -8.30655e-6 -2.1471e-5 1.83352e-5 1.761e-5 0.000114502 5.75503e-5 -9.3877e-5 -9.90538e-5 1.17878e-5; -4.07757e-5 8.22624e-6 2.42239e-5 1.29366e-5 7.45208e-5 5.29256e-5 -9.25707e-5 -7.30924e-5 7.57786e-6; -3.50436e-5 3.9368e-5 5.19086e-5 6.72993e-6 5.94673e-5 6.39709e-5 -8.69786e-5 -4.99406e-5 2.43322e-5; 2.03586e-6 1.90803e-5 2.6395e-5 2.27187e-5 4.6972e-5 3.21086e-5 -6.9864e-5 -8.75867e-6 5.25513e-5; -5.47772e-6 -2.13026e-5 -7.34398e-6 3.37024e-5 6.27379e-5 1.87207e-5 -1.681e-5 -1.02412e-5 2.97232e-5]
plt.imshow(x)
x = [0.00482351 0.00482301 0.00475211 0.00490983 0.00477718 0.00485861 0.00492406 0.004856 0.00498754 0.0178249 0.00069147; 0.00510025 0.00507656 0.00494271 0.00520305 0.00506702 0.00518367 0.00531098 0.00513295 0.00525123 0.0181254 0.000711474; 0.00500333 0.00506287 0.00486013 0.00509223 0.00499615 0.00494119 0.00515581 0.00512848 0.00503768 0.0179779 0.000720386; 0.00498367 0.00489816 0.00475871 0.00496714 0.00479857 0.00480101 0.00491759 0.00493257 0.00489266 0.0179533 0.000708448; 0.0051403 0.0050187 0.00499268 0.00497967 0.00491829 0.00496581 0.00508766 0.00517086 0.00511394 0.0178707 0.000709368; 0.0051091 0.00507743 0.00508256 0.00502558 0.00515143 0.00516184 0.00518949 0.00520754 0.00513222 0.0176809 0.000690717; 0.00516648 0.00510071 0.00510093 0.00502205 0.00509871 0.00508232 0.00514596 0.00514076 0.005121 0.0175336 0.000708668; 0.00519239 0.00511753 0.00509234 0.00503172 0.00501836 0.00511156 0.00514286 0.00523219 0.00531263 0.017629 0.000813126; 0.00528517 0.00522148 0.00517495 0.00510391 0.0050347 0.00506368 0.00511372 0.00529983 0.00532528 0.0179866 0.000790359]
plt.imshow(x)

In [None]:
latent=:spectrum
rs = []
results = Dict(:scale_t=>Dict(:c=>[], :t=>[]),
                :scale_f=>Dict(:c=>[], :t=>[]),
                :sigma=>Dict(:c=>[], :t=>[]),
                :epsilon=>Dict(:c=>[],:t=>[]))
for z = 1:100
    i,c = data_generators[latent]()
    trace, _, _ = propose(trainable_proposals[latent],(i[1],))
    push!(rs,cor(c[:amp=>:mu],trace[:amp=>:mu]))
    for a in keys(results)
        push!(results[a][:c],c[:amp=>a])
        push!(results[a][:t],trace[:amp=>a])
    end
end

In [None]:
hist(rs)
title(string("Correlation between sampled mu and predicted, mean: ", round(mean(rs),digits=3)))

In [None]:
scatter(results[:scale_t][:c],results[:scale_t][:t])
xlabel("Sampled scale_t")
ylabel("Predicted scale_t")

In [None]:
scatter(results[:scale_f][:c],results[:scale_f][:t])
xlabel("Sampled scale_f")
ylabel("Predicted scale_f")

In [None]:
scatter(results[:epsilon][:c],results[:epsilon][:t])
xlabel("Sampled epsilon")
ylabel("Predicted epsilon")

In [None]:
scatter(results[:sigma][:c],results[:sigma][:t])
xlabel("Sampled sigma")
ylabel("Predicted sigma")
ylim([0,45])

### Define functions with the learned params built in, so that they can be used deterministically inside of involution calls

In [None]:
IJulia.set_verbose()

In [None]:
x=[1.0 2.53617e-8 1.66455e-8 4.6655e-8 1.08739e-8 5.86392e-9 4.55852e-8 -2.69102e-8 3.74473e-10; -2.87734e-9 1.0 8.10441e-9 1.6443e-8 3.66857e-9 4.11638e-10 1.52253e-8 -1.11827e-8 -7.00062e-10; -4.06295e-8 -1.05507e-8 1.0 -1.96244e-8 -1.77769e-9 2.61952e-8 -4.8047e-9 4.08587e-8 1.43087e-8; -2.41308e-8 -2.69301e-8 -5.15402e-8 1.0 -9.15042e-9 1.86673e-8 -3.60519e-8 5.43367e-8 1.21603e-8; 9.28191e-9 -1.15712e-9 6.31115e-9 -2.08482e-9 1.0 -5.46661e-9 -4.66172e-9 -4.1514e-9 -2.63898e-9; 6.78281e-8 3.49587e-8 9.70636e-8 6.47599e-8 9.74444e-9 1.0 3.57981e-8 -9.34781e-8 -2.69687e-8; 3.35265e-8 1.0935e-8 4.05314e-8 2.03186e-8 2.32841e-9 -2.19906e-8 1.0 -3.69765e-8 -1.2207e-8; 6.57322e-8 1.02229e-7 1.74281e-7 1.88714e-7 3.62453e-8 -5.51601e-8 1.44639e-7 1.0 -3.82369e-8; 5.12542e-8 6.02737e-8 1.13079e-7 1.11325e-7 2.0641e-8 -4.01075e-8 8.15196e-8 -1.19883e-7 1.0]
plt.imshow(x.-Diagonal(ones(9)))

In [None]:
deterministic_proposals=Dict()

function create_deterministic_tp_proposal(latent, B_mu_mu::Vector{Float64}, B_logsigma_mu::Vector{Float64},
    B_mu_alpha::Vector{Float64}, B_logsigma_alpha::Vector{Float64}, log_std_replace::Float64)

    function deterministic_tp_proposal(tp_elems)

        n = length(tp_elems[latent])
        mu = mean(tp_elems[latent])
        sigma = n > 1 ? std(tp_elems[latent]) : exp(log_std_replace)
        X = [mu, sigma, n, 1]

        mu_mu = dot(X, B_mu_mu)
        logsigma_mu = dot(X, B_logsigma_mu)
        mu_alpha = dot(X, B_mu_alpha)
        logsigma_alpha = dot(X, B_logsigma_alpha)

        return Dict(latent => 
                    Dict(:mu => Dict("dist"=>log_normal,"args"=>(mu_mu, exp(logsigma_mu)))),
                    Dict(:a => Dict("dist"=>log_normal, "args"=>(mu_alpha, exp(logsigma_alpha))))
                    )

    end
    
    return deterministic_tp_proposal
    
end

function create_deterministic_gp1D_proposal(latent,B_mu_mu::Vector{Float64},B_logsigma_mu::Vector{Float64},B_mu_sigma::Vector{Float64},
         B_logsigma_sigma::Vector{Float64},log_std_replace::Float64, W_mu_scale::Vector{Float64},
        Bn_mu_scale::Float64,C_mu_scale::Float64,replace_mu_scale::Float64,W_logsigma_scale::Vector{Float64},
        Bn_logsigma_scale::Float64,C_logsigma_scale::Float64,replace_logsigma_scale::Float64,
        mu_cats::Vector{Float64},logsigma_cats::Vector{Float64}, B_mu_epsilon::Vector{Float64},
        replace_mu_epsilon::Float64, B_logsigma_epsilon::Vector{Float64},
        replace_logsigma_epsilon::Float64; max_t=50, check_n=2, max_locals=50)
        
    function deterministic_gp1D_proposal(gp_elems)

        #Mu and Sigma of GP
        mu = mean(gp_elems[latent])
        sigma = length(gp_elems[latent]) > 1 ? std(gp_elems[latent]) : exp(log_std_replace)
        X = [mu, sigma, length(gp_elems[latent]), 1]

        mu_mu = dot(X, B_mu_mu)
        logsigma_mu = dot(X, B_logsigma_mu)
        mu_sigma = dot(X, B_mu_sigma)
        logsigma_sigma = dot(X, B_logsigma_sigma)

        #Temporal lengthscale of GP
        mu_scale = 0.0; logsigma_scale = 0.0; t = gp_elems[:t];
        if length(t) > 1

            #Subsample if GP is too big 
            idx = []; 
            if length(t) <= max_t
                idx = 1:length(t) # 1 2 3 4
            else
                idx = Random.randperm(length(t)) # 4 7 2 1 8 5 3 6 10 9
                idx = idx[1:max_t] # 4 7 2 1
                idx = sort(idx) # 1 2 4 7 
            end

            n = sum(1:length(idx)-1) #how many pairs we will have 
            delta_ts = Vector{Float64}(undef, n)
            delta_ys = Vector{Float64}(undef, n)
            rs = Array{Float64}(undef, 1, n)
            ps = [] #need to do it this way (rather than initialize an array) otherwise the gradients will break

            k = 1
            for i = 1:length(idx)
                for j = (i+1):length(idx)
                    i_idx = idx[i]; j_idx = idx[j]
                    delta_ts[k] = t[j_idx] - t[i_idx]
                    delta_ys[k] = abs(gp_elems[latent][j_idx] - gp_elems[latent][i_idx])
                    ##This is an issue! you need to sample sigma before you compute the other ones. 
                    rs[k] = delta_ys[k]/(delta_ts[k]*gp_sigma)
                    lp = [Gen.logpdf(normal, delta_ts[k], mu_cats[g],exp(logsigma_cats[g])) for g = 1:3]
                    p = exp.(lp .- logsumexp(lp))
                    push!(ps, p)
                    k += 1
                end
            end

            ps = transpose(cat(ps...,dims=2))
            mu_scale += (rs*ps*reshape(W_mu_scale, 3, 1))[1] + Bn_mu_scale*n + C_mu_scale #mu_scale(1)=rs(n,)*ps(n,3)*W(3,)
            logsigma_scale += (rs*ps*reshape(W_logsigma_scale, 3, 1))[1] + Bn_logsigma_scale*n + C_logsigma_scale 

        else

            mu_scale += replace_mu_scale
            logsigma_scale += replace_logsigma_scale

        end
        
        ## Epsilon, local noise parameter of GP
        #Check stds between closeby points only to get estimate
        #Could also calculate the deviation of a middle-point
        #from the straight line between its neighbours

        local_stds = []; segment_size = 2*check_n + 1
        for i = 1:segment_size:length(t) #It won't go beyond bounds, even if length(t) < segment_size
            if issubset( [ t[i]+k*steps["t"] for k = -check_n:check_n ], t )
                local_gp = gp_elems[latent][i-check_n:i+check_n] #contiguous GP points
                push!(local_stds, std(local_gp))
            end
            if length(local_stds) == max_locals
                break
            end
        end
        
        mu_epsilon = 0.0; logsigma_epsilon = 0.0;
        if length(local_stds) > 1
            mu_locals = mean(local_stds)
            std_locals = std(local_stds)
            X = [mu_locals, std_locals, length(local_stds), 1]

            mu_epsilon += dot(X, B_mu_epsilon)
            logsigma_epsilon += dot(X, B_logsigma_epsilon)
        else
            mu_epsilon += replace_mu_epsilon
            logsigma_epsilon += replace_logsigma_epsilon
        end
        
        return Dict(latent => Dict(:mu => Dict("dist"=>normal, args=>(mu_mu, exp(logsigma_mu))),
                            :sigma => Dict("dist"=>log_normal, args=>(mu_sigma, exp(logsigma_sigma))),
                            :scale => Dict("dist"=>log_normal, args=>(mu_scale, exp(logsigma_scale))),
                            :epsilon => Dict("dist"=>log_normal, args=>(mu_epsilon, exp(logsigma_epsilon)))
                            
                ))

    end

    return deterministic_gp1D_proposal
    
end

function create_deterministic_gp2D_proposal(latent; max_t=5, check_nt=1, check_nf=2, max_locals=5)

        @param B_mu_mu::Vector{Float64}
        @param B_logsigma_mu::Vector{Float64}
        @param B_mu_sigma::Vector{Float64}
        @param B_logsigma_sigma::Vector{Float64}
    
            @param W_mu_scale_t::Vector{Float64}
        @param Bn_mu_scale_t::Float64
        @param C_mu_scale_t::Float64
        @param replace_mu_scale_t::Float64

        @param W_logsigma_scale_t::Vector{Float64}
        @param Bn_logsigma_scale_t::Float64
        @param C_logsigma_scale_t::Float64
        @param replace_logsigma_scale_t::Float64

        @param mu_cats::Vector{Float64} #for categorizing time differences
        @param logsigma_cats::Vector{Float64}

    
    function deterministic_gp2D_proposal(gp_elems)

        #Mu and Sigma of GP
        mu = mean(gp_elems[latent])
        sigma = std(gp_elems[latent])
        X = [mu, sigma, length(gp_elems[latent]), 1]

        mu_mu = dot(X, B_mu_mu)
        logsigma_mu = dot(X, B_logsigma_mu)
        mu_sigma = dot(X, B_mu_sigma)
        logsigma_sigma = dot(X, B_logsigma_sigma)

        gp_mu = @trace(normal(mu_mu,exp(logsigma_mu)), latent => :mu)
        gp_sigma = @trace(log_normal(mu_sigma, exp(logsigma_sigma)), latent => :sigma)
        gp_elems[:reshaped] = gp_elems[:reshaped] .- gp_mu #de-mean

        #Temporal lengthscale of GP
        mu_scale_t = 0.0; logsigma_scale_t = 0.0; t = gp_elems[:t];
        if length(t) > 1

            #Subsample if GP is too big 
            idx = []; 
            if length(t) <= max_t
                idx = 1:length(t) # 1 2 3 4
            else
                idx = Random.randperm(length(t)) # 4 7 2 1 8 5 3 6 10 9
                idx = idx[1:max_t] # 4 7 2 1
                idx = sort(idx) # 1 2 4 7 
            end

            n = sum(1:length(idx)-1) #how many pairs we will have 
            delta_ts = Vector{Float64}(undef, n)
            delta_ys = Vector{Float64}(undef, n)
            rs = []#Array{Float64}(undef, 1, n)
            ps = [] #need to do it this way (rather than initialize an array) otherwise the gradients will break

            k = 1
            for i = 1:length(idx)
                for j = (i+1):length(idx)
                    i_idx = idx[i]; j_idx = idx[j]
                    delta_ts[k] = t[j_idx] - t[i_idx]
                    #Differences between 2D and 1D case
                    push!(rs, cor( gp_elems[:reshaped][:,j_idx] , gp_elems[:reshaped][:,i_idx] ))
                    #delta_ys[k] = sqrt(sum((p1 .- p2).^2)) 
                    #rs[k] = delta_ys[k]/delta_ts[k]
                    lp = [Gen.logpdf(normal, delta_ts[k], mu_cats[g],exp(logsigma_cats[g])) for g = 1:3]
                    p = exp.(lp .- logsumexp(lp))
                    push!(ps, p)
                    k += 1
                end
            end
            #cat(ps) is necessary to be able to compute the derivative given the list
            #You need to use a list rather than editting a pre-made vector in place
            ps = transpose(cat(ps...,dims=2))
            rs = reshape(cat(rs...,dims=1), 1, n)
            mu_scale_t += (rs*ps*reshape(W_mu_scale_t, 3, 1))[1] + Bn_mu_scale_t*n + C_mu_scale_t #mu_scale(1)=rs(n,)*ps(n,3)*W(3,)
            logsigma_scale_t += (rs*ps*reshape(W_logsigma_scale_t, 3, 1))[1] + Bn_logsigma_scale_t*n + C_logsigma_scale_t 

        else

            mu_scale_t += replace_mu_scale_t
            logsigma_scale_t += replace_logsigma_scale_t

        end
        @trace(log_normal(mu_scale_t, exp(logsigma_scale_t)), latent => :scale_t)

        #Frequency lengthscale of GP
        @param B_mu_scale_f::Vector{Float64} #len = sum(1:length(f)-1)+2
        @param B1_mu_scale_f::Vector{Float64}
        @param B_logsigma_scale_f::Vector{Float64}
        @param B1_logsigma_scale_f::Vector{Float64}
        
        f = gp_elems[:f]; 
        rs = [];
        for i = 1:length(f)
            for j = i+1:length(f)
                if length(t) > 1
                    r = cor( gp_elems[:reshaped][j,:], gp_elems[:reshaped][i,:] )
                else
                    dy = abs(gp_elems[:reshaped][j,1] - gp_elems[:reshaped][i,1])
#                     df = f[j] - f[i]
                    r = dy/gp_sigma
                end
                push!(rs, r)  
            end
        end
        #rcat is necessary to be able to compute the derivative given the list
        #You need to use a list rather than editting a pre-made vector in place
        rcat = cat(rs..., length(t), 1.0, dims=1) 
        Bm = length(t) > 1 ? B_mu_scale_f : B1_mu_scale_f
        Bs = length(t) > 1 ? B_logsigma_scale_f : B1_logsigma_scale_f
        mu_scale_f = dot(Bm,rcat)
        logsigma_scale_f = dot(Bs,rcat)
        @trace(log_normal(mu_scale_f, exp(logsigma_scale_f)), latent => :scale_f)
        
        ## Epsilon, local noise parameter of GP
        #Check stds between closeby points only to get estimate
        #Could also calculate the deviation of a middle-point
        #from the straight line between its neighbours
        @param B_mu_epsilon::Vector{Float64}
        @param replace_mu_epsilon::Float64
        @param B_logsigma_epsilon::Vector{Float64}
        @param replace_logsigma_epsilon::Float64

        local_stds = []; segment_size_t = 2*check_nt + 1; 
        for i = 1:segment_size_t:length(t) #It won't go beyond bounds, even if length(t) < segment_size
            if issubset( [ t[i]+k*steps["t"] for k = -check_nt:check_nt ], t )
                j = rand(1+check_nf:length(f)-check_nf)
                local_gp = gp_elems[:reshaped][j-check_nf:j+check_nf, i-check_nt:i+check_nt][:] #contiguous GP points
                push!(local_stds, std(local_gp))
            end
            if length(local_stds) == max_locals
                break
            end
        end

        mu_epsilon = 0.0; logsigma_epsilon = 0.0;
        if length(local_stds) > 1
            mu_locals = mean(local_stds)
            std_locals = std(local_stds)
            X = [mu_locals, std_locals, length(local_stds), 1.0]

            mu_epsilon += dot(X, B_mu_epsilon)
            logsigma_epsilon += dot(X, B_logsigma_epsilon)
        else
            mu_epsilon += replace_mu_epsilon
            logsigma_epsilon += replace_logsigma_epsilon
        end
        @trace(log_normal(mu_epsilon, exp(logsigma_epsilon)), latent => :epsilon)

        return nothing

    end

    return trainable_gp2D_proposal
    
end