In [13]:
using MCMCChains, Serialization, Plots, StableRNGs
include("misc.jl")
include("convergence_diagnostics.jl");

In [6]:
#run = "200"
rng = StableRNG(39485) #for randomly sampled trace Plots
parameters = return_inferred_parameters();

Trace Plots

In [16]:
chain = deserialize("outputs/300_$(run)_posterior_samples.jls")
ndims, nwalkers, nsamples = size(chain)
p = []
for i in 1:ndims
    p_i = trace_plots(rng, chain[i,:,:], nwalkers, nsamples, parameters[i])
    push!(p,p_i)
end
my_plot = plot(p...,layout=(4,2))
plot!(legendfontsize=4,titlefontsize=4,tickfontsize=4,guidefontsize=4,left_margin = 2Plots.mm, bottom_margin = 2Plots.mm)
savefig("outputs/400_convergence_diagnostics/400_$(run)_trace_plots.png")

"/Users/hollyhuber/Documents/structure_informed_cell_signaling2/gpcr/outputs/400_convergence_diagnostics/400_200_trace_plots.png"

ESS

In [17]:
my_ess = return_ess(chain, ndims, nwalkers, nsamples)
serialize("outputs/400_$(run)_ess.jls", my_ess)
chain = nothing

Plot ESS

In [19]:
ess = deserialize("outputs/400_$(run)_ess.jls")
nwalkers = 1000
min_ess = nwalkers*100
p1 = plot(bar(1:ndims, ess, color=:deeppink4, label="ess"))
plot!([0,ndims+1],[min_ess, min_ess], color=:black, linewidth=4, label="minimum ess")
xticks!(1:ndims, parameters)
plot!(dpi=300, size=(10*100, 2*100))
savefig("outputs/400_convergence_diagnostics/400_$(run)_ess_plot.png")

"/Users/hollyhuber/Documents/structure_informed_cell_signaling2/gpcr/outputs/400_convergence_diagnostics/400_200_ess_plot.png"

chain = deserialize("outputs/300_$(run)_posterior_samples.jls")
ndims, nwalkers, nsamples = size(chain)
parameters = return_inferred_parameters()
#must reshape for ess_rhat function later 
reshaped_chain = zeros(nsamples,ndims,nwalkers)
for i in 1:ndims
    for j in 1:nwalkers
        reshaped_chain[:,i,j] = chain[i,j,:]
    end
end

Trace Plots

p = []
for parameter_to_plot in 1:ndims
    chains_to_plot = [1,33, 50, 75, 98]
    plot(1:nsamples, reshaped_chain[:,parameter_to_plot,chains_to_plot[1]],label="chain $(chains_to_plot[1])")
    for i in 2:length(chains_to_plot)
        plot!(1:nsamples, reshaped_chain[:,parameter_to_plot,chains_to_plot[i]],label="chain $(chains_to_plot[i])")
    end
    xlabel!("iteration")
    ylabel!("$(parameters[parameter_to_plot])")

parameter_to_plot = 1
chains_to_plot = [1,33, 50, 75, 98]
plot(1:nsamples, reshaped_chain[:,parameter_to_plot,chains_to_plot[1]],label="chain $(chains_to_plot[1])")
for i in 2:length(chains_to_plot)
    plot!(1:nsamples, reshaped_chain[:,parameter_to_plot,chains_to_plot[i]],label="chain $(chains_to_plot[i])")
end
xlabel!("iteration")
ylabel!("$(parameters[parameter_to_plot])")

Calculate Diagnostics

mcmchains = MCMCChains.Chains(reshaped_chain, parameters)
diagnostics = MCMCChains.ess_rhat(mcmchains)
min_ess = nwalkers*100
max_rhat = 1.01;

plot(bar(1:ndims, diagnostics[:,:ess], color=:deeppink4, label="ess"))
plot!([0,ndims+1],[min_ess,min_ess], color=:black, linewidth=4, label="minimum ess")
xticks!(1:8,parameters)