In [None]:
using Revise
using Serialization
using MCMCChains
using Comrade
using Pyehtim
using Plots
using Pigeons
using CairoMakie
using PairPlots
using ColorSchemes
using FileIO
using Krang
using CairoMakie
using Distributions

In [None]:
red_cb = colorant"rgba(84%, 11%, 38%, 1.0)";
blue_cb = colorant"rgba(12%, 53%, 89%, 1.0)";
orange_cb = colorant"rgba(100%, 75%, 3%, 1.0)";
blue_cb_t = colorant"rgba(12%, 53%, 89%, 0.5)";
orange_cb_t = colorant"rgba(100%, 75%, 3%, 0.5)";
blue_cb_t1 = colorant"rgba(12%, 53%, 89%, 0.25)";
orange_cb_t1 = colorant"rgba(100%, 75%, 3%, 0.25)";
white_cb_t = colorant"rgba(100%, 100%, 100%, 0.5)";
green_cb = colorant"rgba(0%, 30%, 25%, 1.0)"

theme_curr = Theme(
    Axis=(
        yticklabelsvisible=false,
        yticksvisible=false,
        titlesize=30,
        ylabelsize=20,
        ylabelrotation=0,
    ),
    LineElement=(
        linewidth=10,
    ),
    Density=(rasterize=true,),
    Heatmap=(rasterize=true,)
)
set_theme!(merge(theme_curr, theme_latexfonts()));

In [None]:
function samples_to_plot(tsamples, prior_keys)
    temp = MCMCChains.to_matrix(tsamples[prior_keys])

    # Transform variables from radians to degrees
    for sym in [:pa, :ι, :χ, :η, :θo, :θs]
        if sym == :pa
            temp[:, indexin([sym,], prior_keys)[1]] .= (360 .+ temp[:, indexin([sym,], prior_keys)[1]] .* 180 / π  ) .% 360
        else
            temp[:, indexin([sym,], prior_keys)[1]] .= temp[:, indexin([sym,], prior_keys)[1]] .* 180 / π  
        end

    end
    Chains(temp, prior_keys)#[[:m_d, :spin, :θo, :θs, :pa, :rpeak]]
end

function get_best_fit(path)
    param_file = open(path)#abspath(dirname(@__DIR__), "..", "..","runs","image_domain","ma+0.5_r20_nall_tavg.fits", "JBOX", "best_nxcorr.txt"))
    [readline(param_file) for _ in 1:3]
    best_fit = eval(Meta.parse(readline(param_file)[12:end]))
    return NamedTuple{(keys(best_fit)...,:pa)}(values(best_fit)..., (288-360)/180*π)
end

function get_nxcorr_vals(best_fit)
    b_f_keys = keys(best_fit)[begin:end-2]
    return NamedTuple{replace(b_f_keys, :σ=>:spec)}(map(x-> x ∈ (:θo, :θs, :χ, :ι, :η) ? best_fit[x]*180/π : best_fit[x], b_f_keys))
end

function convert_inimg(inimg)
    temp = pyconvert(Vector{Float64}, inimg.regrid_image(μas2rad(100), 200).imvec)
    temp = reshape(temp, (Int(sqrt(length(temp))), Int(sqrt(length(temp)))))
    return reverse(reverse(temp, dims=1), dims=2)
end


In [None]:
# MAD
rnd = 17

model_name = "ma+0.5_r20_GRMHD_snapshot"
in_base = abspath(dirname(@__DIR__), "..", "visibilityDomain","results","all","Mad_Snapshot")
include(joinpath(in_base, "model_params.jl"))

prior_keys = collect(keys(prior));

outpath = abspath(dirname(@__DIR__), ".." , "..","..","..","runs","visibilityDomain","data","$(model)")
mad_inimg_1 = convert_inimg(inimg)

reduced_recorders = Serialization.deserialize(joinpath(in_base, "round=$rnd", "checkpoint", "reduced_recorders.jls"))
s_array = hcat([reduced_recorders.traces[n_tempering_levels=>i] for i in 1:2^rnd]...)'[:,1:end]

samples = Chains(s_array[:,begin:end-1], collect(prior_keys))
tsamples = Chains(reshape(hcat(collect.(map(x -> values(transform(cpost, x)), [samples.value[i, :] for i in 1:size(samples.value)[1]]))...)', size(samples.value)), collect(prior_keys))
MCMCChains.hpd(tsamples)
mad_samples_to_plot_1 = samples_to_plot(tsamples, prior_keys)

in_base = abspath(dirname(@__DIR__), "..", "visibilityDomain","results","all","Mad_Snapshot_2")
include(joinpath(in_base, "model_params.jl"))
outpath = abspath(dirname(@__DIR__), ".." , "..","..","..","runs","visibilityDomain","data","$(model)")
mad_inimg_2 = convert_inimg(inimg)

reduced_recorders = Serialization.deserialize(joinpath(in_base, "round=$rnd", "checkpoint", "reduced_recorders.jls"))
s_array = hcat([reduced_recorders.traces[n_tempering_levels=>i] for i in 1:2^rnd]...)'[:,1:end]

samples = Chains(s_array[:,begin:end-1], collect(prior_keys))
tsamples = Chains(reshape(hcat(collect.(map(x -> values(transform(cpost, x)), [samples.value[i, :] for i in 1:size(samples.value)[1]]))...)', size(samples.value)), collect(prior_keys))
MCMCChains.hpd(tsamples)

mad_samples_to_plot_2 = samples_to_plot(tsamples, prior_keys)


In [None]:
# SANE
rnd = 14
in_base = abspath(dirname(@__DIR__), "..", "visibilityDomain","results","all","Sane_Snapshot")
include(joinpath(in_base, "model_params.jl"))
outpath = abspath(dirname(@__DIR__), "..","..","..","runs","visibilityDomain","data","$(model)")
sane_inimg_1 = convert_inimg(inimg)

reduced_recorders = Serialization.deserialize(joinpath(in_base, "round=$rnd", "checkpoint", "reduced_recorders.jls"))
s_array = hcat([reduced_recorders.traces[n_tempering_levels=>i] for i in 1:2^rnd]...)'[:,1:end]

samples = Chains(s_array[:,begin:end-1], collect(prior_keys))
tsamples = Chains(reshape(hcat(collect.(map(x -> values(transform(cpost, x)), [samples.value[i, :] for i in 1:size(samples.value)[1]]))...)', size(samples.value)), collect(prior_keys))
MCMCChains.hpd(tsamples)

sane_samples_to_plot_1 = samples_to_plot(tsamples, prior_keys)

in_base = abspath(dirname(@__DIR__), "..", "visibilityDomain","results","all","Sane_Snapshot_2")
include(joinpath(in_base, "model_params.jl"))
outpath = abspath(dirname(@__DIR__), ".." , "..","..","..","runs","visibilityDomain","data","$(model)")
sane_inimg_2 = convert_inimg(inimg)


reduced_recorders = Serialization.deserialize(joinpath(in_base, "round=$rnd", "checkpoint", "reduced_recorders.jls"))
s_array = hcat([reduced_recorders.traces[n_tempering_levels=>i] for i in 1:2^rnd]...)'[:,1:end]

samples = Chains(s_array[:,begin:end-1], collect(prior_keys))
tsamples = Chains(reshape(hcat(collect.(map(x -> values(transform(cpost, x)), [samples.value[i, :] for i in 1:size(samples.value)[1]]))...)', size(samples.value)), collect(prior_keys))
MCMCChains.hpd(tsamples)

sane_samples_to_plot_2 = samples_to_plot(tsamples, prior_keys)


In [None]:
fig = Figure(resolution=(500,600));
titles = (m_d=L"\theta_g", spin=L"\text{a}", θo=L"\theta_o", θs=L"\theta_s", pa=L"\text{p.a.}", rpeak=L"r_{\rm peak}", p1=L"p_1", p2=L"p_2", χ=L"\chi", ι=L"\iota", βv=L"\beta_v", spec=L"\sigma", η=L"\eta")
for (i,ty) in enumerate([1,2,3,5])
    for (j,mdl) in enumerate([:sane, :mad])
        best_fit = get_best_fit(abspath(dirname(@__DIR__), "..", "..","runs","image_domain", mdl == :sane ? "sa+0.94_r160_nall_tavg.fits" : "ma+0.5_r20_nall_tavg.fits", "JBOX", "best_nxcorr.txt"))
        nxcorr_vals = get_nxcorr_vals(best_fit)
        true_vals = (m_d=3.83, spin=mdl == :sane ? -0.94 : -0.5, θo=17, pa=360 - 72)

        samples_to_plot_1, samples_to_plot_2 = mdl == :sane ? (sane_samples_to_plot_1, sane_samples_to_plot_2) : (mad_samples_to_plot_1, mad_samples_to_plot_2)
        nsamples = size(samples_to_plot_1.value)[1]
        N = nsamples

        ax = j == 1 ? Axis(fig[i, j], ylabel=titles[ty]) : Axis(fig[i, j])
        categories = rand(ty:ty, N)
        dodge = rand(1:2, N)
        side = [:left, :right] #rand([:left, :right], N)
        for s in side
            color = (s == :left) ? blue_cb_t : orange_cb_t
            vals = s == :left ? samples_to_plot_1[:,ty,1] : samples_to_plot_2[:, ty,1]

            CairoMakie.density!(ax, vals, color=color, orientation=:horizontal, bandwidth=0.02*((5 >= ty >= 3 ? 180/π : 1)*abs(prior[ty].b-prior[ty].a) + (ty==5 ? 360 : 0)))
        end
        try
            CairoMakie.vlines!(ax, (5 >= ty >= 3 ? 180/π : 1)*best_fit[prior_keys[ty]] + (ty==5 ? 360 : 0), color=:black, linewidth=3.5)
            CairoMakie.vlines!(ax, true_vals[prior_keys[ty]], color=red_cb)
        catch e
            println("no true value for $(prior_keys[ty])")
        end
    end
end
hidedecorations!(Axis(fig[1:4,1], title="SANE", topspinevisible=false, bottomspinevisible=false, leftspinevisible=false, rightspinevisible=false))
hidedecorations!(Axis(fig[1:4,2], title="MAD", topspinevisible=false, bottomspinevisible=false, leftspinevisible=false, rightspinevisible=false))
CairoMakie.Legend(
    fig[5,1:2], 
    labelsize=20,
    [
        [CairoMakie.PolyElement(color=blue_cb, strokewidth=1), CairoMakie.PolyElement(color=orange_cb, strokewidth=1)],
        [CairoMakie.LineElement(color=red_cb, strokewidth=1), CairoMakie.LineElement(color=:black, strokewidth=1)]
    ], 
    [
        [L"\text{Snapshot 1}", L"\text{Snapshot 2}"],
        [L"\text{Truth}", L"\text{Best NxCORR}"],
    ],
    [nothing, nothing],
    tellheight=false,
    tellwidth=false,
    margin=(10, 10, 10, 10),
    halign=:center, valign=:bottom, orientation=:horizontal, nbanks=2
)
save(joinpath((@__DIR__), "snapshot_comparison.pdf"), fig)
fig


In [None]:
fig = Figure(resolution=(500,500))
#MAD
ax = Axis(fig[1,1], xreversed=true,
    xticksvisible=false,
    yticksvisible=false,
    xticklabelsvisible=false,
    yticklabelsvisible=false,
    aspect=1

)#, xlabel=L"\text{Truth}")
CairoMakie.heatmap!(ax, mad_inimg_1, colormap=ColorSchemes.afmhot, show_axis=false)
CairoMakie.text!(ax, 195, 195, text=L"\text{Snapshot 1}", align=(:left,:top), justification=:left, color=:white, fontsize=30)
CairoMakie.text!(ax, 195, 40, text=L"\text{MAD}", align=(:left,:top), justification=:left, color=:white, fontsize=30)
ax = Axis(fig[1,2], xreversed=true,
    xticksvisible=false,
    yticksvisible=false,
    xticklabelsvisible=false,
    yticklabelsvisible=false,
    aspect=1
)
CairoMakie.heatmap!(ax, mad_inimg_2, colormap=ColorSchemes.afmhot, show_axis=false)
CairoMakie.text!(ax, 195, 195, text=L"\text{Snapshot 2}", align=(:left,:top), justification=:left, color=:white, fontsize=30)
CairoMakie.text!(ax, 195, 40, text=L"\text{MAD}", align=(:left,:top), justification=:left, color=:white, fontsize=30)


#SANE
ax = Axis(fig[2,1], xreversed=true,
    xticksvisible=false,
    yticksvisible=false,
    xticklabelsvisible=false,
    yticklabelsvisible=false,
    aspect=1

)#, xlabel=L"\text{Truth}")
CairoMakie.heatmap!(ax, sane_inimg_1, colormap=ColorSchemes.afmhot, show_axis=false)
CairoMakie.text!(ax, 195, 195, text=L"\text{Snapshot 1}", align=(:left,:top), justification=:left, color=:white, fontsize=30)
CairoMakie.text!(ax, 195, 40, text=L"\text{SANE}", align=(:left,:top), justification=:left, color=:white, fontsize=30)
ax = Axis(fig[2,2], xreversed=true,
    xticksvisible=false,
    yticksvisible=false,
    xticklabelsvisible=false,
    yticklabelsvisible=false,
    aspect=1
)
CairoMakie.heatmap!(ax, sane_inimg_2, colormap=ColorSchemes.afmhot, show_axis=false)
CairoMakie.text!(ax, 195, 195, text=L"\text{Snapshot 2}", align=(:left,:top), justification=:left, color=:white, fontsize=30)
CairoMakie.text!(ax, 195, 40, text=L"\text{SANE}", align=(:left,:top), justification=:left, color=:white, fontsize=30)

save(joinpath((@__DIR__), "snapshot_images.pdf"), fig)
fig