# Graph learning modeling code

In [36]:
#set up

#load packages
using FileIO
using JLD2
using DataFrames
using CSV
using LinearAlgebra
using Statistics
using SpecialFunctions
using Logging
using CodecZlib
using EM2 #slightly modified version of EM (in code/em-ak-em if not installed)

BLAS.set_num_threads(1)

include("code/em_scripts.jl")
include("code/sr_funcs_kids.jl")

function unitnorm(x)
    0.5 .+ 0.5 .* erf.(x ./ sqrt(2))
end

unitnorm (generic function with 1 method)

## Load data and fit model

In [40]:
#process model data
trialdata_kids = DataFrame(CSV.File("../data/processed/model_data.csv"))

subjects = unique(trialdata_kids.subject_id)
trialdata_kids.sub = [findfirst(x -> x == y, subjects) for y in trialdata_kids.subject_id]
trialdata_kids.trial = (trialdata_kids.within_block_trial .* trialdata_kids.block_num)/600
age = [trialdata_kids[trialdata_kids.subject_id .== s, :age][1] for s in subjects]
age = age .- mean(age)
trialdata_kids.targetid = trialdata_kids.stim_id
target_button = zeros(Int, nrow(trialdata_kids)) .+ 1
target_button[trialdata_kids.target_button .== "f"] .= 2    
trialdata_kids.keyid = target_button
trialdata_kids.isValid = trialdata_kids.isValid .== 1
trialdata_kids.rt = trialdata_kids.rt ./ 1000
trialdata_kids[trialdata_kids.rt .< 0.2, :isValid] .= false


fname = "fit_model"
fn = run_sr_td_future_dutch_rt_shift_trial_alltargets_keys

run_sr_td_future_dutch_rt_shift_trial_alltargets_keys (generic function with 1 method)

In [42]:
#run model
naive = true
add_αM = true
add_recency_ntrials = false
add_recency_lag10 = false
normalize_prediction = true
add_zero_order = true
warmup = -1
covariates = age

results = fn(trialdata_kids; naive, add_αM, threads=true, add_recency_ntrials, add_recency_lag10, add_zero_order, normalize_prediction, warmup, covariates)
save("$(fname).jld2", "results", results; compress=true)

(110,)
(110,)

iter: 10
betas: [-0.6 -1.44 -0.7 0.01 -1.13 -0.12 -1.15 -1.75 0.59 0.66 -0.02 0.08 0.03 0.03 0.05 0.09 0.08 0.0 -0.04 0.04 0.05 0.02 0.09 0.05 0.01; -0.05 0.02 0.21 0.0 -0.06 0.01 0.06 0.02 0.07 0.03 -0.0 0.0 -0.0 -0.0 -0.0 0.0 0.0 0.0 -0.0 0.0 0.0 0.0 0.0 0.0 0.0]
sigma: [0.06, 0.03, 0.58, 0.01, 0.95, 0.02, 0.58, 0.28, 0.52, 0.99, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
free energy: 38724.080631
change: [-0.006959, -0.000206, -0.003483, 0.045583, -0.053176, -0.086493, -0.072988, -0.018835, 0.031744, 0.018081, -0.004847, 0.000983, 0.000678, 0.004439, 0.001059, 0.000305, 0.000977, 0.032933, -0.002614, 0.001865, 0.000706, 0.002646, 8.4e-5, 0.002424, 0.002179, -0.000978, 0.002822, 0.004358, 0.060235, -0.049368, 0.228024, 0.122404, 0.114289, 0.01022, 0.037256, -0.043125, 0.000412, -0.016171, -0.012593, -0.006637, 0.001617, 0.003185, 0.002387, -0.002583, 0.006454, 0.005804, 0.005288, 0.012554, 0.001177, 0.01038, 0.004957, 0.00501, 0.004764, 

┌ Info: Running emerrors
└ @ Main /Users/katenussenbaum/Library/CloudStorage/Dropbox/research/studies/dev_sr/GL_manuscript/dev-graph-learning/modeling/code/sr_funcs_kids.jl:19


## Group-level model results

In [43]:
#read in data
results = load("fit_model.jld2", "results")
pvalues = results.pvalues
betas = results.betas
sigma = results.sigma
se = results.standarderrors
varnames = results.varnames

#get p values and betas for estimates
p_reshaped = reshape(pvalues, size(betas'))
data = hcat(p_reshaped[:,1], betas[1,:])

#get sigmas
sigma_vec = vec(sigma)
sigmas = [x for x in sigma_vec if x!=0]

#get standard errors
se_reshaped = reshape(se, size(betas'))

#make data frame 
df = DataFrame(data, [:beta_p, :beta])
df.parameter = varnames;
df.se = se_reshaped[:, 1]
df.age_beta = betas[2,:]
df.age_p = p_reshaped[:,2]
df.age_se = se_reshaped[:, 2]


select!(df, [:parameter, :beta, :se, :beta_p, :age_beta, :age_se, :age_p])

function round_dataframe!(df::DataFrame; digits::Int=4)
    for col in eachcol(df)
        if eltype(col) <: Number
            col .= round.(col, digits=digits)
        end
    end
    return df
end

round_dataframe!(df)

# Print the DataFrame
print(df)


[1m25×7 DataFrame[0m
[1m Row [0m│[1m parameter      [0m[1m beta    [0m[1m se      [0m[1m beta_p  [0m[1m age_beta [0m[1m age_se  [0m[1m age_p   [0m
     │[90m String         [0m[90m Float64 [0m[90m Float64 [0m[90m Float64 [0m[90m Float64  [0m[90m Float64 [0m[90m Float64 [0m
─────┼───────────────────────────────────────────────────────────────────────
   1 │ rt_μ            -0.4274   0.037    0.0      -0.0393   0.0085   0.0
   2 │ rt_σ            -1.4404   0.0196   0.0       0.0223   0.0046   0.0
   3 │ rt_shift        -0.7137   0.0995   0.0       0.2089   0.0252   0.0
   4 │ β_trial          0.0201   0.0103   0.0524    0.0015   0.0025   0.5533
   5 │ β_anticipation  -3.8013   0.4905   0.0      -0.2259   0.1109   0.0417
   6 │ β_zero_order    -0.165    0.0215   0.0       0.005    0.0074   0.4968
   7 │ α_zero_order    -1.5981   0.0734   0.0       0.0241   0.0253   0.341
   8 │ αM              -1.9743   0.1255   0.0       0.0301   0.0247   0.2229
   9 │ γ  

## Individual parameter estimates

In [44]:
#add subject_id to results.x
df = hcat(subjects, results.x)

#add subject id to varnames
varnames = ["subject_id"; results.varnames]

#convert df to table
data = DataFrame(df, Symbol.(varnames))

#save df as csv
CSV.write("../data/processed/model_individ_params.csv", data)


"../data/processed/model_individ_params.csv"