In [1]:
# Load Turing.
using Turing

# Load CSV, DataFrames
using CSV, DataFrames

# Load StatsPlots for visualizations and diagnostics.
using StatsPlots

# Functionality for splitting and normalizing the data.
using MLDataUtils: shuffleobs, splitobs, rescale!

# We need a softmax function which is provided by NNlib.
using NNlib: softmax

# Functionality for constructing arrays with identical elements efficiently.
using FillArrays

# Functionality for working with scaled identity matrices.
using LinearAlgebra

# Set a seed for reproducibility.
using Random

# For save
using JLD

using ReverseDiff

using Memoization

using Optim, StatsBase

Random.seed!(0)

TaskLocalRNG()

In [2]:
# Import the dataset.
data = DataFrame(CSV.File("dataForTuring//E4orth.csv"))
# Show twenty random rows.
data[rand(1:size(data, 1), 20), :]

Row,Column1,subj_id,test_part,word,letter,mask_A,mask_B,mask_C,mask_D,mask_E,mask_F,mask_G,mask_H,mask_I,mask_J,mask_K,mask_L,mask_M,mask_N,mask_O,mask_P,mask_Q,mask_R,mask_S,mask_T,mask_U,mask_V,mask_W,mask_X,mask_Y,mask_Z,post_A,post_B,post_C,post_D,post_E,post_F,post_G,post_H,post_I,post_J,post_K,post_L,post_M,post_N,post_O,post_P,post_Q,post_R,post_S,post_T,post_U,post_V,post_W,post_X,post_Y,post_Z,hit_bin,prior_A,prior_B,prior_C,prior_D,prior_E,prior_F,prior_G,prior_H,prior_I,prior_J,prior_K,prior_L,prior_M,prior_N,prior_O,prior_P,prior_Q,prior_R,prior_S,prior_T,prior_U,prior_V,prior_W,prior_X,prior_Y,prior_Z,eig_A,eig_B,eig_C,eig_D,eig_E,eig_F,eig_G,eig_H,eig_I,eig_J,eig_K,eig_L,eig_M,eig_N,eig_O,eig_P,⋯
Unnamed: 0_level_1,Int64,String15,String15,String15,String1,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Bool,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,⋯
1,3146,1425ade0,pretend,taylor swift,T,1,1,1,1,0,1,1,1,0,1,1,1,1,0,0,1,1,1,1,1,1,1,1,1,1,1,0.830864,-0.369273,-0.369273,-0.369273,-0.369273,0.830864,-0.369273,-0.369273,-0.369273,-0.369273,-0.369273,0.830864,-0.369273,-0.369273,-0.369273,-0.369273,-0.369273,0.830864,0.830864,0.830864,-0.369273,-0.369273,0.830864,-0.369273,0.830864,-0.369273,True,2.09245,-0.39132,0.0886886,0.636048,-0.946493,-0.0951288,-0.196712,1.32108,-0.946493,-0.889562,-0.659232,0.551209,-0.051221,-0.946493,-0.946493,-0.228713,-0.911144,1.28127,1.40778,2.42324,0.0797582,-0.58258,-0.0683376,-0.890678,-0.211968,-0.918958,-0.438923,-0.438923,-0.438923,-0.438923,-0.438923,-0.438923,-0.438923,-0.438923,-0.438923,-0.438923,-0.438923,-0.438923,-0.438923,-0.438923,-0.438923,-0.438923,⋯
2,4598,1605d634,nonpretend,strawberry,N,0,1,1,1,0,1,1,1,0,1,0,1,0,1,0,0,1,1,1,1,0,1,1,1,1,1,-0.355635,1.18545,-0.355635,-0.355635,-0.355635,-0.355635,-0.355635,-0.355635,-0.355635,-0.355635,-0.355635,-0.355635,-0.355635,-0.355635,-0.355635,-0.355635,-0.355635,1.18545,1.18545,1.18545,-0.355635,-0.355635,1.18545,-0.355635,1.18545,-0.355635,False,-0.835584,-0.265437,0.227518,0.78964,-0.835584,0.0387427,-0.0655804,1.49315,-0.835584,-0.777117,-0.835584,0.702513,-0.835584,1.74345,-0.835584,-0.835584,-0.799281,1.45226,1.58219,2.62504,-0.835584,-0.461855,0.0662565,-0.778264,-0.081248,-0.807306,-0.327962,-0.327962,-0.327962,-0.327962,-0.327962,-0.327962,-0.327962,-0.327962,-0.327962,-0.327962,-0.327962,-0.327962,-0.327962,-0.327962,-0.327962,-0.327962,⋯
3,12327,2625e2cd,pretend,ninety six,N,1,1,1,1,0,1,1,0,0,1,1,1,1,1,1,1,1,1,0,0,1,1,0,0,0,1,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,8.98733,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,-0.359493,True,2.70555,-0.26793,0.30672,0.962,-0.932565,0.0866602,-0.0349519,-0.932565,-0.932565,-0.864409,-0.588666,0.860434,0.139225,2.07388,2.41155,-0.073262,-0.890246,1.73444,-0.932565,-0.932565,0.296029,-0.4969,-0.932565,-0.932565,-0.932565,-0.899601,-0.375629,-0.375629,-0.375629,-0.375629,-0.375629,-0.375629,-0.375629,-0.375629,-0.375629,-0.375629,-0.375629,-0.375629,-0.375629,-0.375629,-0.375629,-0.375629,⋯
4,20826,3865b208,nonpretend,head,E,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0.859741,-0.0440042,0.195458,0.123304,0.657275,0.0846117,-0.297397,0.40202,0.126874,-0.358104,0.138723,-0.0343999,-0.297397,0.854611,0.172323,-0.301369,-0.358104,-0.12818,0.132763,-0.067774,-0.187898,-0.301369,-0.301369,-0.358104,-0.354133,-0.358104,True,1.10863,-0.604922,-0.273764,0.10386,2.27282,-0.400579,-0.470662,0.576466,0.800319,-0.948659,-0.789754,0.0453294,-0.370287,0.744613,0.9392,-0.492739,-0.963548,0.548998,0.63628,1.33685,-0.279925,-0.736872,-0.382096,-0.949429,-0.481187,-0.968939,2.92198,0.918056,1.67485,1.28014,2.82723,1.26627,-0.102907,2.17216,1.55075,-0.407659,1.50901,1.04104,-0.102907,3.30074,1.58518,-0.128717,⋯
5,22129,403606df,pretend,tooth,N,0,1,1,1,0,1,1,1,0,1,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1,-0.314594,-0.314594,-0.314594,-0.314594,-0.314594,-0.314594,-0.314594,3.77513,-0.314594,-0.314594,-0.314594,-0.314594,-0.314594,-0.314594,-0.314594,-0.314594,-0.314594,-0.314594,-0.314594,3.77513,-0.314594,-0.314594,-0.314594,-0.314594,-0.314594,-0.314594,False,-0.697647,-0.279471,0.0820894,0.49438,-0.697647,-0.0563684,-0.132885,1.01037,-0.697647,-0.654764,-0.481272,0.430476,-0.0232955,1.19396,-0.697647,-0.156989,-0.671021,0.980384,1.07568,1.84056,0.0753627,-0.423534,-0.0361883,-0.655605,-0.144376,-0.676906,-0.271873,-0.271873,-0.271873,-0.271873,-0.271873,-0.271873,-0.271873,-0.271873,-0.271873,-0.271873,-0.271873,-0.271873,-0.271873,-0.271873,-0.271873,-0.271873,⋯
6,30412,506296a,pretend,eleven,O,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,-0.364273,-0.364273,-0.364273,-0.364273,1.07074,-0.364273,0.209733,0.209733,0.496736,-0.364273,-0.364273,0.496736,-0.364273,0.496736,-0.07727,-0.364273,-0.364273,-0.07727,-0.364273,1.07074,-0.364273,0.209733,0.209733,-0.364273,0.783739,-0.364273,False,1.0016,-0.546518,-0.247332,0.0938324,2.05338,-0.361904,-0.425221,0.52081,0.723051,-0.857069,-0.713506,0.040953,-0.334537,0.672722,0.848523,-0.445166,-0.870521,0.495994,0.574849,1.20778,-0.252899,-0.665729,-0.345206,-0.857765,-0.43473,-0.875391,-0.361109,-0.361109,-0.361109,-0.361109,2.77456,-0.361109,1.50564,1.50564,2.02328,-0.361109,-0.361109,2.11719,-0.361109,2.11719,0.683681,-0.361109,⋯
7,16547,3255eacb,pretend,taylor swift,O,0,0,1,0,0,1,1,1,0,1,1,0,1,1,1,1,1,0,1,0,1,1,1,1,0,1,-0.363835,-0.363835,-0.363835,-0.363835,-0.363835,2.00109,-0.363835,-0.363835,-0.363835,-0.363835,-0.363835,-0.363835,-0.363835,-0.363835,2.00109,-0.363835,-0.363835,-0.363835,2.00109,-0.363835,-0.363835,-0.363835,2.00109,-0.363835,-0.363835,-0.363835,True,-0.94722,-0.94722,0.560677,-0.94722,-0.94722,0.29292,0.144948,2.35584,-0.94722,-0.864291,-0.528781,-0.94722,0.356878,2.71087,3.12172,0.0983348,-0.895728,-0.94722,2.48213,-0.94722,0.547669,-0.417125,0.331945,-0.865917,-0.94722,-0.907111,-0.404725,-0.404725,-0.404725,-0.404725,-0.404725,-0.404725,-0.404725,-0.404725,-0.404725,-0.404725,-0.404725,-0.404725,-0.404725,-0.404725,-0.404725,-0.404725,⋯
8,1743,12355974,pretend,lemon,A,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0.472936,0.152691,0.457671,-0.363003,-0.363003,-0.363003,-0.363003,0.457671,-0.363003,-0.363003,-0.363003,0.773808,0.773808,0.773808,0.773808,0.472936,-0.363003,0.167955,-0.347739,-0.363003,-0.363003,-0.363003,-0.363003,-0.363003,0.152691,-0.363003,False,1.37915,-0.535892,-0.165794,0.256233,-0.963944,-0.307522,-0.385845,0.784412,1.03459,-0.920049,-0.742459,0.19082,-0.273668,0.97233,1.1898,-0.410518,-0.936689,0.753714,0.851259,1.6342,-0.172679,-0.683358,-0.286865,-0.920909,-0.397607,-0.942714,1.05387,0.717292,1.02561,-0.39923,-0.39923,-0.39923,-0.39923,1.02561,-0.39923,-0.39923,-0.39923,1.50633,1.50633,1.27392,1.27392,1.05387,⋯
9,2568,1365ff5f,pretend,dalai lama,A,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0.638681,-0.0207043,-0.178659,0.193875,-0.357475,-0.312771,0.120859,-0.0277825,0.376045,-0.312771,-0.0892506,0.367104,-0.133955,0.497491,0.424847,-0.178659,-0.357475,0.336929,0.139858,-0.0557225,-0.0110184,-0.312771,-0.312771,-0.312771,0.19164,-0.312771,True,1.42384,-0.553257,-0.171166,0.264536,-0.995179,-0.317486,-0.398347,0.80983,1.06811,-0.949862,-0.766517,0.197004,-0.282535,1.00384,1.22835,-0.42382,-0.967041,0.778137,0.878843,1.68716,-0.178275,-0.705501,-0.29616,-0.95075,-0.410491,-0.973261,4.57694,2.23567,1.19488,3.33048,-0.414113,0.0172918,2.75396,2.29062,3.96805,0.0172918,1.60988,3.96308,1.54815,4.31104,4.14021,0.894898,⋯
10,24499,4365a3bd,nonpretend,taylor swift,R,0,0,1,1,0,0,1,1,0,1,1,0,1,1,0,1,1,1,1,1,0,1,1,1,1,1,-0.366025,-0.366025,-0.366025,-0.366025,-0.366025,-0.366025,-0.366025,-0.366025,-0.366025,-0.366025,-0.366025,-0.366025,-0.366025,-0.366025,-0.366025,-0.366025,-0.366025,1.53731,1.53731,1.53731,-0.366025,-0.366025,1.53731,-0.366025,1.53731,-0.366025,True,-0.961334,-0.961334,0.322693,1.00163,-0.961334,-0.961334,-0.0313143,1.85134,-0.961334,-0.890717,-0.605019,-0.961334,0.149151,2.15366,-0.961334,-0.0710074,-0.917487,1.80196,1.95888,3.21845,-0.961334,-0.50994,0.12792,-0.892102,-0.0502378,-0.92718,-0.404541,-0.404541,-0.404541,-0.404541,-0.404541,-0.404541,-0.404541,-0.404541,-0.404541,-0.404541,-0.404541,-0.404541,-0.404541,-0.404541,-0.404541,-0.404541,⋯


In [3]:
# Recode columns
letters = map(string, collect('A':'Z'))
subjects = unique(data[:,:subj_id])
conditions = ["pretend", "nonpretend"]
data[!, :letter_index] = indexin(data[!, :letter], letters)
data[!, :subj_index] = indexin(data[!, :subj_id], subjects)
data[!, :pretend] = map(x -> x=="pretend" ? 1 : 0, data[:, :test_part])

# Show twenty random rows of the new dataframe
data[rand(1:size(data, 1), 20), [:subj_id, :subj_index, :letter, :letter_index, :test_part, :pretend]]

Row,subj_id,subj_index,letter,letter_index,test_part,pretend
Unnamed: 0_level_1,String15,Union…,String1,Union…,String15,Int64
1,3915d649,257,N,14,nonpretend,0
2,24160734,126,T,20,pretend,1
3,2566165d,141,I,9,nonpretend,0
4,4715d51f,331,I,9,nonpretend,0
5,1095eac8,10,A,1,pretend,1
6,110601e1,11,N,14,nonpretend,0
7,1485ddd9,45,L,12,pretend,1
8,3935afe3,259,N,14,pretend,1
9,36961206,236,V,22,pretend,1
10,7460fbe,478,N,14,nonpretend,0


In [13]:
subjects[79]

"1885a9f5"

In [4]:
prior_cols = [Symbol("prior_", i) for i in 'A':'Z']
post_cols = [Symbol("post_", i) for i in 'A':'Z']
eig_cols = [Symbol("eig_", i) for i in 'A':'Z']
mask_cols = [Symbol("mask_", i) for i in 'A':'Z']
subj_col = :subj_index
pretend_col = :pretend
target_col = :letter_index

# Turing requires data in matrix and vector form.
priors = Matrix(data[!, prior_cols])
posts = Matrix(data[!, post_cols])
eig = Matrix(data[!, eig_cols])
masks = Matrix(data[!, mask_cols])
subj = data[!, subj_col]
pretend = data[!, pretend_col]
target = data[!, target_col]

41445-element Vector{Union{Nothing, Int64}}:
  1
  3
  4
  5
  9
 10
 11
 12
 13
 15
 18
 19
 20
  ⋮
  5
  9
 13
 14
 15
 20
  1
  5
  8
  9
 15
 20

In [5]:
# Bayesian multinomial logistic regression
@model function logistic_regression(priors, posts, eig, masks, pretend, y)
    
    n = size(priors, 1) # number of rows
    length(y) == size(posts,1) == size(masks,1) == n ||
        throw(DimensionMismatch("number of observations in `as` and 'bs' `y` is not equal"))

    coef_prior ~ Normal(0, 10)
    coef_post ~ Normal(0, 10)
    coef_eig ~ Normal(0, 10)

    # Subject-level parameters: condition effets
    coef_prior_delta ~ Normal(0, 10)
    coef_post_delta ~ Normal(0, 10)
    coef_eig_delta ~ Normal(0, 10)
    
    for i in 1:n
        v_unmasked = softmax((coef_prior*priors[i,:]+coef_post*posts[i,:]+coef_eig*eig[i,:]) +
            pretend[i]*(coef_prior_delta*priors[i,:]+coef_post_delta*posts[i,:]+coef_eig_delta*eig[i,:]))
        v = v_unmasked.*masks[i,:]
        if sum(v)>0
            v=v/sum(v)
        else 
            v = masks[i,:]/sum(masks[i,:]) # if everything sum to 0, just choose randomly from the possibilities
        end
        y[i] ~ Categorical(v)
    end
end;

In [6]:
"""
    outfun(m, outfn="output.csv")

output the coefficient table of a fitted model to a file
"""
outfun = function(m, outfn="output.csv")
    ct = coeftable(m)
    coef_df = DataFrame(ct.cols, :auto);
    rename!(coef_df, ct.colnms, makeunique = true)
    coef_df[!, :term] = ct.rownms;
    CSV.write(outfn, coef_df);
end

for i = 1:length(unique(subj))
    i
    m = logistic_regression(priors[(subj.==i),:], posts[(subj.==i),:], eig[(subj.==i),:], masks[(subj.==i),:], pretend[(subj.==i)], target[(subj.==i)])
    mle1 = optimize(m, MLE())
    outfun(mle1, "model_fits//E4//subj$(i)_orth.csv")
end

# t = @time chain = sample(m, HMC(0.05, 10), MCMCThreads(), 1_500, 3)
# t = @time chain = sample(m, Prior(), 1500)

# Save the chain to a file
# save("my_chain.jld", "chain", chain)
# CSV.write("my_chains.csv", DataFrame(chain))