In [None]:
using GaussianProcesses
ess

In [None]:
# Much of this notebook is based on https://github.com/STOR-i/GaussianProcesses.jl/blob/master/notebooks/Regression.ipynb
using Gen
using Random 
using Optim
using PyPlot
import Plots #For GP "plot"
using Statistics: cor
include("./extra_distributions.jl")
include("./time_helpers.jl")
include("./gaussian_helpers.jl")

In [None]:
function make_source_latent_model(source_params, audio_sr, steps)
    
    @gen function source_latent_model(latents, scene_duration)

        ##Single function for generating data for amoritized inference to propose source-level latents
        # Wait: can generate on its own
        # Dur_minus_min: can generate on its own
        # GPs:
        # Can have a separate amoritized inference move for ERB, Amp-1D, Amp-2D 
        # Need to sample wait and dur_minus_min to define the time points for the GP 
        #
        # Format of latents: 
        # latents = Dict(:gp => :amp OR :tp => :wait, :source_type => "tone")

        ### SOURCE-LEVEL LATENTS 
        ## Sample GP source-level latents if needed 
        gp_latents = Dict()
        if :gp in keys(latents)

            gp_params = source_params["gp"]
            gp_type = latents[:gp]; gp_latents[gp_type] = Dict();
            source_type = latents[:source_type]
            
            hyperpriors = gp_type == :erb ? gp_params["erb"] : 
                ((source_type == "noise" || source_type == "harmonic") ? gp_params["amp"]["2D"] : gp_params["amp"]["1D"] )
    
            for latent in keys(hyperpriors)
                hyperprior = hyperpriors[latent]; syml = Symbol(latent)
                gp_latents[gp_type][syml] = @trace(hyperprior["dist"](hyperprior["args"]...), gp_type => syml)
            end
            
        end

        ## Sample temporal source-level latents 
        tp_latents = Dict()
        if :tp in keys(latents)
            tp_latents[latents[:tp]] = Dict()
        elseif :gp in keys(latents)
            tp_latents[:wait] = Dict()
            tp_latents[:dur_minus_min] = Dict()
        end            
        tp_params = source_params["tp"]
        for tp_type in keys(tp_latents)
            hyperpriors = tp_params[String(tp_type)]
            for latent in keys(hyperpriors)
                hyperprior = hyperpriors[latent]; syml = Symbol(latent)
                tp_latents[tp_type][syml] = @trace(hyperprior["dist"](hyperprior["args"]...), tp_type => syml)
            end
            tp_latents[tp_type][:args] = (tp_latents[tp_type][:a], tp_latents[tp_type][:mu]/tp_latents[tp_type][:a]) #a, b for gamma
        end

        ## Sample a number of elements 
        ne_params = source_params["n_elements"]
        n_elements = ne_params["type"] == "max" ? 
            @trace(uniform_discrete(1, ne_params["val"]),:n_elements) : 
            @trace(geometric(ne_params["val"]),:n_elements)    

        ### ELEMENT-LEVEL LATENTS 
        #Storage for what inputs are needed
        tp_elems = Dict( [ k => [] for k in keys(tp_latents)]... )
        gp_elems = Dict(); x_elems = [];
        if :gp in keys(latents)       
            
            gp_type = latents[:gp]; source_type = latents[:source_type]
            
            gp_elems[gp_type] = []
            gp_type = latents[:gp]
            gp_elems[:t] = []
            if gp_type == :amp && (source_type == "noise" || source_type == "harmonic")
                gp_elems[:reshaped] = []
                gp_elems[:f] = []
                gp_elems[:tf] = []
            end 
            
        end
                
        time_so_far = 0.0;
        for element_idx = 1:n_elements
            
            if :wait in keys(tp_elems)
                wait = element_idx == 1 ? @trace(uniform(0, scene_duration-steps["t"]), (:element,element_idx)=>:wait) : 
                    @trace(gamma(tp_latents[:wait][:args]...), (:element,element_idx)=>:wait)
                push!(tp_elems[:wait], wait)
            end
            
            if :dur_minus_min in keys(tp_elems)
                dur_minus_min = @trace(truncated_gamma(tp_latents[:dur_minus_min][:args]..., source_params["duration_limit"]), (:element,element_idx)=>:dur_minus_min); 
                push!(tp_elems[:dur_minus_min], dur_minus_min)
            end
            
            if :gp in keys(latents)
                
                gp_type = latents[:gp]; source_type = latents[:source_type]
                duration = dur_minus_min + steps["min"]; onset = time_so_far + wait; 

                if onset > scene_duration
                    break
                end

                time_so_far = onset + duration; element_timing = [onset, time_so_far]

                ## Define points at which the GPs should be sampled
                x = []; ts = [];
                if gp_type === :erb || (source_type == "tone" && gp_type === :amp)
                    x = get_element_gp_times(element_timing, steps["t"])
                elseif gp_type === :amp && (source_type == "noise"  || source_type == "harmonic")
                    x, ts, gp_elems[:f] = get_gp_spectrotemporal(element_timing, steps, audio_sr)
                end

                mu, cov = element_idx == 1 ? get_mu_cov(x, gp_latents[gp_type]) : 
                        get_cond_mu_cov(x, x_elems, gp_elems[gp_type], gp_latents[gp_type])
                element_gp = @trace(mvnormal(mu, cov), (:element, element_idx) => gp_type)
                
                ## Save the element data 
                append!(x_elems, x)
                if gp_type === :erb || (source_type == "tone" && gp_type === :amp)
                    append!(gp_elems[:t], x)
                    append!(gp_elems[gp_type], element_gp)
                elseif gp_type === :amp && (source_type == "noise"  || source_type == "harmonic")
                    append!(gp_elems[:t],ts)
                    append!(gp_elems[:tf],x) 
                    append!(gp_elems[gp_type], element_gp)
                    reshaped_elem = reshape(element_gp, (length(gp_elems[:f]), length(ts))) 
                    if element_idx == 1
                        gp_elems[:reshaped] = reshaped_elem
                    else
                        gp_elems[:reshaped] = cat(gp_elems[:reshaped], reshaped_elem, dims=2)
                    end
                end

                if time_so_far > scene_duration
                    break
                end
           
            end

        end

        return tp_latents, gp_latents, tp_elems, gp_elems

    end
 
    return source_latent_model
    
end
function make_batch_data_generator(source_latent_model, latents, batch_size)
    
    function batch_data_generator()


        max_scene_duration = 2.5; min_scene_duration = 0.5;
        scene_duration = round((max_scene_duration-min_scene_duration)*rand() + min_scene_duration,digits=3)
        traces = []; tplas = []; gplas = []; tpels = []; gpels = [];
        for i = 1:batch_size
            trace = simulate(source_latent_model, (latents,scene_duration))
            push!(traces, trace)
            tp_latents, gp_latents, tp_elems, gp_elems = get_retval(trace)
            push!(tplas, tp_latents); push!(gplas, gp_latents)
            push!(tpels, tp_elems); push!(gpels, gp_elems)
        end

        constraints = choicemap()
        for i = 1:batch_size

            if :tp in keys(latents)
                tp_latent = latents[:tp]
                constraints[(tp_latent, i) => :mu] = traces[i][tp_latent => :mu]
                constraints[(tp_latent, i) => :a] = traces[i][tp_latent => :a]
                
            elseif :gp in keys(latents)
                
                gp_type = latents[:gp]
                source_type = latents[:source_type]
                d = gp_type == :erb ? source_params["gp"]["erb"] : (source_type == "tone" ? source_params["gp"]["amp"]["1D"] : source_params["gp"]["amp"]["2D"]) 
                for k in keys(d)
                    constraints[(gp_type, i) => Symbol(k)] = traces[i][gp_type => Symbol(k)]
                end
                
            end

        end
        inputs = :tp in keys(latents) ? (tpels,scene_duration,) : (gpels,scene_duration,)

        return (inputs, constraints)

    end
    
    return batch_data_generator
    
end;

In [None]:
latent = :erb
source_params, steps, gtg_params, obs_noise = include("./base_params.jl")
audio_sr = 20000;
source_latent_model = make_source_latent_model(source_params, audio_sr, steps);
batch_size = 256; 

data_generators=Dict()
latents = Dict()
latents[:erb] = Dict(:gp => :erb, :source_type => "tone")
data_generators[latent] = make_batch_data_generator(source_latent_model, latents[latent], batch_size)

In [None]:
d = data_generators[latent]();
single_datapoint = d[1][1][15]
#mu_constraint = d[2][(:erb, 1)=>:mu]
scatter(single_datapoint[:t],single_datapoint[:erb])

In [None]:
mConstant = GaussianProcesses.MeanConst(25.0)
kern = GaussianProcesses.SE(0.0, 0.0)
logObsNoise = -1.0
gp = GaussianProcesses.GP(Float64.(single_datapoint[:t]),Float64.(single_datapoint[:erb]),mConstant,kern, logObsNoise)

In [None]:
Plots.plot(gp)

### Maximum Likelihood Estimates

In [None]:
optimize!(gp; domean=true, kern=true, noise=true, meanbounds=[[0],[50]], kernbounds = [[-15, -15], [5, 5]])

In [None]:
Plots.plot(gp)

In [None]:
actual = Dict(:mu=>[],:sigma=>[],:scale=>[],:epsilon=>[])
predict = Dict(:mu=>[],:sigma=>[],:scale=>[],:epsilon=>[])
n_reps = 3
for i = 1:batch_size
    
    sd = d[1][1][i]
    for j = [:mu, :sigma, :scale, :epsilon]
        push!(actual[j], d[2][(:erb, i)=>j])
    end
    
    gps = []; scores = [];
    for i = 1:n_reps
        
        mConstant = GaussianProcesses.MeanConst(Random.rand()[1]*50)
        kern = GaussianProcesses.SE(Random.randn(1)[1], Random.randn(1)[1])
        logObsNoise = Random.randn(1)[1]
        gp = GaussianProcesses.GP(Float64.(sd[:t]),Float64.(sd[:erb]),mConstant,kern, logObsNoise)
         
        optimize!(gp; domean=true, kern=true, noise=true, meanbounds=[[0],[60]], kernbounds = [[-15, -15], [5, 5]])
        
        push!(scores, gp.mll)
        push!(gps, gp) 
        
    end
    
    gp = gps[argmax(scores)]
        
    push!(predict[:mu], GaussianProcesses.get_params(gp.mean)[1])
    k = GaussianProcesses.get_params(gp.kernel)
    push!(predict[:scale], exp(k[1]))
    push!(predict[:sigma], exp(k[2]))
    push!(predict[:epsilon], exp(GaussianProcesses.get_params(gp.logNoise)[1]))
    
end

In [None]:
for z=1:4
    subplot(2,2,z)
    l = [:mu, :sigma, :scale, :epsilon][z]
    scatter(actual[l], predict[l])
    if z == 3
        xlabel("Actual")
        ylabel("MLE")
    end
    r = round(cor(actual[l],predict[l]),digits = 4)
    if z == 1
        title("MLE($n_reps opts)- $l, r=$r")
    else
        title("$l, r=$r")
    end
    maxy = maximum(predict[l]); miny=minimum(predict[l]);
    maxx = maximum(actual[l]); minx =minimum(actual[l]);
    xlim([min(miny,minx)-0.5, max(maxy,maxx)+0.5])
    ylim([min(miny,minx)-0.5, max(maxy,maxx)+0.5])
end
tight_layout()

### Elliptical Slice Sampling 

In [None]:
single_datapoint[:erb]

In [None]:
x = [ Gen.uniform(0,40) for g in 1:1000 ]
mu_mean = mean(x)
mu_std = std(x)
println(string("Mu -- mean: ", mu_mean, " std: ", mu_std))

x = [ log(Gen.gamma(3,1)) for g in 1:1000 ]
log_sigma_mean = mean(x)
log_sigma_std = std(x)
println(string("Sigma -- mean: ", log_sigma_mean, " std: ", log_sigma_std))

x = [ log(Gen.gamma(0.5,1)) for g in 1:1000 ]
log_scale_mean = mean(x)
log_scale_std = std(x)
println(string("scale -- mean: ", log_scale_mean, " std: ", log_scale_std))

x = [ log(Gen.gamma(1.0,0.1)) for g in 1:1000 ]
log_epsilon_mean = mean(x)
log_epsilon_std = std(x)
println(string("epsilon -- mean: ", log_epsilon_mean, " std: ", log_epsilon_std))

mConstant = GaussianProcesses.MeanConst(25.0)
kern = GaussianProcesses.SE(0.0, 0.0)
logObsNoise = -1.0
gp_ess = GaussianProcesses.GP(Float64.(single_datapoint[:t]),Float64.(single_datapoint[:erb]),mConstant,kern, logObsNoise)

In [None]:
rng = MersenneTwister(2134)
Random.seed!(rng,2134)
set_priors!(gp_ess.mean, [Distributions.Normal(mu_mean,mu_std)]) 
set_priors!(gp_ess.kernel, [Distributions.Normal(log_scale_mean, log_scale_std), Distributions.Normal(log_sigma_mean, log_sigma_std)]) 
set_priors!(gp_ess.logNoise, [Distributions.Normal(log_epsilon_mean, log_epsilon_std)])
@time chain = ess(rng, gp_ess, nIter=5000)

In [None]:
rng = MersenneTwister(2134)
Random.seed!(rng,2134)
set_priors!(gp_ess.mean, [Normal(mu_mean,mu_std)]) 
set_priors!(gp_ess.kernel, [Normal(log_scale_mean, log_scale_std), Normal(log_sigma_mean, log_sigma_std)]) 
set_priors!(gp_ess.logNoise, [Normal(log_epsilon_mean, log_epsilon_std)])
@time chain = ess(rng, gp_ess, nIter=5000, lik=true)

In [None]:
println(mean(chain,dims=2))
m=cov(chain,dims=2)

In [None]:
GaussianProcesses.make_posdef!(m)[1]

In [None]:
rng = MersenneTwister(2143)
randn(rng)

In [None]:
## To see order of parameters in "chains"
# println("All params: ", GaussianProcesses.get_params(gp))
# println("Mean: ", GaussianProcesses.get_params(gp.mean))
# println("LogScale/LogSigma: ", GaussianProcesses.get_params(gp.kernel))
# println("LogEpsilon: ", GaussianProcesses.get_params(gp.logNoise))
PyPlot.plot(chain')
legend(["Noise", "Mean", "SE log scale", "SE log sigma"])
title("Eliptical Slice Sampling Chain")

In [None]:
Plots.plot(gp_ess)

In [None]:
for z = 1:4
    subplot(2,2,z)
    l = [:epsilon, :mu, :scale, :sigma][z]
    if l == :epsilon || l == :scale || l == :sigma
        hist(exp.(chain[z,:]))
    else
        hist(chain[z,:])
    end
    title("Marginal on $l")
end
plt.tight_layout()

### Variational Inference: doesn't work! It doesn't replicate the ipynb that is on there either.

In [None]:
single_datapoint = d[1][1][203]
scatter(single_datapoint[:t],single_datapoint[:erb])
ylim([-5,50]);

In [None]:
mC = GaussianProcesses.MeanConst(Random.rand(1)[1]*50)
kern = GaussianProcesses.SE(Random.randn(1)[1], Random.randn(1)[1])
l = GaussianProcesses.GaussLik(1.0)
gp_vi = GaussianProcesses.GP(
            Float64.(vec(single_datapoint[:t])),
            Float64.(vec(single_datapoint[:erb])), 
            mC,kern,l)

In [None]:
@time Q = vi(gp_vi;nits=10000000);
gp_vi

In [None]:
ymean = [];
nsamps = 1000
xtest = collect(range(minimum(gp_vi.x),stop=maximum(gp_vi.x),length=50));
visamples = Array{Float64}(undef, nsamps, size(xtest, 1))

m = zeros(size(xtest))
for i in 1:nsamps
    visamples[i, :] = rand(gp_vi, xtest, Q)
    m += predict_y(gp_vi,xtest)[1]
end
m ./= nsamps

q10 = [quantile(visamples[:, i], 0.1) for i in 1:length(xtest)]
q50 = [quantile(visamples[:, i], 0.5) for i in 1:length(xtest)]
q90 = [quantile(visamples[:, i], 0.9) for i in 1:length(xtest)];


In [None]:
Plots.plot(xtest, q50, ribbon=(q10, q90), leg=true, fmt=:png, label="quantiles")
Plots.scatter!(single_datapoint[:t],single_datapoint[:erb], label="data")
Plots.ylims!((-5,100))
Plots.plot!(xtest, m, label="posterior mean", w=2)