In [1]:
# Load Turing.
using Turing

# Load CSV, DataFrames
using CSV, DataFrames

# Load StatsPlots for visualizations and diagnostics.
using StatsPlots

# Functionality for splitting and normalizing the data.
using MLDataUtils: shuffleobs, splitobs, rescale!

# We need a softmax function which is provided by NNlib.
using NNlib: softmax

# Functionality for constructing arrays with identical elements efficiently.
using FillArrays

# Functionality for working with scaled identity matrices.
using LinearAlgebra

# Set a seed for reproducibility.
using Random

# For save
using JLD

using ReverseDiff

using Memoization

using Optim, StatsBase

Random.seed!(0)

TaskLocalRNG()

In [7]:
# Import the dataset.
data = DataFrame(CSV.File("dataForTuring//E4orth.csv"))
# Show twenty random rows.
data[rand(1:size(data, 1), 20), :]

Row,Column1,subj_id,test_part,word,letter,mask_A,mask_B,mask_C,mask_D,mask_E,mask_F,mask_G,mask_H,mask_I,mask_J,mask_K,mask_L,mask_M,mask_N,mask_O,mask_P,mask_Q,mask_R,mask_S,mask_T,mask_U,mask_V,mask_W,mask_X,mask_Y,mask_Z,post_A,post_B,post_C,post_D,post_E,post_F,post_G,post_H,post_I,post_J,post_K,post_L,post_M,post_N,post_O,post_P,post_Q,post_R,post_S,post_T,post_U,post_V,post_W,post_X,post_Y,post_Z,hit_bin,prior_A,prior_B,prior_C,prior_D,prior_E,prior_F,prior_G,prior_H,prior_I,prior_J,prior_K,prior_L,prior_M,prior_N,prior_O,prior_P,prior_Q,prior_R,prior_S,prior_T,prior_U,prior_V,prior_W,prior_X,prior_Y,prior_Z,eig_A,eig_B,eig_C,eig_D,eig_E,eig_F,eig_G,eig_H,eig_I,eig_J,eig_K,eig_L,eig_M,eig_N,eig_O,eig_P,⋯
Unnamed: 0_level_1,Int64,String15,String15,String15,String1,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Bool,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,⋯
1,19738,375d9eb,nonpretend,dalai lama,A,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,2.06114,-1.3453,-0.216165,0.484639,1.74338,-0.591815,0.171802,0.201978,0.947121,-0.269125,-0.00678563,0.937268,-0.00432235,1.96815,0.86768,-0.216165,-1.04752,0.726041,0.423673,0.457543,-0.318391,-0.751929,-0.597974,-0.813511,0.362707,-0.862777,True,-0.15156,-2.09311,0.355215,0.148453,1.36212,0.546431,-0.190891,0.896313,0.487563,-0.316423,-0.375168,-0.306654,0.0687078,-0.457902,0.704187,0.122382,0.343378,0.41223,0.767485,1.48299,0.437394,0.327831,0.571429,0.155273,-0.367784,0.17729,1.22701,-1.2671,1.6882,1.70004,1.34366,1.57894,1.67757,1.67215,1.62711,1.63807,1.66032,1.64097,1.72687,1.26097,1.60277,1.64633,⋯
2,35829,568629e6,nonpretend,tooth,H,1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0.730018,0.725541,-1.27525,-0.568648,1.54753,-0.761794,-0.249598,0.966766,0.823682,-0.761794,0.0668177,1.18276,-0.0855374,1.19298,0.998748,-0.422942,-0.761794,1.06623,0.826258,1.72334,0.698673,-0.444268,-0.0222716,-0.761794,-0.592368,-0.761794,True,1.07565,-0.762019,-1.8409,1.19387,1.5727,0.829882,0.281985,0.285226,0.657846,0.240782,-0.352926,-0.484948,0.238504,0.257245,0.645601,0.418188,0.224778,0.16393,0.479154,0.404476,-0.387909,0.17546,0.16744,0.239954,0.586923,0.218983,1.63269,1.61952,-1.18599,1.68532,1.41408,1.57053,1.69137,1.55488,1.70135,1.55018,1.71629,1.52002,1.67436,1.55373,1.594,1.70004,⋯
3,41445,995d7ff,pretend,tooth,T,0,1,1,1,0,1,1,0,0,1,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1,-1.30374,0.0133245,0.0133245,0.0133245,-1.30374,0.0133245,0.0133245,-1.30374,-1.30374,0.0133245,0.0133245,0.0133245,0.0133245,0.0133245,-1.30374,0.0133245,0.0133245,0.0133245,0.0133245,4.35635,0.0133245,0.0133245,0.0133245,0.0133245,0.0133245,0.0133245,True,-2.06019,0.16185,0.504184,0.894552,-2.06019,0.373089,0.300641,-2.06019,-2.06019,-0.193488,-0.0292207,0.834046,0.404403,1.55693,-2.06019,0.277819,-0.20888,1.35471,1.44494,0.398111,0.497815,0.0254467,0.392196,-0.194284,0.289761,-0.214453,-1.20452,-0.500657,-0.336203,-0.148675,-1.20452,-0.39918,-0.433983,-1.20452,-1.20452,-0.671357,-0.592445,-0.177741,-0.384137,0.169523,-1.20452,-0.444946,⋯
4,3343,14560a55,pretend,dalai lama,E,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,-1.24586,-0.150638,-0.150638,2.13407,-0.150638,-0.150638,-0.150638,-0.150638,2.13407,-0.150638,-0.150638,2.13407,2.13407,-0.150638,-0.150638,-0.150638,-0.150638,-0.150638,-0.150638,-0.150638,-0.150638,-0.150638,-0.150638,-0.150638,-0.150638,-0.150638,False,-1.92078,0.0259308,0.2243,0.314659,1.74974,0.148335,0.106355,0.733601,0.731849,-0.179973,-0.0847868,0.279599,0.0306377,0.834323,0.950884,0.0931303,-0.188892,0.717147,0.76943,1.18908,0.220609,-0.0531093,0.159407,-0.180435,0.10005,-0.192121,-1.09978,-0.312671,-0.295348,-0.242479,-0.162129,-0.301982,-0.305648,-0.25087,-0.206046,-0.330653,-0.322341,-0.245541,-0.267283,-0.242074,-0.231894,-0.306803,⋯
5,10711,2436009e,nonpretend,head,D,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,-1.178,-0.0141283,-0.0141283,3.04597,-1.178,-0.0141283,-0.0141283,3.04597,-0.0141283,-0.0141283,-0.0141283,-0.0141283,-0.0141283,-0.0141283,-0.0141283,-0.0141283,-0.0141283,-0.0141283,-0.0141283,-0.0141283,-0.0141283,-0.0141283,-0.0141283,-0.0141283,-0.0141283,-0.0141283,True,-1.72288,0.0367813,0.309038,0.183787,-1.72288,0.204778,0.147161,0.572333,1.19208,-0.245817,-0.115176,0.571375,0.229683,1.14628,1.30626,0.129011,-0.258058,0.985459,1.05722,1.63318,0.303973,-0.0716993,0.219974,-0.24645,0.138508,-0.26249,-1.12242,-0.319172,-0.275013,-0.215954,-1.12242,-0.291924,-0.301269,-0.152934,-0.13179,-0.365007,-0.343818,-0.232464,-0.287884,-0.139218,-0.113271,-0.304213,⋯
6,2186,129615f2,pretend,head,A,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,2.13382,0.125206,0.657422,0.497056,1.68383,0.411061,-0.437972,1.11652,0.504991,-0.572895,0.531325,0.146552,-0.437972,2.12242,0.606004,-0.446799,-0.572895,-0.0618794,0.51808,0.0723764,-0.194605,-0.446799,-0.446799,-0.572895,-0.564069,-0.572895,True,0.0499775,-0.108336,-0.171758,0.256743,1.34203,-0.119695,0.374407,0.25767,0.848798,0.05258,-0.531926,0.435341,0.460481,-0.254736,0.901957,0.361236,0.0398119,1.0033,0.699586,1.59127,0.379113,0.151885,0.456116,0.0519195,0.44769,0.035189,1.13005,1.91139,1.95043,1.72268,1.53632,1.86934,1.59778,1.772,1.93183,1.50535,1.99409,1.9528,1.58886,1.57222,1.81877,1.5848,⋯
7,416,1055e9b4,nonpretend,iowa,I,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1,2.50882,-0.152149,-0.152149,-0.152149,-0.152149,-0.152149,-0.152149,-0.152149,2.50882,-0.152149,-0.152149,-0.152149,-0.152149,-0.152149,-1.39023,-0.152149,-0.152149,-0.152149,-0.152149,-0.152149,-0.152149,-0.152149,2.50882,-0.152149,-0.152149,-0.152149,True,0.664163,-0.0271438,0.182686,0.421958,1.79626,0.102333,0.0579268,0.721413,0.46881,-0.244944,-0.144258,0.384871,0.121527,0.827954,-2.27802,0.0439382,-0.254378,0.704008,0.759312,1.20321,0.178782,-0.11075,-0.280397,-0.245432,0.0512578,-0.257794,-0.288182,-0.382597,-0.366279,-0.347671,-0.240795,-0.372528,-0.375981,-0.324383,-0.303374,-0.399535,-0.391705,-0.350556,-0.371035,-0.316098,-1.40454,-0.377069,⋯
8,38382,5996298d,nonpretend,taylor swift,W,0,1,1,1,0,1,1,1,0,1,1,1,1,1,0,1,1,1,1,1,0,1,1,1,0,1,-1.23017,-0.206018,-0.206018,-0.206018,-1.23017,1.70937,-0.206018,-0.206018,-1.23017,-0.206018,-0.206018,1.70937,-0.206018,-0.206018,-1.23017,-0.206018,-0.206018,1.70937,1.70937,1.70937,-1.23017,-0.206018,1.70937,-0.206018,-1.23017,-0.206018,True,-1.69567,0.332845,0.657318,1.02732,-1.69567,-0.22845,0.464395,1.49038,-1.69567,-0.00395253,0.151744,0.208456,0.562743,1.65513,-1.69567,0.442764,-0.0185412,0.701957,0.787477,1.4739,-1.69567,0.203559,-0.21034,-0.00470712,-1.69567,-0.0238233,-1.07238,-0.475999,-0.343753,-0.192952,-1.07238,-0.358586,-0.422383,-0.00421976,-1.07238,-0.613268,-0.54981,-0.180516,-0.382299,0.0629283,-1.07238,-0.431199,⋯
9,27410,4725bac1,pretend,dalai lama,N,0,0,1,0,0,1,1,0,0,1,1,1,1,1,0,0,1,0,0,0,0,1,1,1,1,1,-1.20313,-1.20313,0.0116517,-1.20313,-1.20313,0.0116517,0.0116517,-1.20313,-1.20313,0.0116517,0.0116517,3.02079,3.02079,0.0116517,-1.20313,-1.20313,0.0116517,-1.20313,-1.20313,-1.20313,-1.20313,0.0116517,0.0116517,0.0116517,0.0116517,0.0116517,False,-1.71238,-1.71238,1.01537,-1.71238,-1.71238,0.843956,0.749229,-1.71238,-1.71238,0.103142,0.317926,0.779023,0.217254,2.39186,-1.71238,-1.71238,0.0830165,-1.71238,-1.71238,-1.71238,-1.71238,0.389405,0.868939,0.102101,0.735003,0.0757298,-0.929284,-0.929284,-0.167924,-0.929284,-0.929284,-0.33159,-0.422037,-0.929284,-0.929284,-1.03893,-0.833851,0.396779,-0.139607,1.14637,-0.929284,-0.929284,⋯
10,10771,24462a01,nonpretend,eleven,V,1,1,1,1,0,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,-0.087638,-0.087638,-0.087638,-0.087638,-1.35032,-0.087638,-0.087638,-0.087638,-0.087638,-0.087638,-0.087638,-1.35032,-0.087638,3.10134,-0.087638,-0.087638,-0.087638,-0.087638,-0.087638,-0.087638,-0.087638,3.10134,-0.087638,-0.087638,-0.087638,-0.087638,True,1.18851,-0.0319234,0.203936,0.47289,-2.20706,0.113615,0.0637003,0.809493,0.968926,-0.276742,-0.163566,-2.20706,0.13519,0.847142,1.06784,0.0479763,-0.287347,0.789929,0.852094,1.35106,0.199548,-0.20801,0.126779,-0.277291,0.056204,-0.291186,-0.168445,-0.398102,-0.353719,-0.303108,-1.25726,-0.370715,-0.380108,-0.239768,-0.209766,-0.444171,-0.422874,-1.25726,-0.366655,-0.146358,-0.191153,-0.383067,⋯


In [8]:
# Recode columns
letters = map(string, collect('A':'Z'))
subjects = unique(data[:,:subj_id])
conditions = ["pretend", "nonpretend"]
data[!, :letter_index] = indexin(data[!, :letter], letters)
data[!, :subj_index] = indexin(data[!, :subj_id], subjects)
data[!, :pretend] = map(x -> x=="pretend" ? 1 : 0, data[:, :test_part])

# Show twenty random rows of the new dataframe
data[rand(1:size(data, 1), 20), [:subj_id, :subj_index, :letter, :letter_index, :test_part, :pretend]]

Row,subj_id,subj_index,letter,letter_index,test_part,pretend
Unnamed: 0_level_1,String15,Union…,String1,Union…,String15,Int64
1,4615914b,322,T,20,nonpretend,0
2,8860f4e,490,A,1,nonpretend,0
3,27461098,158,N,14,nonpretend,0
4,535e2b0,392,V,22,nonpretend,0
5,199629fa,88,Y,25,pretend,1
6,55558a43,413,E,5,nonpretend,0
7,2515d445,136,L,12,nonpretend,0
8,14960f6e,46,D,4,pretend,1
9,48259ff8,344,T,20,pretend,1
10,212627e6,98,D,4,pretend,1


In [13]:
subjects[79]

"1885a9f5"

In [9]:
prior_cols = [Symbol("prior_", i) for i in 'A':'Z']
post_cols = [Symbol("post_", i) for i in 'A':'Z']
eig_cols = [Symbol("eig_", i) for i in 'A':'Z']
mask_cols = [Symbol("mask_", i) for i in 'A':'Z']
subj_col = :subj_index
pretend_col = :pretend
target_col = :letter_index

# Turing requires data in matrix and vector form.
priors = Matrix(data[!, prior_cols])
posts = Matrix(data[!, post_cols])
eig = Matrix(data[!, eig_cols])
masks = Matrix(data[!, mask_cols])
subj = data[!, subj_col]
pretend = data[!, pretend_col]
target = data[!, target_col]

41445-element Vector{Union{Nothing, Int64}}:
  1
  3
  4
  5
  9
 10
 11
 12
 13
 15
 18
 19
 20
  ⋮
  5
  9
 13
 14
 15
 20
  1
  5
  8
  9
 15
 20

In [10]:
# Bayesian multinomial logistic regression
@model function logistic_regression(priors, posts, eig, masks, pretend, y)
    
    n = size(priors, 1) # number of rows
    length(y) == size(posts,1) == size(masks,1) == n ||
        throw(DimensionMismatch("number of observations in `as` and 'bs' `y` is not equal"))

    coef_prior ~ Normal(0, 10)
    coef_post ~ Normal(0, 10)
    coef_eig ~ Normal(0, 10)

    # Subject-level parameters: condition effets
    coef_prior_delta ~ Normal(0, 10)
    coef_post_delta ~ Normal(0, 10)
    coef_eig_delta ~ Normal(0, 10)
    
    for i in 1:n
        v_unmasked = softmax((coef_prior*priors[i,:]+coef_post*posts[i,:]+coef_eig*eig[i,:]) +
            pretend[i]*(coef_prior_delta*priors[i,:]+coef_post_delta*posts[i,:]+coef_eig_delta*eig[i,:]))
        v = v_unmasked.*masks[i,:]
        if sum(v)>0
            v=v/sum(v)
        else 
            v = masks[i,:]/sum(masks[i,:]) # if everything sum to 0, just choose randomly from the possibilities
        end
        y[i] ~ Categorical(v)
    end
end;

In [11]:
"""
    outfun(m, outfn="output.csv")

output the coefficient table of a fitted model to a file
"""
outfun = function(m, outfn="output.csv")
    ct = coeftable(m)
    coef_df = DataFrame(ct.cols, :auto);
    rename!(coef_df, ct.colnms, makeunique = true)
    coef_df[!, :term] = ct.rownms;
    CSV.write(outfn, coef_df);
end

for i = 1:length(unique(subj))
    i
    m = logistic_regression(priors[(subj.==i),:], posts[(subj.==i),:], eig[(subj.==i),:], masks[(subj.==i),:], pretend[(subj.==i)], target[(subj.==i)])
    mle1 = optimize(m, MLE())
    outfun(mle1, "model_fits//E4//subj$(i)_orth.csv")
end

# t = @time chain = sample(m, HMC(0.05, 10), MCMCThreads(), 1_500, 3)
# t = @time chain = sample(m, Prior(), 1500)

# Save the chain to a file
# save("my_chain.jld", "chain", chain)
# CSV.write("my_chains.csv", DataFrame(chain))

[33m[1m└ [22m[39m[90m@ Turing C:\Users\tanzor\.julia\packages\Turing\Suzsv\src\modes\OptimInterface.jl:243[39m
[33m[1m└ [22m[39m[90m@ Turing C:\Users\tanzor\.julia\packages\Turing\Suzsv\src\modes\OptimInterface.jl:243[39m
[33m[1m└ [22m[39m[90m@ Turing C:\Users\tanzor\.julia\packages\Turing\Suzsv\src\modes\OptimInterface.jl:243[39m
[33m[1m└ [22m[39m[90m@ Turing C:\Users\tanzor\.julia\packages\Turing\Suzsv\src\modes\OptimInterface.jl:243[39m
