In [None]:
using Pkg
cd("../")
Pkg.activate(".")

## General inofrmation

In this notebook, we will use the fitted parameters (saved at `data/modeling/Params.CSV`) and show how to perform Posterior Predictive Checks. *Please note that the simulations may take up to an hour depending on your device.*

To re-fit the models yourself, please check the folder `src/01_ModelFitting`, and for the full PPC-based model-comparison, check the folder `src/02_PPCSimulations`.

In [None]:
using Revise
using PyPlot
using IMRLExploration
using JLD2
using Random
using Statistics
using DataFrames
using LogExpFunctions
using CSV

using FitPopulations
using ComponentArrays

PyPlot.svg(true)
rcParams = PyPlot.PyDict(PyPlot.matplotlib."rcParams")
rcParams["svg.fonttype"] = "none"
rcParams["pdf.fonttype"] = 42

Path_Save = "src/02_PPCSimulation/Figures/"


## Loading real data

In [None]:
# ------------------------------------------------------------------------------
# Loading data
# ------------------------------------------------------------------------------
ROutliers, RLong_Subject, RQuit_Subject, RData, RGoal_type_Set, RSub_Num =
        Read_processed_data(Plotting = true);
RData_ns = Str_Input2Agents.(RData);

## Modeling information

In [None]:
# ------------------------------------------------------------------------------
# initializing the params
# ------------------------------------------------------------------------------
Param = Str_Param(; total_leak=true)
Rs    = [1.,1.,1.]
ws    = [1.,1.,1.]
A = Str_Agent_Policy(Param, Rs, NIS_object(ws))
p = parameters(A)
p_names = [string(k) for k = keys(param2η(p))]
m_names = [string(m) for m = keys(model_settings)]

model_set = [5,7,6,8]; # focusing only models that use single reward function

# ------------------------------------------------------------------------------
# parameters dataframe
# ------------------------------------------------------------------------------
ηdf_CV = CSV.read("data/modeling/Params.CSV", DataFrame)

## Simulating the models

The code below, go over all models and simulate them for 1500 random seeds---according to the protocol described in the Supplementary Information. The simulated data will be saved in `src/02_PPCSimulation/Figures/`. These files will be used for extracting general statistics for Figure 5B and SFigure 3 (see `src/02_PPCSimulation/PPC_Simulation_plot.jl`)

In [None]:
n_sub = 1500    # you may want to test with a smaller number first
for i_model = eachindex(model_set)
        η = Array(ηdf_CV[!,[m_names[model_set[i_model]] * "_f" * string(i_fold)
                                                         for i_fold = 1:3]])
        i_pars = Vector{Float64}([])

        rng = MersenneTwister(2024)
        SData = Vector{Str_Input}([])
        SGoal_type_Set = Vector{Int64}([])
        for i = 1:n_sub
                println("--------------------------------")
                println("--------------------------------")
                @show m_names[model_set[i_model]]
                @show i
                
                # sampling parameters
                i_par = rand(rng, 1:3)
                push!(i_pars, i_par)
                p = parameters(A,η[:,i_par])
                # simulationg the agent
                temp = simulate(A, p; ifpass_env = true, rng = rng)
                # reading out the goal type
                G_type = temp.data[1].G_type
                push!(SGoal_type_Set, G_type)
                temp = Str_SASeq2Input(temp)

                @show G_type
                @show Func_GoalType(temp)
                @show Func_EpiLenghts(temp)
                push!(SData,temp)
                println("--------------------------------")
                println("--------------------------------")
        end

        Data_Path_Save = Path_Save * string(m_names[model_set[i_model]]) * "/Data/sdata.jld2"
        save(Data_Path_Save, "η", η, "SData", SData, "i_pars", i_pars,
                             "SGoal_type_Set", SGoal_type_Set)
end

## Plotting qualitative results (as in Fig 2DF and Fig 5A)

To plot the results, you can treat the simulated data just like the normal data (as in `figures/Figure2AC.ipynb`). The script in `src/02_PPCSimulation/PPC_Simulation_plot.jl` reproduces and saves all the figures. It additionally extract all the 43 summary statistics described in the paper and saves them in `src/02_PPCSimulation/figures/PPCStats.CSV`

Here is one example for novelty-seeking:

In [None]:
i_model = 1
Colors = ["#004D66","#0350B5","#00CCF5"]
Data_Path_Load = Path_Save * string(m_names[model_set[i_model]]) * "/Data/sdata.jld2"
# loading the data
temp = load(Data_Path_Load)
SData = temp["SData"]
SOutliers, SLong_Subject, SQuit_Subject, SData, SGoal_type_Set, SSub_Num = 
        Read_processed_data(Data=SData, Plotting = true)

# WARNING: 
# If you have simulated a smaller number of participants, then you may get an 
# error with the following choice of number of data points to be plotted. 
# You can replace the line with
#       points_to_plot = -1; points_to_plot2 = -1
points_to_plot = 20; points_to_plot2 = 60

In [None]:
Func_plot_state_ratio_Epi25(SData, SGoal_type_Set;
        Colors = Colors,
        Sub_testing = false,
        Traps = [8,9],
        Stoch = [7],
        All_states = Array(1:9),
        ifsave=false,
        points_to_plot = points_to_plot)


## Plotting the quantitative results (Fig 5B and SFig 3)

For The script in `src/02_PPCSimulation/PPC_Simulation_plot.jl` extracts all the 43 summary statistics described in the paper and saves them in `src/02_PPCSimulation/figures/PPCStats.CSV`. The code below reads out `src/02_PPCSimulation/figures/PPCStats.CSV` and plot the quantitative results.

In [None]:
# reading out the dataframe
df_stats = CSV.read(Path_Save * "PPCStats.CSV", DataFrame)
# picking the names of colums for the standard error of different statistics
dNames = filter(s -> startswith(s, "d"), names(df_stats))
# the names of colums for different statistics
Names  = replace.(dNames, r"^d" => "")

# evaluating the relative errors
ddf_stats = DataFrame(Models = m_names[model_set])
for s = Names
        ddf_stats[:,s] = zeros(length(model_set))
end
for i_model = eachindex(model_set)
        for s = Names
                i_model2 = findmax(df_stats.Models .== m_names[model_set[i_model]])[2]
                ddf_stats[i_model,s] = 
                        (df_stats[i_model2,s] - df_stats[end,s]) /
                        sqrt(df_stats[i_model2,"d"*s]^2 + df_stats[end,"d"*s]^2)
        end        
end

# plotting
ylims = [1.8,3.5]

x_names = m_names[model_set]; x = 1:length(model_set)

fig = figure(figsize=(9,6)); 
ax = subplot(1,2,1)
y = mean(abs.(Array(
                ddf_stats[ddf_stats.Models .== x_names, Names])), 
        dims=2)[:]
ax.bar(x, y, color = "k")
ax.set_xticks(x); ax.set_xticklabels(x_names)
ax.set_ylabel("average relative error")
ax.set_xlim([x[1]-1,x[end]+1])
ax.set_ylim(ylims)

ax = subplot(1,2,2)
y = median(abs.(Array(
                ddf_stats[ddf_stats.Models .== x_names, Names])), 
        dims=2)[:]
ax.bar(x, y, color = "k")
ax.set_xticks(x); ax.set_xticklabels(x_names)
ax.set_ylabel("median relative error")
ax.set_xlim([x[1]-1,x[end]+1])
ax.set_ylim(ylims)
tight_layout()
display(fig)



In [None]:
fig = figure(figsize=(20,3)); 
ax = subplot(1,1,1)
cp = ax.imshow(abs.(Array(ddf_stats[ddf_stats.Models .== x_names, Names])), cmap="Reds")
ax.set_xticks(0:(length(Names)-1))
ax.set_xticklabels(Names, rotation = 90)
ax.set_yticks(0:(x[end]-1))
ax.set_yticklabels(x_names)
fig.colorbar(cp, ax=ax)
tight_layout()
display(fig)


## Simulating efficient algorithms

To simulate the efficient algorithms (in Fig S1 and Fig S2), one can repeat the procedure above by with the parameters in `IMRLExploration.η0_efficient_sim` instead of the fitted parameters. See `HowTo.ipynb` for an example.