# Graph learning modeling code

In [4]:
# ---- setup ----
import Pkg
cd(@__DIR__)
Pkg.activate(@__DIR__)
Pkg.develop(path=joinpath(@__DIR__, "code", "em-ak-em"))

using FileIO
using JLD2
using DataFrames
using CSV
using LinearAlgebra
using Statistics
using SpecialFunctions
using Logging
using CodecZlib
using EM2

BLAS.set_num_threads(1)

include(joinpath(@__DIR__, "code", "em_scripts.jl"))
include(joinpath(@__DIR__, "code", "sr_funcs_kids.jl"))

function unitnorm(x)
    0.5 .+ 0.5 .* erf.(x ./ sqrt(2))
end

[32m[1m  Activating[22m[39m project at `~/Dropbox (Personal)/research/studies/dev_sr/GL_manuscript/dev-graph-learning/modeling`
[32m[1m   Resolving[22m[39m package versions...
[36m[1m     Project[22m[39m No packages added to or removed from `~/Dropbox (Personal)/research/studies/dev_sr/GL_manuscript/dev-graph-learning/modeling/Project.toml`
[36m[1m    Manifest[22m[39m No packages added to or removed from `~/Dropbox (Personal)/research/studies/dev_sr/GL_manuscript/dev-graph-learning/modeling/Manifest.toml`


unitnorm (generic function with 1 method)

## Load data and fit model

In [5]:
#import data
trialdata_kids = DataFrame(CSV.File("../data/processed/model_data.csv"))

#process model data
subjects = unique(trialdata_kids.subject_id)
trialdata_kids.sub = [findfirst(x -> x == y, subjects) for y in trialdata_kids.subject_id]
age = [trialdata_kids[trialdata_kids.subject_id .== s, :age][1] for s in subjects]
age = age .- mean(age)
trialdata_kids.targetid = trialdata_kids.stim_id
target_button = zeros(Int, nrow(trialdata_kids)) .+ 1
target_button[trialdata_kids.target_button .== "f"] .= 2    
trialdata_kids.keyid = target_button
trialdata_kids.isValid = trialdata_kids.isValid .== 1
trialdata_kids.rt = trialdata_kids.rt ./ 1000
trialdata_kids[trialdata_kids.rt .< 0.2, :isValid] .= false

#code block as 0/1
trialdata_kids.block .= trialdata_kids.block_num .- 1

#code overall trial number (for model without block)
#trialdata_kids.trial .= trialdata_kids.within_block_trial + (trialdata_kids.block_num .- 1) .* 300

#code within_block_trial (for model with block)
trialdata_kids.trial .= trialdata_kids.within_block_trial

#determine model-fitting function
fname = "fit_model"
fn = run_sr_td_future_dutch_rt_shift_trial_alltargets_keys

trialdata_kids


Row,subject_id,age,block_num,within_block_trial,node,target_button,stim_id,rt,isValid,bad_browser_sub,bad_missed_trials,bad_acc_trials,fast_rt_subs,incomplete_subs,long_break_subs,exclude,sub,targetid,keyid,block,trial
Unnamed: 0_level_1,Int64,Float64,Int64,Int64,Int64,String1,Int64,Float64,Bool,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64
1,14085,9.34,1,1,13,f,10,1.0072,true,0,0,0,0,0,0,0,1,10,2,0,1
2,14085,9.34,1,2,11,j,6,0.8461,true,0,0,0,0,0,0,0,1,6,1,0,2
3,14085,9.34,1,3,12,f,4,0.6793,true,0,0,0,0,0,0,0,1,4,2,0,3
4,14085,9.34,1,4,11,j,6,0.6953,true,0,0,0,0,0,0,0,1,6,1,0,4
5,14085,9.34,1,5,12,j,4,0.3014,false,0,0,0,0,0,0,0,1,4,1,0,5
6,14085,9.34,1,6,14,j,11,1.0695,false,0,0,0,0,0,0,0,1,11,1,0,6
7,14085,9.34,1,7,11,f,6,0.1619,false,0,0,0,0,0,0,0,1,6,2,0,7
8,14085,9.34,1,8,13,f,10,0.784,true,0,0,0,0,0,0,0,1,10,2,0,8
9,14085,9.34,1,9,12,f,4,0.8125,true,0,0,0,0,0,0,0,1,4,2,0,9
10,14085,9.34,1,10,15,j,3,0.9222,true,0,0,0,0,0,0,0,1,3,1,0,10


In [6]:
#run model
naive = true
add_αM = true
add_recency_ntrials = false
add_recency_lag10 = false
normalize_prediction = true
add_zero_order = true
warmup = -1
covariates = age

results = fn(trialdata_kids; naive, add_αM, threads=true, add_recency_ntrials, add_recency_lag10, add_zero_order, normalize_prediction, warmup, covariates)
save("$(fname).jld2", "results", results; compress=true)


(106,)
(106,)

iter: 10
betas: [-0.65 -1.46 -0.75 0.0 -0.05 -0.28 -0.14 -1.19 -0.7 0.65 0.9 -0.02 0.08 0.03 0.03 0.05 0.09 0.09 0.0 -0.04 0.04 0.06 0.02 0.09 0.05 0.01; -0.05 0.02 0.21 0.0 -0.0 -0.04 0.01 0.04 0.01 0.13 0.04 -0.0 0.0 0.0 -0.0 -0.0 0.0 0.0 0.0 -0.0 0.0 0.0 0.0 0.0 0.0 0.0]
sigma: [0.06, 0.03, 0.6, 0.0, 0.01, 0.08, 0.02, 0.61, 0.47, 0.61, 1.14, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
free energy: 37461.773381
change: [-0.001993, -0.000259, -0.002402, 0.006172, -0.000739, -0.011854, -0.139967, -0.106067, -0.061041, 0.003931, 0.039166, -0.002227, 0.000997, 0.000289, 0.003391, 0.00182, 0.000615, 0.00032, 0.014923, -0.002408, 0.002862, 0.000376, 0.003234, 0.000122, 0.002937, 0.002079, -0.0012, 0.001731, 0.004315, 0.124017, -0.024362, -0.136637, 0.555501, 0.61075, 0.104253, 0.09988, 0.162324, -0.044842, 0.003683, 0.040895, -0.025005, -0.019198, 0.00194, 0.001963, 0.020502, -0.011565, 0.000725, 0.027646, 0.018129, 0.001944, 0.001088, 0.002535

┌ Info: Running emerrors
└ @ Main /Users/katenuss/Dropbox (Personal)/research/studies/dev_sr/GL_manuscript/dev-graph-learning/modeling/code/sr_funcs_kids.jl:19


## Group-level model results

In [7]:
#read in data
results = load("fit_model.jld2", "results")
pvalues = results.pvalues
betas = results.betas
sigma = results.sigma
se = results.standarderrors
varnames = results.varnames

#get p values and betas for estimates
p_reshaped = reshape(pvalues, size(betas'))
data = hcat(p_reshaped[:,1], betas[1,:])

#get sigmas
sigma_vec = vec(sigma)
sigmas = [x for x in sigma_vec if x!=0]

#get standard errors
se_reshaped = reshape(se, size(betas'))

#make data frame 
df = DataFrame(data, [:beta_p, :beta])
df.parameter = varnames;
df.se = se_reshaped[:, 1]
df.age_beta = betas[2,:]
df.age_p = p_reshaped[:,2]
df.age_se = se_reshaped[:, 2]


select!(df, [:parameter, :beta, :se, :beta_p, :age_beta, :age_se, :age_p])

function round_dataframe!(df::DataFrame; digits::Int=4)
    for col in eachcol(df)
        if eltype(col) <: Number
            col .= round.(col, digits=digits)
        end
    end
    return df
end

round_dataframe!(df)

# Print the DataFrame
print(df)


[1m26×7 DataFrame[0m
[1m Row [0m│[1m parameter      [0m[1m beta    [0m[1m se      [0m[1m beta_p  [0m[1m age_beta [0m[1m age_se  [0m[1m age_p   [0m
[1m     [0m│[90m String         [0m[90m Float64 [0m[90m Float64 [0m[90m Float64 [0m[90m Float64  [0m[90m Float64 [0m[90m Float64 [0m
─────┼───────────────────────────────────────────────────────────────────────
   1 │ rt_μ            -0.6385   0.0266   0.0      -0.0495   0.0062   0.0
   2 │ rt_σ            -1.4527   0.0192   0.0       0.0228   0.0044   0.0
   3 │ rt_shift        -0.7374   0.1001   0.0       0.213    0.0249   0.0
   4 │ β_trial          0.0001   0.0      0.0698    0.0      0.0      0.6842
   5 │ β_block         -0.0508   0.0082   0.0      -0.0014   0.002    0.4783
   6 │ β_anticipation  -0.4059   0.0792   0.0      -0.0719   0.0139   0.0
   7 │ β_zero_order    -0.2263   0.0277   0.0       0.0128   0.0084   0.127
   8 │ α_zero_order    -1.7464   0.0714   0.0       0.0118   0.0166   0.4768
   9 

## Individual parameter estimates

In [8]:
#add subject_id to results.x
df = hcat(subjects, results.x)

#add subject id to varnames
varnames = ["subject_id"; results.varnames]

#convert df to table
data = DataFrame(df, Symbol.(varnames))

#save df as csv
CSV.write("../data/processed/model_individ_params.csv", data)

data



Row,subject_id,rt_μ,rt_σ,rt_shift,β_trial,β_block,β_anticipation,β_zero_order,α_zero_order,αM,γ,λ,β_targets_2,β_targets_3,β_targets_4,β_targets_5,β_targets_6,β_targets_7,β_targets_8,β_targets_9,β_targets_10,β_targets_11,β_targets_12,β_targets_13,β_targets_14,β_targets_15,β_key_2
Unnamed: 0_level_1,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64
1,14085.0,-0.352252,-1.65049,-2.42249,-0.000110665,-0.0775474,-0.0383141,-0.284238,-1.8154,-0.875323,-0.827613,0.0230361,0.0319223,0.0888993,0.0459095,0.0418469,0.0582195,0.0700697,0.0796053,-0.040676,-0.0438931,0.0183775,0.0566694,0.00923412,0.0619912,0.0337397,0.0177248
2,14086.0,-0.407733,-1.67943,-1.98846,0.000327192,-0.0133157,-0.121264,-0.282289,-1.91299,-0.829624,-0.761114,0.0341544,0.0327853,0.0534098,0.0156877,0.0400792,0.0417732,0.0949934,0.0174108,-0.0419434,-0.0318216,0.0298347,0.0934264,-0.00137381,0.115803,0.0224748,0.0617798
3,14087.0,-0.246481,-1.7836,-2.01583,-0.000358281,-0.0826548,0.0583159,-0.353464,-1.82041,-0.691289,-0.286067,0.320121,0.00953664,0.0381408,0.0199365,0.0345702,0.0613568,0.0664592,0.112409,0.0392823,-0.024478,0.0189572,0.0507528,-0.00808548,0.0495366,0.0713233,-0.0143189
4,14089.0,-0.281421,-1.57166,-2.00208,0.000169757,-0.200042,-0.0378283,-0.223146,-1.79079,-0.836537,-0.215143,0.273577,-0.0226094,0.0234469,0.0342931,0.0257773,0.0690756,0.0757206,0.0554437,0.00183524,-0.045837,0.0397544,0.0320976,0.00362553,0.0705032,0.0464925,-0.000758946
5,14090.0,-0.278547,-1.39871,-2.34968,0.000102335,-0.0066976,-0.0145958,-0.195805,-1.78788,-0.877082,-0.358381,0.192459,0.00420229,0.0571558,0.017267,0.0186782,0.0590156,0.0735431,0.106316,-0.00416176,-0.0313094,0.0291299,0.0491347,0.0300431,0.0815025,0.0418727,-0.0570825
6,14091.0,-0.922528,-1.29009,0.163258,7.64166e-5,0.0712104,-0.465409,-0.294817,-1.75073,-0.904393,1.1267,0.584062,0.00914049,0.0682806,0.0435102,0.0142613,0.0669847,0.0842686,0.0923575,0.0181223,-0.0293008,0.0390365,0.0551426,0.0268575,0.11241,0.038445,0.0131104
7,14092.0,-0.363552,-1.41382,-1.225,0.000205951,0.00916104,-0.575142,0.0596559,-1.69924,-1.14543,1.85362,0.482677,-0.0129656,0.0749694,0.0307685,0.035916,0.0482605,0.101254,0.0724975,0.00485881,-0.00700954,0.0534426,0.0741188,0.00170139,0.0720973,0.0659929,-0.0281933
8,14093.0,-0.761975,-1.37308,0.138793,6.07387e-5,-0.0140935,-0.643183,-0.174621,-1.70279,-0.863351,1.16085,0.714031,0.0168573,0.0749175,0.037036,0.0207293,0.0456276,0.0906376,0.08315,0.0271866,-0.0548403,0.0595532,0.0601675,0.00944941,0.0895757,0.0532643,0.0380341
9,14094.0,-0.894623,-1.447,0.439289,-0.00021121,-0.0541196,-0.734895,-0.255058,-1.74914,-0.858761,1.35673,0.596705,-0.0332855,0.127377,0.0335106,0.00968858,0.0995558,0.105053,0.0999197,0.019519,-0.0390208,0.0572805,0.0661843,-0.00185934,0.0770197,0.0584888,0.093818
10,14095.0,-0.815901,-1.28274,0.402136,0.000198004,-0.0501107,-0.774633,-0.236477,-1.69749,-0.877787,1.30811,1.13522,-0.0286709,0.0960479,0.049766,0.0241085,0.0544879,0.117304,0.116425,0.0312264,-0.0490324,0.0667028,0.0815208,0.000821054,0.0817909,0.056723,-0.0234463
