In [56]:
import Pkg
Pkg.activate(".")

dependencies = [
    "IJulia",
    "Revise",
    "Turing",
    "Bijectors",
    "ParetoSmooth",  
    "LogExpFunctions",    
    "StatsPlots",             
    "DataFrames",       
    "JLD2",            
    "CSV"
]

# Pkg.add(dependencies)
Pkg.status()

using Revise, Turing, Bijectors, ParetoSmooth, LogExpFunctions, StatsPlots, DataFrames, JLD2, CSV
using Random, LinearAlgebra

jlfiles = [
    "code/DataPreparation.jl",
    "code/Utils.jl",
    "code/MyModels.jl",
    "code/models/glm.jl",
    "code/models/glmhmm.jl",
    "code/models/iohmm.jl",
]
for file in jlfiles
    include(file)
end

# Check number of threads
println("Running on ", Threads.nthreads(), " threads.")
# Set random seed
const SEED = 123;

[32m[1m  Activating[22m[39m project at `c:\Users\michi\WorkSpace\gitwork\mich2000jp\IPD_IOHMM`


[32m[1mStatus[22m[39m `C:\Users\michi\WorkSpace\gitwork\mich2000jp\IPD_IOHMM\Project.toml`
  [90m[76274a88] [39mBijectors v0.15.14
  [90m[336ed68f] [39mCSV v0.10.15
  [90m[a93c6f00] [39mDataFrames v1.8.1
  [90m[7073ff75] [39mIJulia v1.33.0
  [90m[033835bb] [39mJLD2 v0.6.3
  [90m[2ab3a3ac] [39mLogExpFunctions v0.3.29
  [90m[a68b5a21] [39mParetoSmooth v0.7.16
  [90m[295af30f] [39mRevise v3.13.0
  [90m[f3b207a7] [39mStatsPlots v0.15.8
[32m⌃[39m [90m[fce5fe82] [39mTuring v0.40.5
[36m[1mInfo[22m[39m Packages marked with [32m⌃[39m have new versions available and may be upgradable.
Running on 12 threads.


In [57]:
## Data Preparation ----------------------------------------------------
FIX_PATH = "data/fix.csv"
RAND_PATH = "data/rand.csv"
DATA_PATH = "data/data.jld2"
prepare_data(FIX_PATH, RAND_PATH, DATA_PATH)
data_fp = load_data(DATA_PATH, condition=:FP)
data_sp = load_data(DATA_PATH, condition=:SP);

Reading CSV files...
Saving processed data to data/data.jld2 ...
Data saved successfully.
Data loaded successfully. Excluded IDs: Int64[]
Data loaded successfully. Excluded IDs: Int64[]


In [38]:
## MCMC Settings ----------------------------------------------------
n_iter   = 2000    # Number of samples
n_burnin = 2000    # Burn-in
n_chains = 12       # Number of chains

trt        = "sp"
model_name = "glmhmm"
K_states   = 2
prefix     = "TDist3_Sig50"

sampler = NUTS(0.8; adtype=AutoForwardDiff())

println("=== Analysis Settings ===")
println("MCMC: $n_iter samples, $n_burnin burn-in, $n_chains chains")
println("Model: $model_name, K=$K_states, Condition=$trt, Prefix=$prefix")

=== Analysis Settings ===
MCMC: 2000 samples, 2000 burn-in, 12 chains
Model: glmhmm, K=2, Condition=sp, Prefix=TDist3_Sig50


In [39]:
## MCMC Run ----------------------------------------------------
title = "$(trt)_$(model_name)_K$(K_states)_$prefix"
CHAIN_PATH    = "chain/$title.jld2"
SUMMARY_PATH = "output/$(title)_summary.csv"
LOO_PATH     = "output/$(title)_loo.csv"
PLOT_PATH     = "output/$(title)_plot.png"
PLOT_GQ_PATH  = "output/$(title)_plot_gq.png"

data = trt == "fp" ? data_fp : data_sp
model, model_gq = model_selector(model_name, data, K_states)
@load CHAIN_PATH chain

1-element Vector{Symbol}:
 :chain

In [42]:
chain_relabeled = relabel_chain(chain, K_states)
p = plot(chain_relabeled)
savefig(p, "chain.png")
plp = plot(chain_relabeled[:lp])
savefig(plp, "lp.png")
model_gq = glmhmm(data_fp, K_states, track=true)
gq = generated_quantities(model_gq, chain_relabeled)
chain_gq = convert_gq(gq)
p = plot(chain_gq)
savefig(p, "gq.png")

"c:\\Users\\michi\\WorkSpace\\gitwork\\mich2000jp\\IPD_IOHMM\\gq.png"

In [48]:
function chain_selected(chain, chn::Vector{Int})
    return chain[:, :, chn]
end
chn = [3,4,10,12]
chain_select = chain_selected(chain_relabeled, chn)
p = plot(chain_select)
savefig(p, "chain_selected.png")
plp = plot(chain_select[:lp])
savefig(plp, "lp_selected.png")
model_gq = glmhmm(data_fp, K_states, track=true)
gq = generated_quantities(model_gq, chain_select)
chain_gq = convert_gq(gq)
p = plot(chain_gq)
savefig(p, "gq_selected.png")

"c:\\Users\\michi\\WorkSpace\\gitwork\\mich2000jp\\IPD_IOHMM\\gq_selected.png"

In [50]:
## Post Processing ----------------------------------------------------
chain_gq = convert_gq(gq)
println("summarizing results...")
df_summary = DataFrame(summarystats(chain_select))
df_summary_gq = DataFrame(summarystats(chain_gq))
df_hpd = DataFrame(MCMCChains.hpd(chain_select, alpha=0.05))
df_hpd_gq = DataFrame(MCMCChains.hpd(chain_gq, alpha=0.05))
df = leftjoin(df_summary, df_hpd, on = :parameters)
df_gq = leftjoin(df_summary_gq, df_hpd_gq, on = :parameters)
df_stacked = vcat(df, df_gq)
display(df_stacked)

println("Plotting MCMC Results...")
p1 = plot(chain_select)
p2 = plot(chain_gq)

println("PSIS-LOO Calculation...")
loo = RunPSISLOO(model, chain_select)
df_loo =DataFrame(loo.estimates)
df_loo = unstack(df_loo, :statistic, :column, :value)

println("Saving Outputs...")
CSV.write(SUMMARY_PATH, df_stacked)
savefig(p1, PLOT_PATH)
savefig(p2, PLOT_GQ_PATH)
CSV.write(LOO_PATH, df_loo)

println("All done!")

summarizing results...


Row,parameters,mean,std,mcse,ess_bulk,ess_tail,rhat,ess_per_sec,lower,upper
Unnamed: 0_level_1,Symbol,Float64,Float64,Float64,Float64,Float64,Float64,Missing,Float64?,Float64?
1,beta0[1],-1.19862,0.0561328,0.00200715,791.652,1268.09,1.00969,missing,-1.30496,-1.08717
2,beta0[2],-6.76286,2.11499,0.120997,374.413,579.671,1.01534,missing,-11.3393,-4.40141
3,beta1[1],0.90479,0.0912275,0.00347361,691.323,1286.34,1.01106,missing,0.716996,1.07253
4,beta1[2],12.0386,6.96984,0.490225,289.082,298.178,1.01994,missing,6.97643,24.4603
5,beta2[1],0.227804,0.0855608,0.00255121,1127.93,1713.01,1.00124,missing,0.059801,0.390119
6,beta2[2],0.0429746,1.82879,0.059847,1325.66,842.784,1.00846,missing,-4.04258,3.24302
7,beta3[1],0.983248,0.139045,0.00383225,1316.79,1832.23,1.00085,missing,0.725694,1.26695
8,beta3[2],2.00027,3.15985,0.114209,1111.23,725.5,1.00862,missing,-2.93907,9.02128
9,"trans[1, 1]",0.988319,0.00250537,8.1597e-05,916.929,1584.92,1.0025,missing,0.983215,0.99284
10,"trans[2, 1]",0.0116813,0.00250537,8.1597e-05,916.929,1584.92,1.0025,missing,0.00715959,0.0167849


Plotting MCMC Results...
PSIS-LOO Calculation...
All subjects have pareto k ≤ 0.5.
Saving Outputs...


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mNo source provided for samples; variables are assumed to be from a Markov Chain. If the samples are independent, specify this with keyword argument `source=:other`.


All done!


In [77]:
function encode_experience(exp_seq)
    levels = sort(unique(exp_seq))
    map_exp = Dict(l => i for (i, l) in enumerate(levels))
    y = [map_exp[e] for e in exp_seq]
    return y, levels
end

using Plots.PlotMeasures

function plot_player_experience(data::ExperimentData, player_id::Int)

    exp_seq = data.experiences[player_id]
    levels = sort(unique(exp_seq))
    map_exp = Dict(l => i for (i, l) in enumerate(levels))
    y = [map_exp[e] for e in exp_seq]
    
    T = length(y)

    p = plot(
        2:(T+1),
        y;
        seriestype = :step,
        linewidth = 2,
        legend = false,
        xlabel = "Round",
        ylabel = "Experience",
        yticks = (1:length(levels), string.(levels)),
        title = "Player ID = $player_id",
        margin = 10mm,
        size = (900, 300)
    )

    return p
end
plot_player_experience(data_fp, 16)


LoadError: ArgumentError: Package Plots not found in current path, maybe you meant `import/using .Plots`.
- Otherwise, run `import Pkg; Pkg.add("Plots")` to install the Plots package.

In [60]:
data_fp.experiences[16]

99-element Vector{String3}:
 "CDC"
 "CDC"
 "CCC"
 "CCC"
 "CCC"
 "CCC"
 "CDC"
 "CDC"
 "CCC"
 "CCC"
 "CCC"
 "CCC"
 "CDC"
 ⋮
 "CDC"
 "CDC"
 "CDC"
 "CDC"
 "CDC"
 "CDC"
 "CDC"
 "CDC"
 "CDC"
 "CDC"
 "CDC"
 "CDC"