In [None]:
using Pkg
cd("../")
Pkg.activate(".")

## General inofrmation

In this notebook, we will use the fitted parameters (saved at `data/modeling/Params.CSV`) to plot the results of the model-selection in Figure 4.

To re-fit the models yourself, please check the folder `src/01_ModelFitting`

*Warning:*
The results in the paper focus on the set of models with a *single* type of intrinsic rewards. The code is more general than this setting and include set of models that linearly combine different intrinsic rewards, e.g., 
$$r_{{\rm int},t} = w_N \, \text{Novelty}_t + w_S \, \text{Surpise}_t + w_{I} \, \text{Inf-Gain}_t.$$

The fitted parameters saved at `data/modeling/Params.CSV` include the parameters of these models as well. See the constant `model_settings` in `src/GlobalConstants.jl` for details of different models (charachterized as constraints over parameter space, e.g., by setting $w_S = w_I = 0$ for novelty-seeking).

*Even when including combined models, novelty-seeking has the highest test log-likelihood, due to over-fitting and optimization challgenges of the more complex models.*

In [None]:
using Revise
using PyPlot
using IMRLExploration
using JLD2
using CSV
using Random
using Statistics
using DataFrames
using LogExpFunctions
using Bootstrap

using FitPopulations
import FitPopulations: parameters, logp, sample, initialize!
import FitPopulations: gradient_logp
import FitPopulations: hessian_logp
import FitPopulations: maximize_logp
import FitPopulations: PopulationModel

using ComponentArrays

PyPlot.svg(true)
rcParams = PyPlot.PyDict(PyPlot.matplotlib."rcParams")
rcParams["svg.fonttype"] = "none"
rcParams["pdf.fonttype"] = 42

## Loading data

In [None]:
# ------------------------------------------------------------------------------
# Loading data
# ------------------------------------------------------------------------------
Outliers, Long_Subject, Quit_Subject, Data, Goal_type_Set, Sub_Num =
        Read_processed_data(Plotting = false);
Epi_len = sum(Func_EpiLenghts_all(Data)[1],dims=2)[:]
Data_ns = Str_Input2Agents.(Data);

## Modeling information

In [None]:
# ------------------------------------------------------------------------------
# initializing the params
# ------------------------------------------------------------------------------
Param = Str_Param(; total_leak=true)
Rs    = [1.,1.,1.]
ws    = [1.,1.,1.]
A = Str_Agent_Policy(Param, Rs, NIS_object(ws))
p = parameters(A)
p_names = [string(k) for k = keys(param2η(p))]
m_names = [string(m) for m = keys(model_settings)]

# ------------------------------------------------------------------------------
# parameters dataframe
# ------------------------------------------------------------------------------
ηdf_CV = CSV.read("data/modeling/Params.CSV", DataFrame)

## Evaluating test log-likelihood

In [None]:
nfolds = 3; CV_inds = [Func_GoalCV(i; nfolds = nfolds) for i = 1:nfolds]
nfold_set = 1:3; 
model_set = [5,7,6,8]; # focusing only models that use single reward function


In [None]:
N_epi = 5
# defining the array the contains the test log-p for every participant, episode, and model
logp_vals_SbS = zeros(Sub_Num, N_epi, length(model_set))
# defining the array the contains the test accuracy rate for every participant, episode, and model
Acc_vals_SbS  = zeros(Sub_Num, N_epi, length(model_set))

# going over all participants
for i_sub = 1:Sub_Num
      # going over all models
      for i_model = eachindex(model_set)
            # going over all folds and checking if i_sub is in the testing set of that fold
            for i_fold = eachindex(nfold_set)
            if i_sub ∈ CV_inds[nfold_set[i_fold]][2]
                  # if i_sub is in the testing set of fold i_fold, then:
                  # initialize and agent
                  A = Str_Agent_Policy(Param, Rs, NIS_object(ws));
                  # read-out the parameters of the corresponding model and fold
                  η = ηdf_CV[!,m_names[model_set[i_model]] * "_f" * string(nfold_set[i_fold])]
                  p = ComponentArray(parameters(A, η))
                  # evaluate the model for every trial
                  lps, APol = logp_pass_agent(Data_ns[i_sub], A, p);
                  # save the log-likelihood
                  logp_vals_SbS[i_sub,:,i_model] .= lps
                  # evaluate and save accuracy rate
                  ASta = [Data_ns[i_sub][j].AStates for j = 1:5];
                  for i_epi = 1:N_epi
                        Acc_vals_SbS[i_sub, i_epi, i_model] = 
                              mean([Func_agent_accuracy(APol[i_epi][t],ASta[i_epi][t])
                                                for t = 1:(length(APol[i_epi])-1)])
                  end
                  println("------------------------")
                  @show i_sub
                  @show m_names[model_set[i_model]]
                  @show lps
                  @show Acc_vals_SbS[i_sub, :, i_model]
            end
            end
      end
end

## Fixed effect (not shown in the paper)

In [None]:
# summing log-p over episodes
total_logp_vals_SbS = sum(logp_vals_SbS,dims = 2)[:,1,:]
y = sum(total_logp_vals_SbS,dims=1)[:]

In [None]:
total_logp_vals_SbS = sum(logp_vals_SbS,dims = 2)[:,1,:]
Colors = ["#004D66","#0350B5","#00CCF5"]

fig = figure(figsize=(8,8))
# all subjects
ax = subplot(2,2,1)
y = sum(total_logp_vals_SbS,dims=1)[:]; 
y = y .- findmax(y)[1]; y = y .- log(sum(exp.(y)))
x = 1:length(model_set); x_names = m_names[model_set]
ax.bar(x,y, color = "k");
ax.set_xticks(x)
ax.set_xticklabels(x_names,fontsize=9)
ax.set_ylabel("log P(model | Data)")
ax.set_title("all subjects")

# group-by-group
for g = 0:2
        ax = subplot(2,2,g + 2)
        y = sum(total_logp_vals_SbS[Goal_type_Set .== g,:],dims=1)[:]; 
        y = y .- findmax(y)[1]; y = y .- log(sum(exp.(y)))
        x = 1:length(model_set); x_names = m_names[model_set]
        ax.bar(x,y, color = Colors[g + 1]);
        ax.set_xticks(x)
        ax.set_xticklabels(x_names,fontsize=9)
        ax.set_ylabel("log P(model | Data)")
        ax.set_title(string(g + 2) * "CHF subjects")
end
tight_layout()
display(fig)



## Random effect (Fig 4B)

In [None]:
# hierarchical inference
subset = vcat([CV_inds[i][2] for i = nfold_set]...)
L_names = Goal_type_Set[subset]

# see src/Functions_MCMC_random_effects.jl for the details
L_matrix = deepcopy(total_logp_vals_SbS[subset,:])
R_matrix_samples, M_matrix_samples, R_samples_all, M_samples_all, 
      exp_r, d_exp_r, xp, pxp, exp_M, BOR = MCMC_BMS_Statistics(L_matrix,
            N_Chains=100, N_Sampling = Int(2e5), N_Sampling_BOR = Int(2e5),
            α = 1/size(L_matrix)[2])


In [None]:
# plotting
x_names = m_names[model_set]

y = exp_r; dy = d_exp_r; x = 1:length(y)
fig = figure(figsize=(12,6)); ax = subplot(1,2,1)
ax.bar(x,y, color="k",alpha=0.7)
ax.plot([x[1]-1,x[end]+1],[1,1] ./ length(y), 
            linestyle="dashed",linewidth=1,color="k")
title("Posterior Probabilities for Different Models")
ax.set_xticks(x)
ax.set_xticklabels(x_names,fontsize=9)
ax.set_ylabel("E[P(model) | Data ]")
ax.set_xlim([x[1]-1,x[end]+1])
ax.set_ylim([0,0.6])

y = pxp; x = 1:length(y)
ax = subplot(1,2,2)
ax.bar(x,y, color="k")
ax.plot([x[1]-1,x[end]+1],[1,1] ./ length(y), 
            linestyle="dashed",linewidth=1,color="k")
title("Protected exceedence probabilities")
ax.set_xticks(x)
ax.set_xticklabels(x_names,fontsize=9)
ax.set_ylabel("P[r_m > r_m' | Data ]")
ax.set_xlim([x[1]-1,x[end]+1])
ax.set_ylim([0,1.0])

tight_layout()
display(fig)


## Random effect per goal (Fig 4D)

In [None]:
for j = 0:2    
        # all subjects
        subset = vcat([CV_inds[i][2] for i = nfold_set]...)
        L_names = Goal_type_Set[subset]
        # selecting those with the goal condition j ∈ {0=2CHF, 1=3CHF, 2=4CHF}
        subset = subset[L_names .== j]

        L_matrix = deepcopy(total_logp_vals_SbS[subset,:])
        R_matrix_samples, M_matrix_samples, R_samples_all, M_samples_all, 
                exp_r, d_exp_r, xp, pxp, exp_M, BOR = MCMC_BMS_Statistics(L_matrix,
                N_Chains=100, N_Sampling = Int(2e5), N_Sampling_BOR = Int(2e5),
                α = 1/size(L_matrix)[2])

        x_names = m_names[model_set]
        # average 
        y = exp_r; dy = d_exp_r; x = 1:length(y)
        fig = figure(figsize=(12,6)); ax = subplot(1,2,1)
        ax.bar(x,y, color="k",alpha=0.7)
        ax.plot([x[1]-1,x[end]+1],[1,1] ./ length(y), 
                linestyle="dashed",linewidth=1,color="k")
        title("Posterior Probabilities for Different Models")
        ax.set_xticks(x)
        ax.set_xticklabels(x_names,fontsize=9)
        ax.set_ylabel("E[P(model) | Data ]")
        ax.set_xlim([x[1]-1,x[end]+1])
        ax.set_ylim([0,0.6])

        y = pxp; x = 1:length(y)
        ax = subplot(1,2,2)
        ax.bar(x,y, color="k")
        ax.plot([x[1]-1,x[end]+1],[1,1] ./ length(y), 
                linestyle="dashed",linewidth=1,color="k")
        title("Protected exceedence probabilities")
        ax.set_xticks(x)
        ax.set_xticklabels(x_names,fontsize=9)
        ax.set_ylabel("P[r_m > r_m' | Data ]")
        ax.set_xlim([x[1]-1,x[end]+1])
        ax.set_ylim([0,1.0])

        tight_layout()
        display(fig)
end

## Accuracy rate for novelty-seeking (Fig 4C)

In [None]:
Nov_Acc_vals_SbS = Acc_vals_SbS[:,:,m_names[model_set] .== "N"][:,:,1]

fig = figure(figsize=(6,8))
# all subjects
ax = subplot(2,2,1)
y = mean(Nov_Acc_vals_SbS,dims=1)[:]; 
dy = std(Nov_Acc_vals_SbS,dims=1)[:] ./ sqrt(size(Nov_Acc_vals_SbS)[1]); 
x = 1:length(y); x_names = ["Epi " * string(i) for i = eachindex(y)]
ax.bar(x,y, color = "k", alpha = 0.7);
ax.errorbar(x,y[:],yerr=dy[:],color="k",
            linewidth=1,drawstyle="steps",linestyle="",capsize=3)
ax.plot([x[1] - 1, x[end] + 1],[1,1] ./ 3, "--k");
ax.set_xticks(x)
ax.set_xticklabels(x_names,fontsize=9)
ax.set_ylabel("average accuracy")
ax.set_title("all subjects")
ax.set_ylim([1/3,1]); ax.set_xlim([x[1] - 1, x[end] + 1])

# group-by-group
for g = 0:2
      ax = subplot(2,2,g + 2)
      y = mean(Nov_Acc_vals_SbS[Goal_type_Set .== g,:],dims=1)[:]; 
      dy = std(Nov_Acc_vals_SbS,dims=1)[:] ./ sqrt(sum(Goal_type_Set .== g)); 
      x = 1:length(y); x_names = ["Epi " * string(i) for i = eachindex(y)]
      ax.bar(x,y, color = Colors[g + 1]);
      ax.errorbar(x,y[:],yerr=dy[:],color="k",
            linewidth=1,drawstyle="steps",linestyle="",capsize=3)
      ax.plot([x[1] - 1, x[end] + 1],[1,1] ./ 3, "--k");
      ax.set_xticks(x)
      ax.set_xticklabels(x_names,fontsize=9)
      ax.set_ylabel("average accuracy")
      ax.set_title(string(g + 2) * "CHF subjects")
      ax.set_ylim([1/3,1]); ax.set_xlim([x[1] - 1, x[end] + 1])
end
tight_layout()
display(fig)
