# Analyze the output of Bayesian Fitting
To run this notebook for analyzing an already performed MCMC sampling of the posterior: <br>
1. Multiple MCMC runs for a single chain should be placed in consecutive folders titled "Run1","Run2",etc... . <br> If only one run was conducted, this should be placed in a folder titled "Run1"
2. Update the idbeg and idend values in "Load libraries and data" to match the Run folders containing your simulation <br>
3. Update common.jl's datamat to match your mcmc-chain's simulated time span if it doesn't match already <br>
4. Set the options under "Declare Notebook Options"
## Load libraries and data

In [None]:
using CSV, DataFrames, Plots, Measures, Statistics
pyplot();

# Load MCMC library
include("common.jl"); include("odesolver.jl"); include("gibbs.jl");

# Load data
idbeg = 1; idend = 1;
M_mcmc,M_err = mergeoutputs(idbeg,idend); nsmp = size(M_mcmc)[2];
println("$nsmp samples taken ...");

## Declare Notebook Options

In [None]:
# Declare whether a change point analysis was performed
flagΔpt = true;

# Declare final day data was fit to
#  (Notebook assumes data() outputs tspan[2] for gibbs trajectories)
dayfit = Date("2021-09-02");

# Trace plots shown every nstep
nstep = 1;

# Specify range of middle quantiles for MCMC posteriors w/wout σ-errors
qthresh = .95;
σthresh = .95;

### Prepare ODH data

In [None]:
# Load ODH data
fname = sheet.csv_odh; dfodh = CSV.read(fname,DataFrame);
cols = ["0-9","10-19","20-29","30-39","40-49","50-59","60-69","70-79","80+"];

# Declare initial and final days for trajectories
dates = dfodh[!,:time];
day0 = Date("2020-01-01");
dayi = day0 + Day(sheet.tspan[1]);
daym = day0 + Day(sheet.tspan[2]);

In [None]:
dayodh = dates[end];
dayf = maximum([dayodh,daym]);

ti = getfield(dayi-day0,:value); tf = getfield(dayf-day0,:value); tm = getfield(daym-day0,:value); 
tfit = getfield(dayfit-day0,:value); todh = getfield(dayodh-day0,:value)
pnow = plot(getfield.(dfodh[1:end-2,:time] .- Date("2020-01-01"),:value),
    dfodh[1:end-2,:daily_confirm]/1000,labels="",linewidth=1,xticks=LinRange(0,608,5),
        guidefont=14,titlefont=16,tickfont=12,ylabel="reported infections (thousands)",xlabel="date",
        size=(650,375),margin=5mm);
vline!(getfield.([dayi,dayfit].- Date("2020-01-01"),:value),labels="",linewidth=1);
plot!(xformatter = (x->Dates.format(Day(x)+Date("2020-01-01"),"mm/dd")))
savefig("odh.pdf");

# Analyze MCMC trajectories

In [None]:
# Compute the error bands which include ind incr error
σs = M_mcmc[177,:];
avgσ = sum(σs)/length(σs); 
confσ = quantile(σs,σthresh);
Δσs = flagΔpt ? M_mcmc[187,:] : [NaN];
confΔσ = flagΔpt ? quantile(Δσs,σthresh) : NaN;
days = dayi-day0;

σMᵗʳ_high = deepcopy(Mᵗʳ);
σMᵗʳ_low = deepcopy(Mᵗʳ);
ntrajs = size(Mᵗʳ)[1];

# Loop over all mcmc smp's, age groups, days and include their own errors
gen = [1,1,0];
varnow = 0.; ptid = Inf; Δσnow = 0.;
for k=1:ntrajs*9*dwnsmp
    # cycle generator:
    if gen[3] != dwnsmp
        gen[3] += 1;
    elseif gen[2] != 9
        gen[2] += 1; gen[3] = 1;
    else
        gen[1] += 1; gen[2] = 1; gen[3] = 1;
    end
    
    smp = gen[1]; aid = gen[2]; tid = gen[3];
    col = (aid-1)*dwnsmp+tid;
    
    # Update the var and Δpt info as needed
    if tid == 1
        # Reset var if new age group
        varnow = 0.;
        
        # Find new Δpt value if new sample
        if flagΔpt&&(aid == 1)
            Δptbest = M_mcmc[182,smp]; Δσnow = M_mcmc[187,smp];
            ptid = floor(Δptbest - days.value); ptid = (ptid > length(quantlow[:,1])-1) ? length(quantlow[:,1])-1 : ptid;
        end
    end
    
    # Determine the appropriate std dev across change point
    σnow = M_mcmc[177,smp];
    if !flagΔpt
        varnow += σnow^2;
    else
        varnow = (tid <= ptid) ? varnow + σnow^2 : varnow + Δσnow^2; 
    end
    
    # Compute errors
    σMᵗʳ_high[smp,col] += 2*√(varnow);
    σMᵗʳ_low[smp,col]  -= 2*√(varnow);
    σMᵗʳ_low[smp,col]  = maximum([σMᵗʳ_low[smp,col],0.]);
end
    
# Compute the σthresh quantiles for these
σMquants = Matrix{Float64}(undef,9*dwnsmp,2);
for i=1:9*dwnsmp
    σMquants[i,:] = [quantile(σMᵗʳ_low[:,i],(1-qthresh)/2) quantile(σMᵗʳ_high[:,i],1-(1-qthresh)/2)];
end
σquantlow = reshape(σMquants[:,1],dwnsmp,9); σquanthigh = reshape(σMquants[:,2],dwnsmp,9);
myσcolor = :blue; # Color for σMCMC bands

In [None]:
# Plot the data
ymax = maximum(convert(Matrix,dfodh[ti:todh,2:end-2])) + 1000;

p1 = plot(ti:todh,dfodh[!,"0-9"][ti:todh],labels="reported",title="0-9",size=(200,200)); 
quantmean = .5*quantlow[:,1] + .5*quanthigh[:,1]
vline!([tfit],labels="",color=:green);
plot!(qttaxis[1:dwnsmp], quantmean, ribbon=(quantmean-σquantlow[:,1],σquanthigh[:,1]-quantmean),
      ylims=(0.,ymax),labels="predicted 7 day avg", fillalpha=.2, color=myσcolor,legend=:topleft,linestyle=:dash,
      guidefont=12,titlefont=12,tickfont=8,xticks=LinRange(425,620,4),xlims=(425,620));
plot!(xformatter = (x->Dates.format(Day(x)+Date("2020-01-01"),"mm/dd")))

p2 = plot(ti:todh,dfodh[!,"10-19"][ti:todh],labels="",title="10-19",size=(200,200)); 
quantmean = .5*quantlow[:,2] + .5*quanthigh[:,2]
vline!([tfit],labels="",color=:green);
plot!(qttaxis[1:dwnsmp], quantmean, ribbon=(quantmean-σquantlow[:,2],σquanthigh[:,2]-quantmean),linestyle=:dash,
      ylims=(0.,ymax),labels="", fillalpha=.2, color=myσcolor,
      guidefont=12,titlefont=12,tickfont=8,xticks=LinRange(425,620,4),xlims=(425,620));
plot!(xformatter = (x->Dates.format(Day(x)+Date("2020-01-01"),"mm/dd")))

p3 = plot(ti:todh,dfodh[!,"20-29"][ti:todh],labels="",title="20-29",size=(200,200)); 
quantmean = .5*quantlow[:,3] + .5*quanthigh[:,3]
vline!([tfit],labels="",color=:green);
plot!(qttaxis[1:dwnsmp], quantmean, ribbon=(quantmean-σquantlow[:,3],σquanthigh[:,3]-quantmean),linestyle=:dash,
      ylims=(0.,ymax),labels="", fillalpha=.2, color=myσcolor,
      guidefont=12,titlefont=12,tickfont=8,xticks=LinRange(425,620,4),xlims=(425,620));
plot!(xformatter = (x->Dates.format(Day(x)+Date("2020-01-01"),"mm/dd")))

p4 = plot(ti:todh,dfodh[!,"30-39"][ti:todh],labels="",title="30-39",size=(200,200)); 
quantmean = .5*quantlow[:,4] + .5*quanthigh[:,4]
vline!([tfit],labels="",color=:green);
plot!(qttaxis[1:dwnsmp], quantmean, ribbon=(quantmean-σquantlow[:,4],σquanthigh[:,4]-quantmean),linestyle=:dash,
      ylims=(0.,ymax),labels="", fillalpha=.2, color=myσcolor,
      guidefont=12,titlefont=12,tickfont=8,xticks=LinRange(425,620,4),xlims=(425,620));
plot!(xformatter = (x->Dates.format(Day(x)+Date("2020-01-01"),"mm/dd")))

p5 = plot(ti:todh,dfodh[!,"40-49"][ti:todh],labels="",title="40-49",size=(200,200)); 
quantmean = .5*quantlow[:,5] + .5*quanthigh[:,5]
vline!([tfit],labels="",color=:green);
plot!(qttaxis[1:dwnsmp], quantmean, ribbon=(quantmean-σquantlow[:,5],σquanthigh[:,5]-quantmean),linestyle=:dash,
      ylims=(0.,ymax),labels="", fillalpha=.2, color=myσcolor,
      guidefont=12,titlefont=12,tickfont=8,xticks=LinRange(425,620,4),xlims=(425,620));
plot!(xformatter = (x->Dates.format(Day(x)+Date("2020-01-01"),"mm/dd")))

p6 = plot(ti:todh,dfodh[!,"50-59"][ti:todh],labels="",title="50-59",size=(200,200)); 
quantmean = .5*quantlow[:,6] + .5*quanthigh[:,6]
vline!([tfit],labels="",color=:green);
plot!(qttaxis[1:dwnsmp], quantmean, ribbon=(quantmean-σquantlow[:,6],σquanthigh[:,6]-quantmean),linestyle=:dash,
      ylims=(0.,ymax),labels="", fillalpha=.2, color=myσcolor,
      guidefont=12,titlefont=12,tickfont=8,xticks=LinRange(425,620,4),xlims=(425,620));
plot!(xformatter = (x->Dates.format(Day(x)+Date("2020-01-01"),"mm/dd")))

p7 = plot(ti:todh,dfodh[!,"60-69"][ti:todh],labels="",title="60-69",size=(200,200),
          xlabel="date",ylabel="daily infections"); 
quantmean = .5*quantlow[:,7] + .5*quanthigh[:,7]
vline!([tfit],labels="",color=:green);
plot!(qttaxis[1:dwnsmp], quantmean, ribbon=(quantmean-σquantlow[:,7],σquanthigh[:,7]-quantmean),linestyle=:dash,
      ylims=(0.,ymax),labels="", fillalpha=.2, color=myσcolor,
      guidefont=12,titlefont=12,tickfont=8,xticks=LinRange(425,620,4),xlims=(425,620));
plot!(xformatter = (x->Dates.format(Day(x)+Date("2020-01-01"),"mm/dd")))

p8 = plot(ti:todh,dfodh[!,"70-79"][ti:todh],labels="",title="70-79",size=(200,200)); 
quantmean = .5*quantlow[:,8] + .5*quanthigh[:,8]
vline!([tfit],labels="",color=:green);
plot!(qttaxis[1:dwnsmp], quantmean, ribbon=(quantmean-σquantlow[:,8],σquanthigh[:,8]-quantmean),linestyle=:dash,
      ylims=(0.,ymax),labels="", fillalpha=.2, color=myσcolor,
      guidefont=12,titlefont=12,tickfont=8,xticks=LinRange(425,620,4),xlims=(425,620));
plot!(xformatter = (x->Dates.format(Day(x)+Date("2020-01-01"),"mm/dd")))

p9 = plot(ti:todh,dfodh[!,"80+"][ti:todh],labels="",title="80+",size=(200,200)); 
quantmean = .5*quantlow[:,9] + .5*quanthigh[:,9]
vline!([tfit],labels="",color=:green);
plot!(qttaxis[1:dwnsmp], quantmean, ribbon=(quantmean-σquantlow[:,9],σquanthigh[:,9]-quantmean),linestyle=:dash,
      ylims=(0.,ymax),labels="", fillalpha=.2, color=myσcolor,
      guidefont=12,titlefont=12,tickfont=8,xticks=LinRange(425,620,4),xlims=(425,620));
plot!(xformatter = (x->Dates.format(Day(x)+Date("2020-01-01"),"mm/dd")))

println("Dates run from $dayi to $dayf while simulated up to $daym")
lay = @layout [a b c; d e f; g h i];
plot(p1,p2,p3,p4,p5,p6,p7,p8,p9, layout=lay, size = (800,500))
savefig("prediction.pdf");

In [None]:
# Compute model uncertainty in prediction aggregated across all age groups
Mtrajsagg = reshape(Mtrajs,:,9,nsmp); Mtrajsagg = reshape(sum(Mtrajsagg,dims=2),:,nsmp);
σMtrajsagg_high = deepcopy(Mtrajsagg);
σMtrajsagg_low = deepcopy(Mtrajsagg);
gen = [0,1];
varnow = 0.; ptid = Inf; Δσnow = 0.;
for k=1:prod(size(Mtrajsagg))
    if gen[1]!=size(Mtrajsagg)[1]
        gen[1]+=1;
    else
        gen[2]+=1;
        gen[1]=1;
    end
    i=gen[1]; smp=gen[2];
    
    # Update the var and Δpt info as needed
    if i == 1
        varnow = 0.;
        
        # Find new Δpt value if new sample
        if flagΔpt&&(i == 1)
            Δptbest = M_mcmc[182,smp]; Δσnow = M_mcmc[187,smp];
            ptid = floor(Δptbest - days.value); ptid = (ptid > length(quantlow[:,1])-1) ? length(quantlow[:,1])-1 : ptid;
        end
    end
    
    # Determine the appropriate std dev across change point
    σnow = M_mcmc[177,smp];
    if !flagΔpt
        varnow += σnow^2;
    else
        varnow = (i <= ptid) ? varnow + σnow^2 : varnow + Δσnow^2; 
    end
    
    σMtrajsagg_high[i,smp] += 2*√(varnow)*3;
    σMtrajsagg_low[i,smp]  -= 2*√(varnow)*3;
    σMtrajsagg_low[i,smp]  = maximum([σMtrajsagg_low[i,smp],0.]);
end


σMaggquants = Matrix{Float64}(undef,dwnsmp,2);
for i=1:dwnsmp
    σMaggquants[i,:] = [quantile(σMtrajsagg_low[i,:],(1-qthresh)/2) quantile(σMtrajsagg_high[i,:],1-(1-qthresh)/2)];
end

In [None]:
# Plot the aggregate model prediction
pagg = plot(getfield.(dfodh[ti:todh,:time] .- Date("2020-01-01"),:value),
            dfodh[ti:todh,:daily_confirm],linewidth=1,
            guidefont=12,titlefont=12,tickfont=12,
            size=(350,250),
            legendfont=8,legend=:topleft,
            labels="ODH daily cases");

quantmean = .5*σMaggquants[:,1] + .5*σMaggquants[:,2]
plot!(qttaxis[1:dwnsmp], quantmean, ribbon=(quantmean-σMaggquants[:,1],σMaggquants[:,2]-quantmean),
      fillalpha=.2, color=myσcolor,linestyle=:dash,
      ylims=(0,10000),
      guidefont=12,titlefont=12,tickfont=12,xticks=LinRange(425,620,4),xlims=(425,620),
      yticks=LinRange(0,10000,3),
      labels="predicted 7 day avg");
vline!([tfit],labels="",color=:green);
plot!(xformatter = (x->Dates.format(Day(x)+Date("2020-01-01"),"mm/dd")),
      yformatter = (x->split(string(x/1000),".")[1]*"k") )

In [None]:
savefig("predictionagg.pdf");

In [None]:
ENV["COLUMNS"] = 1000; ENV["ROWS"] = 1000;
println("Best parameters:")
df = DataFrame(Dict{String,Float64}("L1 Err"=>M_err[idbest],"r0"=>M_mcmc[12,idbest],
                                    "α"=>M_mcmc[31,idbest],"ω"=>M_mcmc[32,idbest],
                                    "rptλ"=>M_mcmc[176,idbest],"bayσ"=>M_mcmc[177,idbest],
                                    "vι0"=>M_mcmc[179,idbest],
                                    "rptλE"=>M_mcmc[180,idbest],"rptλI"=>M_mcmc[181,idbest],
                                    "Δpt"=>M_mcmc[182,idbest],"Δr0"=>M_mcmc[183,idbest],
                                    "Δα"=>M_mcmc[184,idbest],"Δω"=>M_mcmc[185,idbest],
                                    "Δrptλ"=>M_mcmc[186,idbest],"Δbayσ"=>M_mcmc[187,idbest]))

## Collect marginal distributions
Note that $\Delta$ refers to a parameter estimated for after the change point except in the $\Delta$pt case when it stands for the change point date itself.

In [None]:
# Posterior means
dfsmp = DataFrame(:r0=>M_mcmc[12,:],:α=>M_mcmc[31,:],:ω=>M_mcmc[32,:],:rptλ=>M_mcmc[176,:],:bayσ=>M_mcmc[177,:],
                  :Δr0=>M_mcmc[183,:],:Δα=>M_mcmc[184,:],:Δω=>M_mcmc[185,:],:Δrptλ=>M_mcmc[186,:],:Δbayσ=>M_mcmc[187,:]);
describe(dfsmp)

In [None]:
p1 = histogram(M_mcmc[12,:],normalize=:probability, labels="", title="r0",guidefonts = 12);
p2 = histogram(M_mcmc[31,:],normalize=:probability, labels="", title="α");
p3 = histogram(M_mcmc[32,:],normalize=:probability, labels="", title = "ω");
p4 = histogram(M_mcmc[176,:],normalize=:probability, labels="", title="rptλ");
p5 = histogram(M_mcmc[177,:],normalize=:probability, labels="", title="bayσ");
p7 = histogram(M_mcmc[179,:],normalize=:probability,labels="",title="vι0");

lay = @layout [a b;c d; e f];
plot(p2,p3,p1,p4,p5,p7, layout=lay, size=(850,425))
savefig("marginals.pdf")

In [None]:
# Change point marginals
if flagΔpt
    p1 = histogram(M_mcmc[183,:],normalize=:probability, labels="", title="Δr0");
    p2 = histogram(M_mcmc[184,:],normalize=:probability, labels="", title="Δα");
    p3 = histogram(M_mcmc[185,:],normalize=:probability, labels="", title = "Δω");
    p4 = histogram(M_mcmc[186,:],normalize=:probability, labels="", title="Δrptλ");
    p5 = histogram(M_mcmc[187,:],normalize=:probability, labels="", title="Δbayσ");
    #  If Δpt was Infty substitute dummy plot for histogram
    p7 = (minimum(M_mcmc[182,:]) != Inf) ? histogram(M_mcmc[182,:],normalize=:probability,labels="",title="Δpt") : plot([0.,0.],[0.,0.],labels="",title="Did not run Δpt") ;
    
    lay = @layout [g h; a b; c d];
    plot(p7,p5,p2,p3,p1,p4, layout=lay, size=(850,425))
    savefig("marginals_dltpt.pdf")
end

## Trace Plots

In [None]:
q1 = plot(M_mcmc[12,1:nstep:end],labels="", title="r0");
q2 = plot(M_mcmc[31,1:nstep:end],labels="", title="α");
q3 = plot(M_mcmc[32,1:nstep:end],labels="", title = "ω");
q4 = plot(M_mcmc[176,1:nstep:end],labels="", title="rptλ");
q5 = plot(M_mcmc[177,1:nstep:end],labels="", title="bayσ");
q7 = plot(M_mcmc[179,1:nstep:end],labels="",title="vι0");

lay = @layout [a b;c d; e f];
plot(q2,q3,q1,q4,q5,q7, layout=lay, size=(850,425))
savefig("trace.pdf")

In [None]:
# Change point trace plots
if flagΔpt
    q1 = plot(M_mcmc[183,1:nstep:end],labels="", title="Δr0");
    q2 = plot(M_mcmc[184,1:nstep:end],labels="", title="Δα");
    q3 = plot(M_mcmc[185,1:nstep:end],labels="", title = "Δω");
    q4 = plot(M_mcmc[186,1:nstep:end],labels="", title="Δrptλ");
    q5 = plot(M_mcmc[187,1:nstep:end],labels="", title="Δbayσ");
    q7 = plot(M_mcmc[182,1:nstep:end],labels="",title="Δpt");
    
    lay = @layout [g h; a b; c d];
    plot(q7,q5,q2,q3,q1,q4, layout=lay, size=(850,425))
    savefig("trace_dltpt.pdf")
end

## $\Delta$pt by dates

In [None]:
if flagΔpt
    p1 = histogram(M_mcmc[182,:],labels="",title="Δpt",xticks=LinRange(553,559,3),xlims=(553,559),
                   guidefont=12,titlefont=12,tickfont=8,linewidth=0,bins=553:559,normalize=:pdf)
    plot!(xformatter = (x->Dates.format(Day(x)+Date("2020-01-01"),"mm/dd")))
    q1 = plot(M_mcmc[182,1:nstep:end],labels="",title="Δpt",xticks=false,
              guidefont=12,titlefont=12,tickfont=8,ylims=(554,562),
              yticks=LinRange(554,562,5));
    plot!(yformatter = (x->Dates.format(Day(x)+Date("2020-01-01"),"mm/dd")))
    
    lay = @layout [a b];
    plot(p1,q1,layout=lay,size=(600,250))
    savefig("dltpt.pdf");
end