In [None]:
using Pkg
Pkg.activate(".")

## Table of Contents
- [1. General inofrmation](#1.-general-inofrmation)
- [2. Loading and working with the raw data](#2.-loading-and-working-with-the-raw-data)
- [3. Loading and working with the processed data](#3.-loading-and-working-with-the-processed-data)
- [4. Simulating the models](#4.-simulating-the-models)
- [5. Evaluating log-likelihood](#5.-evaluating-log-likelihood)

# 1. General inofrmation

In [None]:
using Revise
using PyPlot
using IMRLExploration
using DataFrames
using Statistics
using HypothesisTests

PyPlot.svg(true)
rcParams = PyPlot.PyDict(PyPlot.matplotlib."rcParams")
rcParams["svg.fonttype"] = "none"
rcParams["pdf.fonttype"] = 42

Colors = ["#004D66","#0350B5","#00CCF5"]
Legends = ["CHF2","CHF3","CHF4"];

# 2. Loading and working with the raw data

**If you would like to only use the data without using our code, you can find a minimal version of the raw data saved in `data/tidydata.CSV` (with the same notation as in the paper).**

However, if you'd like to use code, this is how we can read the data of all 63 participants:

In [None]:
Data = Read_data_all();
length(Data)

`Data[n]` corresponds to the data of participant `n` in the format of structure `Str_Input` (see `src/Structs_form.jl`).

`Str_Input` has the following entries:

- `Sub`: Subject id (between 1 to 63)

- `Gender`: 1 for male and 0 for female

- `actions`: `actions[i]` is the sequence of actions during episode `i` (each action denoted by 0, 1, or 2)

- `states`: `states[i]` is a coarse grained sequence of states during episode `i` with all stochastic states put together as state `7` and all goal states put together as state `0`; trap state 7 and 8 are encoded as states `8` and `9`, respectively

- `images`: `images[i][:,t]` is a vector corresponding to the exact state at time step `t` during episode `i`. `images[i][1,t]` denotes whether the state is in the stochastic part (if equal to `1`), is a goal state (if equal to `2`) or is a normal state (if equal to `0`). For the normal states (states 1 to 8), `images[i][2,t]` corresponds to `states[i][t]`. For the goal state, `images[i][2,t]` denotes the value of the goal (`0` for 2CHF, `1` for 3CHF, and `2` for 4CHF). For the stochastic states, `images[i][2,t]` corresponds to the index of the stochastic state (between 0 to 49).

- `trial_time`: `trial_time[i]` is the sequence of real times of the start of each trial, during episode `i`

- `resp_time`: `resp_time[i]` is the sequence of reaction times during episode `i`

- `TM`: the transition matrix for this participant. It is the same for all participants except for state 4, because of the action manipulation described in the paper.

#### Let's look at the data of participant 3:

In [None]:
n_sub = 3
d = Data[n_sub];
@show fieldnames(typeof(d))

In [None]:
d.images[1]

#### Let's look at the data in episode 1

In [None]:
n_epi = 1;
T = length(d.states[n_epi])

fig = figure(figsize=(12,6));
ax = subplot(2,2,1)
ax.plot(1:T, d.states[n_epi],"-o",color="k")
ax.plot((1:T)[d.states[n_epi] .== 7], d.states[n_epi][d.states[n_epi] .== 7],"o",color="b")
ax.plot((1:T)[d.states[n_epi] .== 8], d.states[n_epi][d.states[n_epi] .== 8],"o",color="r")
ax.legend(["all states","stochastic states", "state 8 (7 in the paper; trap)"])
ax.set_xlabel("t"); ax.set_ylabel("states[n_epi]"); 
ax.set_xlim([0,T+1]); ax.set_ylim([-0.1,9.1]); 

ax = subplot(2,2,3)
ax.plot(1:T, d.actions[n_epi],"-o",color="k")
ax.plot((1:T)[d.states[n_epi] .== 7], d.actions[n_epi][d.states[n_epi] .== 7],"o",color="b")
ax.plot((1:T)[d.states[n_epi] .== 8], d.actions[n_epi][d.states[n_epi] .== 8],"o",color="r")
ax.legend(["all states","stochastic states", "state 8 (7 in the paper; trap)"])
ax.set_xlabel("t"); ax.set_ylabel("actions[n_epi]"); 
ax.set_xlim([0,T+1]); ax.set_ylim([-0.1,2.1]); 

ax = subplot(1,2,2)
ax.plot(d.images[n_epi][1,:] .+ (rand(T) .* 0.2) , d.images[n_epi][2,:], "o", color = "k", alpha = 0.5)
ax.set_xlabel("images[n_epi][1,:]"); ax.set_ylabel("images[n_epi][2,:]"); 

tight_layout()
display(fig)

# 3. Loading and working with the processed data

You can directly read the processed data (after removing the outliers) by using `Read_processed_data`

The function returns the following variables:

- `Outliers`: The vector of lenght 63 indicating outliers (= `Long_Subject .| Quit_Subject`).

- `Long_Subject`: The vector of lenght 63 indicating the subjects who had more than 3 times group averaged actions (see the paper).

- `Quit_Subject`: The vector of lenght 63 indicating the subjects who had quit the experiments.

- `Data`: The vector of Str_Input (the lenght must be 57).

- `Goal_type_Set`: The vector of goal types (0=2CHF, 1=3CHF, 2=4CHF; the lenght must be 57).

- `Sub_Num`: Number of subjects (must be 57).

In [None]:
# ------------------------------------------------------------------------------
# Loading data
# ------------------------------------------------------------------------------
Outliers, Long_Subject, Quit_Subject, Data, Goal_type_Set, Sub_Num =
        Read_processed_data();

Let's replicate Fig 2A (simplification of `Func_plot_state_ratio_Epi1`; see `figures/Figure2AC.ipynb`)

In [None]:
# choosing episode 1
Epi = 1;
# information of states
Traps = [8,9]; Stoch = [7];
All_states = [1,2,3,4,5,6,7,8,9];

To do so, we use the function `Func_desired_states_visit` which extracts the information of vising a set of `DesiredStates` for all participants. The idea is to extract how many time a participant visits `DesiredStates` and, during each visit, how long he/she stays in the `DesiredStates` (similar in to Fig. 2A in [Xu and Modirshanechi et al. 2021 in PLOS Comp. Bio.](https://doi.org/10.1371/journal.pcbi.1009070)).

To test, for example, let's put `DesiredStates = Traps` or `DesiredStates = Stoch`, and let's, for now, not specify to split the episode into two parts, i.e., `first_half = false, second_half = false`. Then:

In [None]:
y, dy, N, Lenghts, y_med, y_Q25, y_Q75 = 
        Func_desired_states_visit(Data, DesiredStates = Traps,
                                    Epi= Epi,
                                    first_half = false,
                                    second_half = false);
fig = figure(figsize = (8,4))
ax = subplot(1,1,1)
for i_sub = 1:Sub_Num
    ax.plot(Lenghts[i_sub])
end
ax.set_xlabel("Number of visits of the trap states")
ax.set_ylabel("Number of actions during each visit")
display(fig)

In [None]:
y, dy, N, Lenghts, y_med, y_Q25, y_Q75 = 
        Func_desired_states_visit(Data, DesiredStates = Stoch,
                                    Epi= Epi,
                                    first_half = false,
                                    second_half = false);
fig = figure(figsize = (8,4))
ax = subplot(1,1,1)
for i_sub = 1:Sub_Num
    ax.plot(Lenghts[i_sub])
end
ax.set_xlabel("Number of visits of the stochastic part")
ax.set_ylabel("Number of actions during each visit")
display(fig)

Using this function, we can evaluate the total number of actions that each participant takes within the first and the second halves of Episode 1.

In [None]:
NTraps = zeros(Sub_Num,2)
NStoch = zeros(Sub_Num,2)
NAll   = zeros(Sub_Num,2)

Options = [[ true, false],      # first half
            [false, true ]]     # second half
for i_opt = 1:2
        # extracting the information for the traps
        y, dy, N, Lenghts, y_med, y_Q25, y_Q75 = 
                Func_desired_states_visit(Data, DesiredStates = Traps,
                                    Epi= Epi,
                                    first_half = Options[i_opt][1],
                                    second_half = Options[i_opt][2])
        NTraps[:,i_opt] = sum.(Lenghts)
        # extracting the information for the stochastic part
        y, dy, N, Lenghts, y_med, y_Q25, y_Q75 =  
                Func_desired_states_visit(Data, DesiredStates = Stoch,
                                    Epi= Epi,
                                    first_half = Options[i_opt][1],
                                    second_half = Options[i_opt][2])
        NStoch[:,i_opt] = sum.(Lenghts)
        # extracting the information for "all states" = the lenght of the 1st/2nd halves
        y, dy, N, Lenghts, y_med, y_Q25, y_Q75 =  
                Func_desired_states_visit(Data, DesiredStates = All_states,
                                    Epi= Epi,
                                    first_half = Options[i_opt][1],
                                    second_half = Options[i_opt][2])
        NAll[:,i_opt] = sum.(Lenghts)
end


This is a participant-by-participant version (equivalent to the data points in Fig 2A):

In [None]:
fig = figure(figsize = (12,10))
ax = subplot(2,1,1)
ax.bar((0:3:((Sub_Num-1)*3)) .+ 0, NTraps[:,1] ./ NAll[:,1], color="r")
ax.bar((0:3:((Sub_Num-1)*3)) .+ 1, NStoch[:,1] ./ NAll[:,1], color="b")
ax.set_xticks((0:9:((Sub_Num-1)*3)) .+ 0.5); ax.set_xticklabels(1:3:Sub_Num)
ax.set_xlim([-1,3*Sub_Num -1]); ax.set_ylim([0,1])
ax.legend(["traps", "stoch part"])
ax.set_xlabel("Participant")
ax.set_ylabel("Fraction of time steps, during the 1st half of E1, in ...")
ax = subplot(2,1,2)
ax.bar((0:3:((Sub_Num-1)*3)) .+ 0, NTraps[:,2] ./ NAll[:,2], color="r")
ax.bar((0:3:((Sub_Num-1)*3)) .+ 1, NStoch[:,2] ./ NAll[:,2], color="b")
ax.set_xticks((0:9:((Sub_Num-1)*3)) .+ 0.5); ax.set_xticklabels(1:3:Sub_Num)
ax.set_xlim([-1,3*Sub_Num -1]); ax.set_ylim([0,1])
ax.legend(["traps", "stoch part"])
ax.set_xlabel("Participant")
ax.set_ylabel("Fraction of time steps, during the 2nd half of E1, in ...")

display(fig)



Averaged version:

In [None]:
Y_traps = NTraps ./ NAll
Y_stoch = NStoch ./ NAll
Y = cat(Y_traps,Y_stoch,dims = 3)

mY = mean(Y,dims=1)[1,:,:]
dY = std(Y,dims=1)[1,:,:] ./ sqrt(Sub_Num)

# X-axis information
x_0 = [1,4]; x_0ticks = ["Traps", "Stoch"]; σ = 0.2

fig = figure(figsize=(4,7)); ax = subplot(1,1,1)
for i=1:2
        x = x_0 .+ (i-1)
        ax.bar(x,mY[i,:], color = Colors[i])
end
ax.legend(["1st half", "2nd half"])
for i=1:2
        x = x_0 .+ (i-1)
        ax.errorbar(x,mY[i,:],yerr=dY[i,:],color="k",
                    linewidth=1,drawstyle="steps",linestyle="",capsize=3)
        for j = 1:Sub_Num
                x_plot = x .+ 2*σ*(rand() - 0.5)
                ax.plot(x_plot,Y[j,i,:],".k",alpha = 0.5)
        end
end
ax.set_xticks(x_0.+0.5)
ax.set_xticklabels(x_0ticks)
ax.set_title("Fraction of time in Epi 1")
ax.set_xlim([0,6])
display(fig)


# 4. Simulating the models

For modeling part, we need some additional packages:

In [None]:
using JLD2
using Random
using LogExpFunctions
using FitPopulations    # = an older version of LaplacianExpectationMaximization.jl
using ComponentArrays

Now, let's play around a bit with the computational models and see how we can simulate them. Let's first see how we can define an "agent." 

The general idea is that an agent has two parts: The "policy" part that includes the internal variables such as Q-values and action policy (specified by `Str_Agent_Policy`) and the "state" part that includes the agent's action and environmental state (specified by `Str_Agent_State`). 

The logic is that `Str_Agent_State` can be evaluated purely based on the data, with no modeling assumption, while `Str_Agent_Policy` captures all the modeling variables. 
In other words, `Str_Agent_State` can be seen as an alternative way to represent the participants' data; see `src\Structs_form.jl`.

### Initialize the agent type

We first need to initialize the "type" of model we want to work with, by initializing `Str_Agent_Policy`. At this point, we only need to care about 3 things:

1. `total_leak` variable: Indicating whether the total count of states $\tilde{C}_s^{(t)}$ used for novelty will be leaked; it is always set to `true` for all algorithms in the paper.

2. `back_leak` variable: Indicator of whether there is a background leak of all counts; it is only set to `true` for the control algorithms in Section 3.4 of SI.

3. Combined intrinsic rewards using `NIS_object`: While the results in the paper focus on the set of models with a *single* type of intrinsic rewards, the implementation here is is more general and allows linearly combining different intrinsic rewards, e.g., 
$$r_{{\rm int},t} = w_N \, \text{Novelty}_t + w_{I} \, \text{Inf-Gain}_t + w_S \, \text{Surpise}_t.$$
The `NIS_object` makes this possible efficiently through appropriate memory allocation.

In [None]:
# specifying total_leak and back_leak
Param = Str_Param(total_leak = true, back_leak = false); 
# for now, you can see these as just some dummy variables
Rs = [1.,1.,1.]; ws = [1.,1.,1.];
# initializaing the agent
A = Str_Agent_Policy(Param, Rs, NIS_object(ws))
for name in fieldnames(Str_Agent_Policy)
    println("$(name)")
end


Here is the detailed explanation of the fields of `Str_Agent_Policy`:
| Field of `Str_Agent_Policy` | Meaning|
|-----------------------|----------------------------------------|
| `Param`               | The set of parameters used by the agent (see below)|
| `Func_eR_sas`         | A function that receives $s$, $a$, and $s'$ and returns the extrinsic rewards; see `src/Functions_for_environments.jl`|
| `Func_iR_sas`         | A function that receives $s$, $a$, and $s'$ and returns the intrinsic rewards; see `src/Functions_for_rewards.jl`|
| `C_s`                 | $\tilde{C}_s^{(t)}$ in the paper|
| `C_sa`                | $\tilde{C}_{s,a}^{(t)}$ in the paper|
| `C_sas`               | $\tilde{C}_{s,a,s'}^{(t)}$ in the paper|
| `θ_sas`               | $p^{(t)}(s';s,a) = \hat{\theta}^{(t)}_{s,a,s'}$ in the paper|
| `eR_sas`              | The array of extrinsic rewards evaluated using `Func_eR_sas`|
| `iR_sas`              | The array of intrinsic rewards evaluated using `Func_iR_sas`|
| `Q_MBe`               | $Q_{\rm MB, ext}^{(t)}$ in the paper|
| `Q_MBi`               | $Q_{\rm MB, int}^{(t)}$ in the paper|
| `U_e`                 | $U_{\rm MB, ext}^{(t)}$ in the paper (for prioritized sweeping)|
| `U_i`                 | $U_{\rm MB, int}^{(t)}$ in the paper (for prioritized sweeping)|
| `V_dummy`             | Some dummy vector defined for *a priori* memory allocation needed for prioritized sweeping (necessary for the automatic differentiation)|
| `P_dummy`             | Some dummy vector defined for *a priori* memory allocation needed for prioritized sweeping (necessary for the automatic differentiation)|
| `Q_MFe`               | $Q_{\rm MF, ext}^{(t)}$ in the paper|
| `Q_MFi`               | $Q_{\rm MF, int}^{(t)}$ in the paper|
| `E_e`                 | $e_{\rm ext}^{(t)}$ in the paper|
| `E_i`                 | $e_{\rm int}^{(t)}$ in the paper|
| `Q_MB_t`              | weighted combination of the MB Q-values at time $t$|
| `Q_MF_t`              | weighted combination of the MF Q-values at time $t$|
| `Q_t`                 | weighted combination of all Q-values at time $t$|
| `π_A_t`               | $\pi_{t}$ in the paper (i.e., the softmax transformation of `Q_t`)|
| `eR_t`                | $r_{{\rm ext}, t}$ in the paper|
| `eRPE_t`              | $RPE_{{\rm ext}, t}$ in the paper||
| `iR_t`                | $r_{{\rm int}, t}$ in the paper|
| `iRPE_t`              | $RPE_{{\rm int}, t}$ in the paper|
| `eR_max_t`            | Maximum extrinsic reward found so far (indicating the degree of reward optimism)|

### Agent's parameters

Now that we have initialized the agent `A`, we need to specify the set of parameters for this agent. This can be done in an outer loop to facilate model-fitting and simulations. To do so, we can use the function `parameters` from the `FitPopulations.jl` package (an older version of `LaplacianExpectationMaximization.jl`) to define the set a set of parameters:

In [None]:
p = parameters(A);


However, `p` is now in the space of real numbers, for the sake of optimization (see the paper's SI for details). We can convert it to using `param2η`:


In [None]:
η = param2η(p)
η_names = [string(k) for k = keys(η)]
for (k, v) in pairs(η)
    println("$(k) = $(v)")
end

Note that there are slight differences in the notations between the code and the paper; here is the detailed explanation of the fields of `η`:
| Field of `η` | Meaning|
|----------|----------|
| `κ`           | $\kappa$ in the paper|
| `ϵ_new`       | $\epsilon_{\rm new}$ in the paper|
| `ϵ_obs`       | $\epsilon_{\rm known}$ in the paper|
| `λ_e`       | $\gamma_{\rm ext}$ in the paper|
| `λ_i`       | $\gamma_{\rm int}$ in the paper|
| `T_PS_e`       | $T_{PS,{\rm ext}}$ in the paper|
| `T_PS_i`       | $T_{PS,{\rm int}}$ in the paper|
| `ρ`       | $\rho$ in the paper|
| `μ_e`       | $\lambda_{\rm ext}$ in the paper|
| `μ_i`       | $\lambda_{\rm int}$ in the paper|
| `Q_e0`       | $Q_{\rm MF, ext}^{(0)}$ in the paper|
| `Q_i0`       | $Q_{\rm MF, int}^{(0)}$ in the paper|
| `β_MBe_1`     | $\beta_{\rm MB, ext}^{(1)}$ in the paper|
| `β_MBe_2_G0`     | $\beta_{\rm MB, ext}^{(2, {\rm 2CHF})}$ in the paper|
| `β_MBe_2_G1`     | $\beta_{\rm MB, ext}^{(2, {\rm 3CHF})}$ in the paper|
| `β_MBe_2_G2`     | $\beta_{\rm MB, ext}^{(2, {\rm 4CHF})}$ in the paper|
| `β_MBi_1`     | $\beta_{\rm MB, int}^{(1)}$ in the paper|
| `β_MBi_2_G0`     | $\beta_{\rm MB, int}^{(2, {\rm 2CHF})}$ in the paper|
| `β_MBi_2_G1`     | $\beta_{\rm MB, int}^{(2, {\rm 3CHF})}$ in the paper|
| `β_MBi_2_G2`     | $\beta_{\rm MB, int}^{(2, {\rm 4CHF})}$ in the paper|
| `β_MFe_1`     | $\beta_{\rm MF, ext}^{(1)}$ in the paper|
| `β_MFe_2_G0`     | $\beta_{\rm MF, ext}^{(2, {\rm 2CHF})}$ in the paper|
| `β_MFe_2_G1`     | $\beta_{\rm MF, ext}^{(2, {\rm 3CHF})}$ in the paper|
| `β_MFe_2_G2`     | $\beta_{\rm MF, ext}^{(2, {\rm 4CHF})}$ in the paper|
| `β_MFi_1`     | $\beta_{\rm MF, int}^{(1)}$ in the paper|
| `β_MFi_2_G0`     | $\beta_{\rm MF, int}^{(2, {\rm 2CHF})}$ in the paper|
| `β_MFi_2_G1`     | $\beta_{\rm MF, int}^{(2, {\rm 3CHF})}$ in the paper|
| `β_MFi_2_G2`     | $\beta_{\rm MF, int}^{(2, {\rm 4CHF})}$ in the paper|
| `Q_bias_1`     | $b({\rm middle})$ in the paper|
| `Q_bias_2`     | $b({\rm right})$ in the paper|
| `r_1`     | $r_1^*$ in the paper|
| `r_2`     | $r_2^*$ in the paper|
| `wN`     | $w_N$ above|
| `wI`     | $w_I$ above|
| `wS`     | $w_S$ above|




Now we can set the variables as we wish (or, e.g., using `IMRLExploration.η0_efficient_sim` for efficient simulations) by manipulating `η`.

WARNING: Note that changing `T_PS_e` and `T_PS_i` will be over-written by the system to always keep them at 100 (see the paper's SI). Similarly, changing `β_MBe_1` will be overwritten by the system to always keep it at 0.1; this is done without the loss of generality as `wN`, `wI`, and `wS` are not bounded.

In [None]:
# let's consider perfect model-building
η = merge(η, (;κ = 1.))
# let's remove the model-free branch
η = merge(η, (;β_MFe_1 = 0., β_MFe_2_G0 = 0., β_MFe_2_G1 = 0., β_MFe_2_G2 = 0.))
η = merge(η, (;β_MFi_1 = 0., β_MFi_2_G0 = 0., β_MFi_2_G1 = 0., β_MFi_2_G2 = 0.))
# let's assume there is no difference between the extrinsic reward-values of different goals
η = merge(η, (;r_1 = 1., r_2 = 1.))
# and let's consider novlety-seeking
η = merge(η, (;wN = 30., wI = 0., wS = 0.))

# this can now be transformed back to the real-value space:
p = parameters(A,η);

for (k, v) in pairs(η)
    println("$(k) = $(v)")
end

### Simulation

Now that we have initialzed the agent type `A` and have specified its parameters `p`, we can go on and simulate the agent for 5 episodes in our experimental paradigm:

In [None]:
# setting the random seed
rng = MersenneTwister(2023)
# setting the goal parobability
G_type_prob = [1.,0.,0.]    # i.e., 100% probability of being in an Environment with the 2CHF goal
# simulation
simulation_results = simulate(A, p; G_type_prob = G_type_prob, ifpass_env = true, rng = rng);

`simulation_results` has three parts:

1. `simulation_results.data` has 5 elements corresponding to 5 simulated episodes. Each element contains two parts:

    1.1. `simulation_results.data[i].AStates` is the sequence of `Str_Agent_State` structs druing episode `i`.
    
    1.2. `simulation_results.data[i].G_type` is the goal type of episode `i` (must be the same for all episodes).

2. `simulation_results.logp` is the log-likelihood of the simulated data under the true parameters.

3. `simulation_results.TM` is the transition matrix of the environment (after the action manipulation in state 4; see the paper).

We can transform `simulation_results` into the data type of `Str_Input` (just like the participants' data analyzed above) and analyze it similarly:

In [None]:
d = Str_SASeq2Input(simulation_results);
n_epi = 1;
T = length(d.states[n_epi])

fig = figure(figsize=(12,6));
ax = subplot(2,2,1)
ax.plot(1:T, d.states[n_epi],"-o",color="k")
ax.plot((1:T)[d.states[n_epi] .== 7], d.states[n_epi][d.states[n_epi] .== 7],"o",color="b")
ax.plot((1:T)[d.states[n_epi] .== 8], d.states[n_epi][d.states[n_epi] .== 8],"o",color="r")
ax.legend(["all states","stochastic states", "state 8 (7 in the paper; trap)"])
ax.set_xlabel("t"); ax.set_ylabel("states[n_epi]"); 
ax.set_xlim([0,T+1]); ax.set_ylim([-0.1,9.1]); 

ax = subplot(2,2,3)
ax.plot(1:T, d.actions[n_epi],"-o",color="k")
ax.plot((1:T)[d.states[n_epi] .== 7], d.actions[n_epi][d.states[n_epi] .== 7],"o",color="b")
ax.plot((1:T)[d.states[n_epi] .== 8], d.actions[n_epi][d.states[n_epi] .== 8],"o",color="r")
ax.legend(["all states","stochastic states", "state 8 (7 in the paper; trap)"])
ax.set_xlabel("t"); ax.set_ylabel("actions[n_epi]"); 
ax.set_xlim([0,T+1]); ax.set_ylim([-0.1,2.1]); 

ax = subplot(1,2,2)
ax.plot(d.images[n_epi][1,:] .+ (rand(T) .* 0.2) , d.images[n_epi][2,:], "o", color = "k", alpha = 0.5)
ax.set_xlabel("images[n_epi][1,:]"); ax.set_ylabel("images[n_epi][2,:]"); 

tight_layout()
display(fig)

Let's simulate a surprise-seeking agent now:

In [None]:
η = merge(η, (;wN = 0., wI = 0., wS = 30.))
p = parameters(A,η);
rng = MersenneTwister(2023)
G_type_prob = [1.,0.,0.]
simulation_results = simulate(A, p; G_type_prob = G_type_prob, ifpass_env = true, rng = rng);
d = Str_SASeq2Input(simulation_results);

n_epi = 1;
T = length(d.states[n_epi])

fig = figure(figsize=(12,6));
ax = subplot(2,2,1)
ax.plot(1:T, d.states[n_epi],"-o",color="k")
ax.plot((1:T)[d.states[n_epi] .== 7], d.states[n_epi][d.states[n_epi] .== 7],"o",color="b")
ax.plot((1:T)[d.states[n_epi] .== 8], d.states[n_epi][d.states[n_epi] .== 8],"o",color="r")
ax.legend(["all states","stochastic states", "state 8 (7 in the paper; trap)"])
ax.set_xlabel("t"); ax.set_ylabel("states[n_epi]"); 
ax.set_xlim([0,T+1]); ax.set_ylim([-0.1,9.1]); 

ax = subplot(2,2,3)
ax.plot(1:T, d.actions[n_epi],"-o",color="k")
ax.plot((1:T)[d.states[n_epi] .== 7], d.actions[n_epi][d.states[n_epi] .== 7],"o",color="b")
ax.plot((1:T)[d.states[n_epi] .== 8], d.actions[n_epi][d.states[n_epi] .== 8],"o",color="r")
ax.legend(["all states","stochastic states", "state 8 (7 in the paper; trap)"])
ax.set_xlabel("t"); ax.set_ylabel("actions[n_epi]"); 
ax.set_xlim([0,T+1]); ax.set_ylim([-0.1,2.1]); 

ax = subplot(1,2,2)
ax.plot(d.images[n_epi][1,:] .+ (rand(T) .* 0.2) , d.images[n_epi][2,:], "o", color = "k", alpha = 0.5)
ax.set_xlabel("images[n_epi][1,:]"); ax.set_ylabel("images[n_epi][2,:]"); 

tight_layout()
display(fig)

# 5. Evaluating log-likelihood

We can use the same setting as described above for model-based study of participants' behavior.

To do so, we can first specify the model similarly to before:

In [None]:
# initializaing the agent
Param = Str_Param(total_leak = true, back_leak = false); 
Rs = [1.,1.,1.]; ws = [1.,1.,1.];
A = Str_Agent_Policy(Param, Rs, NIS_object(ws))

# setting the parameters for a novelty-seeking agent
pN = parameters(A); η = param2η(pN);
η = merge(η, (;κ = 1.))
η = merge(η, (;β_MFe_1 = 0., β_MFe_2_G0 = 0., β_MFe_2_G1 = 0., β_MFe_2_G2 = 0.))
η = merge(η, (;β_MFi_1 = 0., β_MFi_2_G0 = 0., β_MFi_2_G1 = 0., β_MFi_2_G2 = 0.))
η = merge(η, (;r_1 = 2., r_2 = 10.))
η = merge(η, (;wN = 5., wI = 0., wS = 0.));
pN = parameters(A,η);
# setting the parameters for a surprise-seeking agent
η = merge(η, (;wN = 0., wI = 0., wS = 5.));
pS = parameters(A,η);
# setting the parameters for a information-gain-seeking agent
η = merge(η, (;wN = 0., wI = 5., wS = 0.));
pI = parameters(A,η);
# setting the parameters for an agent with no intrinsic reward
η = merge(η, (;wN = 0., wI = 0., wS = 0.));
pR = parameters(A,η);


Then, we need to convert participants' data to sequence of `Str_Agent_State`, using `Str_Input2Agents`:

In [None]:
Data_ns = Str_Input2Agents.(Data);

Now, let's look at participant 20:

In [None]:
n_sub = 20
draw = Data[n_sub]
d = Data_ns[n_sub];

lpsN, APolN = logp_pass_agent(d, A, pN); 
lpsS, APolS = logp_pass_agent(d, A, pS); 
lpsI, APolI = logp_pass_agent(d, A, pI); 
lpsR, APolR = logp_pass_agent(d, A, pR); 

The function `logp_pass_agent` takes as input the data of a praticipant `d`, an agent type `A`, and a set of parameters, e.g., `pN`. Then it returns:
1. `lps[i]`: the log-likelihood of data for episode `i`
2. `APol[i]`: the sequece of `Str_Agent_Policy` during episode `i`

Then, we can, for example, compare the log-likelihood of the different models:

In [None]:
NEpi = 5
fig = figure(figsize=(6,6));
ax = subplot(1,1,1)
ax.bar((0:5:(5*(NEpi-1))) .+ 0, lpsN .- lpsR)
ax.bar((0:5:(5*(NEpi-1))) .+ 1, lpsI .- lpsR)
ax.bar((0:5:(5*(NEpi-1))) .+ 2, lpsS .- lpsR)

ax.legend(["novelty-seeking", "inf-gain-seeking", "surprise-seeking"])
ax.set_xticks((0:5:(5*(NEpi-1))) .+ 1); ax.set_xticklabels(1:NEpi)
ax.plot([0,5*5] .- 1.5,[0,0],"--k"); ax.set_xlim([0,5*5] .- 1.5)

ax.set_xlabel("Episode")
ax.set_ylabel("log P( Data | Model ) -  log P( Data | nIR )")

tight_layout()
display(fig)

Or we can look at the sequence Q-values of specific actions.
For example, let's look at the Q-values of different actions in state 4:

In [None]:
# encoding of state 1 as an image
state = 4
St = [0,state - 1]; 

# for intrinsic and extrinsic Q-values
Qis = Vector{Vector{Float64}}([])
Qes = Vector{Vector{Float64}}([])
As  = []

# looking at episode 1
n_epi = 1

# looping over time steps
T = length(APolN[n_epi])
for t = 1:T
    # checking if the agent has observed St yet
    if St ∈ d[n_epi].AStates[t].State_Set
        sid = findmax([St == s for s = d[n_epi].AStates[t].State_Set])[2]
        push!(Qis, APolN[n_epi][t].Q_MBi[sid,:])
        push!(Qes, APolN[n_epi][t].Q_MBe[sid,:])

        if d[n_epi].AStates[t].S_t == St
            push!(As, [t,d[n_epi].AStates[t].A_t])
        end
    else
        push!(Qis, [1,1,1] .* NaN)
        push!(Qes, [1,1,1] .* NaN)
    end
end
As = hcat(As...); Qis = hcat(Qis...); Qes = hcat(Qes...);

# plotting
fig = figure(figsize=(6,10));
ax = subplot(4,1,1)
ax.plot(1:T, draw.states[n_epi],"-o",color="k")
ax.plot((1:T)[draw.states[n_epi] .== state],
        draw.states[n_epi][draw.states[n_epi] .== state],"o",color="r")
ax.set_xlim([0,T+1])
ax.set_xlabel("trials within episode " * string(n_epi))
ax.set_ylabel("state")
ax.legend(["all states","in state " * string(state)])

ax = subplot(4,1,2)
ax.plot(1:T, draw.actions[n_epi],"-o",color="k")
ax.plot((1:T)[draw.states[n_epi] .== state], 
        draw.actions[n_epi][draw.states[n_epi] .== state],"o",color="r")
ax.set_xlim([0,T+1])
ax.set_xlabel("trials within episode " * string(n_epi))
ax.set_ylabel("action")
ax.legend(["all states","in state " * string(state)])

ax = subplot(4,1,3)
for i = 1:3
    ax.plot(1:T,Qis[i,:] .- Qis[1,:])
end
ax.set_xlim([0,T+1])
ax.set_xlabel("trials within episode " * string(n_epi))
ax.set_ylabel("Q_i(a,s=" * string(state) * ") - Q_i(a1,s=" * string(state) * ")")
ax.legend(["a1","a2","a3"])


ax = subplot(4,1,4)
for i = 1:3
    ax.plot(1:T,Qes[i,:] .- Qes[1,:])
end
ax.set_xlim([0,T+1])
ax.set_xlabel("trials within episode " * string(n_epi))
ax.set_ylabel("Q_e(a,s=" * string(state) * ") - Q_e(a1,s=" * string(state) * ")")
ax.legend(["a1","a2","a3"])


tight_layout()
display(fig)