In [31]:
import Pkg
Pkg.activate(".")

dependencies = [
    "IJulia",
    "Revise",
    "Turing",
    "Bijectors",
    "ParetoSmooth",  
    "LogExpFunctions",    
    "StatsPlots",             
    "DataFrames",       
    "JLD2",            
    "CSV"
]

# Pkg.add(dependencies)
Pkg.status()

using Revise, Turing, Bijectors, ParetoSmooth, LogExpFunctions, StatsPlots, DataFrames, JLD2, CSV
using Random, LinearAlgebra

jlfiles = [
    "code/DataPreparation.jl",
    "code/Utils.jl",
    "code/MyModels.jl",
    "code/models/glm.jl",
    "code/models/glmhmm.jl",
    "code/models/iohmm.jl",
]
for file in jlfiles
    include(file)
end

# Check number of threads
println("Running on ", Threads.nthreads(), " threads.")
# Set random seed
const SEED = 123;

[32m[1m  Activating[22m[39m project at `c:\Users\michi\WorkSpace\gitwork\mich2000jp\IPD_IOHMM`


[32m[1mStatus[22m[39m `C:\Users\michi\WorkSpace\gitwork\mich2000jp\IPD_IOHMM\Project.toml`
  [90m[76274a88] [39mBijectors v0.15.14
  [90m[336ed68f] [39mCSV v0.10.15
  [90m[a93c6f00] [39mDataFrames v1.8.1
  [90m[7073ff75] [39mIJulia v1.33.0
  [90m[033835bb] [39mJLD2 v0.6.3
  [90m[2ab3a3ac] [39mLogExpFunctions v0.3.29
  [90m[a68b5a21] [39mParetoSmooth v0.7.16
  [90m[295af30f] [39mRevise v3.13.0
  [90m[f3b207a7] [39mStatsPlots v0.15.8
[32m⌃[39m [90m[fce5fe82] [39mTuring v0.40.5
[36m[1mInfo[22m[39m Packages marked with [32m⌃[39m have new versions available and may be upgradable.
Running on 12 threads.


In [4]:
## Data Preparation ----------------------------------------------------
FIX_PATH = "data/fix.csv"
RAND_PATH = "data/rand.csv"
DATA_PATH = "data/data.jld2"
prepare_data(FIX_PATH, RAND_PATH, DATA_PATH)
data_fp = load_data(DATA_PATH, condition=:FP)
data_sp = load_data(DATA_PATH, condition=:SP);

Reading CSV files...
Saving processed data to data/data.jld2 ...
Data saved successfully.


In [None]:
## MCMC Settings ----------------------------------------------------
n_iter   = 2000    # Number of samples
n_burnin = 2000    # Burn-in
n_chains = 12       # Number of chains

trt        = "sp"
model_name = "glmhmm"
K_states   = 2
prefix     = "TDist3_Sig50"

sampler = NUTS(0.8; adtype=AutoForwardDiff())

println("=== Analysis Settings ===")
println("MCMC: $n_iter samples, $n_burnin burn-in, $n_chains chains")
println("Model: $model_name, K=$K_states, Condition=$trt, Prefix=$prefix")

=== Analysis Settings ===
MCMC: 2000 samples, 2000 burn-in, 12 chains
Model: glmhmm, K=2, Condition=fp, Prefix=TDist3_Sig50


In [33]:
## MCMC Run ----------------------------------------------------
title = "$(trt)_$(model_name)_K$(K_states)_$prefix"
CHAIN_PATH    = "chain/$title.jld2"
SUMMARY_PATH = "output/$(title)_summary.csv"
LOO_PATH     = "output/$(title)_loo.csv"
PLOT_PATH     = "output/$(title)_plot.png"
PLOT_GQ_PATH  = "output/$(title)_plot_gq.png"

data = trt == "fp" ? data_fp : data_sp
model, model_gq = model_selector(model_name, data, K_states)
@load CHAIN_PATH chain

1-element Vector{Symbol}:
 :chain

In [34]:
chain_relabeled = relabel_chain(chain, K_states)
p = plot(chain_relabeled)
savefig(p, "chain.png")
plp = plot(chain_relabeled[:lp])
savefig(plp, "lp.png")
model_gq = glmhmm(data_fp, K_states, track=true)
gq = generated_quantities(model_gq, chain_relabeled)
chain_gq = convert_gq(gq)
p = plot(chain_gq)
savefig(p, "gq.png")

"c:\\Users\\michi\\WorkSpace\\gitwork\\mich2000jp\\IPD_IOHMM\\gq.png"

In [35]:
function chain_selected(chain, chn::Vector{Int})
    return chain[:, :, chn]
end
chn = [1,2,3,4,5,6,7,8]
chain_relabeled = chain_selected(chain_relabeled, chn)
p = plot(chain_relabeled)
savefig(p, "chain_selected.png")
plp = plot(chain_relabeled[:lp])
savefig(plp, "lp_selected.png")
model_gq = glmhmm(data_fp, K_states, track=true)
gq = generated_quantities(model_gq, chain_relabeled)
chain_gq = convert_gq(gq)
p = plot(chain_gq)
savefig(p, "gq_selected.png")

"c:\\Users\\michi\\WorkSpace\\gitwork\\mich2000jp\\IPD_IOHMM\\gq_selected.png"

In [36]:
## Post Processing ----------------------------------------------------
chain_gq = convert_gq(gq)
println("summarizing results...")
df_summary = DataFrame(summarystats(chain_relabeled))
df_summary_gq = DataFrame(summarystats(chain_gq))
df_hpd = DataFrame(MCMCChains.hpd(chain_relabeled, alpha=0.05))
df_hpd_gq = DataFrame(MCMCChains.hpd(chain_gq, alpha=0.05))
df = leftjoin(df_summary, df_hpd, on = :parameters)
df_gq = leftjoin(df_summary_gq, df_hpd_gq, on = :parameters)
df_stacked = vcat(df, df_gq)
display(df_stacked)

println("Plotting MCMC Results...")
p1 = plot(chain_relabeled)
p2 = plot(chain_gq)

println("PSIS-LOO Calculation...")
loo = RunPSISLOO(model, chain_relabeled)
df_loo =DataFrame(loo.estimates)
df_loo = unstack(df_loo, :statistic, :column, :value)

println("Saving Outputs...")
CSV.write(SUMMARY_PATH, df_stacked)
savefig(p1, PLOT_PATH)
savefig(p2, PLOT_GQ_PATH)
CSV.write(LOO_PATH, df_loo)

println("All done!")

summarizing results...


Row,parameters,mean,std,mcse,ess_bulk,ess_tail,rhat,ess_per_sec,lower,upper
Unnamed: 0_level_1,Symbol,Float64,Float64,Float64,Float64,Float64,Float64,Missing,Float64?,Float64?
1,beta0[1],-1.64465,0.0739455,0.0018363,1619.77,1713.44,1.00236,missing,-1.78302,-1.49302
2,beta0[2],-6.14491,1.13552,0.0341919,1461.09,1109.45,1.00487,missing,-8.30426,-4.35009
3,beta1[1],1.38266,0.108955,0.00261364,1743.96,1985.51,1.00188,missing,1.16492,1.5907
4,beta1[2],18.3322,11.9367,0.399687,1076.92,1063.34,1.0102,missing,7.49481,38.1376
5,beta2[1],1.18533,0.102981,0.0024387,1770.36,2032.3,1.00371,missing,0.976737,1.38407
6,beta2[2],1.56428,1.798,0.0453896,1877.04,1518.69,1.00261,missing,-2.41849,4.96936
7,beta3[1],0.839698,0.151898,0.00353779,1843.64,2210.12,1.00251,missing,0.53527,1.1344
8,beta3[2],4.13794,7.91213,0.336892,1869.79,1022.54,1.00417,missing,-7.56014,17.5495
9,"trans[1, 1]",0.996805,0.00128323,2.56558e-05,2223.86,1671.14,1.00138,missing,0.994314,0.999058
10,"trans[2, 1]",0.00319499,0.00128323,2.56558e-05,2223.86,1671.14,1.00138,missing,0.000941894,0.00568555


Plotting MCMC Results...
PSIS-LOO Calculation...
There are 1 subjects with pareto k > 0.7, and 1 subjects with 0.5 < pareto k ≤ 0.7.
Subject 11: pareto k = 0.6973594849262147
Subject 16: pareto k = 0.8543224971017566
Saving Outputs...
All done!

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mNo source provided for samples; variables are assumed to be from a Markov Chain. If the samples are independent, specify this with keyword argument `source=:other`.



