In [1]:
using LinearAlgebra, Distributions, Combinatorics, Random, Kronecker, SpecialFunctions
include("../DCM_model/BBVI_utils.jl")

# Object for observed data
struct TDCMObs{T <: AbstractFloat}
    # data
    Y       :: Array{Int, 3}
    Q       :: Matrix{Int}
    D       :: Vector{Matrix{Int}}
    U       :: Vector{Vector{Matrix{T}}}
    X       :: Vector{Vector{Matrix{T}}}
    group   :: Vector{Int}
end

function TDCMObs(
    Y       :: Array{Int, 3}, 
    Q       :: Matrix{Int},
    U       :: Vector{Vector{Matrix{T}}},
    X       :: Vector{Vector{Matrix{T}}},
    group   :: Vector{Int}) where T <: AbstractFloat
    D = generate_delta(Q)
    TDCMObs(Y, Q, D, U, X, group)
end

# Object including latent variables and model parameters
struct TDCMmodel{T <: AbstractFloat}
    # data
    obs         :: TDCMObs,
    # prior distribution parameters
    mu_omega_prior      :: Vector{Vector{Vector{Vector{T}}}},
    V_omega_prior       :: Vector{Vector{Vector{Matrix{T}}}},
    a_tau_prior         :: Vector{Vector{Vector{T}}},
    b_tau_prior         :: Vector{Vector{Vector{T}}}
    # This option allocates extra memory based on number of threads availible in the environment
    enable_parallel     :: Bool
    # Variational distribution parameters
    pi_star             :: Vector{Vector{Vector{T}}}
    mu_beta_star        :: Vector{Vector{T}}
    mu_gamma_init_star  :: Vector{Vector{Vector{T}}}
    V_gamma_init_star   :: Vector{Vector{Matrix{T}}}
    mu_gamma_t_star     :: Vector{Vector{Vector{Vector{T}}}}
    V_gamma_t_star      :: Vector{Vector{Vector{Matrix{T}}}}
    mu_omega_star       :: Vector{Vector{Vector{Vector{Vector{T}}}}}
    V_omega_star        :: Vector{Vector{Vector{Vector{Matrix{T}}}}}
    a_tau_star          :: Vector{Vector{Vector{Vector{T}}}}
    b_tau_star          :: Vector{Vector{Vector{Vector{T}}}}
end

In [2]:
using RCall

R"""
load("TDCM_data.RData")
"""
TDCM_data = @rget data_single_level
Y = Array{Int, 3}(TDCM_data[:Y])
Q = convert(Matrix{Int64}, TDCM_data[:Q_matrix])
U = Vector{Vector{Matrix{Float64}}}(TDCM_data[:X_group])
for time in TDCM_data[:X_ind]
    for k in 1:length(time)
        if time[k] isa Vector{<: Number}
            time[k] = reshape(time[k], :, 1)
        end
    end
end
X = Vector{Vector{Matrix{Float64}}}(TDCM_data[:X_ind])
group = Vector{Int64}(TDCM_data[:group])
obs = TDCMObs(Y, Q, U, X, group)
;

In [5]:
obs.D

25-element Vector{Matrix{Int64}}:
 [1 0; 1 1; 1 0; 1 1]
 [1 0 0 0; 1 1 0 0; 1 0 1 0; 1 1 1 1]
 [1 0; 1 0; 1 1; 1 1]
 [1 0 0 0; 1 1 0 0; 1 0 1 0; 1 1 1 1]
 [1 0; 1 1; 1 0; 1 1]
 [1 0; 1 1; 1 0; 1 1]
 [1 0 0 0; 1 1 0 0; 1 0 1 0; 1 1 1 1]
 [1 0 0 0; 1 1 0 0; 1 0 1 0; 1 1 1 1]
 [1 0 0 0; 1 1 0 0; 1 0 1 0; 1 1 1 1]
 [1 0; 1 0; 1 1; 1 1]
 [1 0 0 0; 1 1 0 0; 1 0 1 0; 1 1 1 1]
 [1 0; 1 0; 1 1; 1 1]
 [1 0; 1 1; 1 0; 1 1]
 [1 0 0 0; 1 1 0 0; 1 0 1 0; 1 1 1 1]
 [1 0; 1 1; 1 0; 1 1]
 [1 0; 1 0; 1 1; 1 1]
 [1 0; 1 0; 1 1; 1 1]
 [1 0; 1 0; 1 1; 1 1]
 [1 0 0 0; 1 1 0 0; 1 0 1 0; 1 1 1 1]
 [1 0; 1 1; 1 0; 1 1]
 [1 0; 1 0; 1 1; 1 1]
 [1 0; 1 0; 1 1; 1 1]
 [1 0 0 0; 1 1 0 0; 1 0 1 0; 1 1 1 1]
 [1 0; 1 1; 1 0; 1 1]
 [1 0; 1 0; 1 1; 1 1]

In [25]:
TDCM_data[:X_ind][1][1] isa Vector{<: Number}

true

In [27]:
Y

1000×2×25 Array{Float64, 3}:
[:, :, 1] =
 0.0  0.0
 0.0  0.0
 0.0  0.0
 1.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  1.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 1.0  1.0
 0.0  0.0
 0.0  0.0
 ⋮    
 1.0  1.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 1.0  1.0
 0.0  0.0
 0.0  0.0
 1.0  0.0
 1.0  0.0
 1.0  0.0

[:, :, 2] =
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  1.0
 0.0  0.0
 1.0  0.0
 ⋮    
 0.0  1.0
 1.0  0.0
 0.0  1.0
 0.0  0.0
 1.0  0.0
 0.0  0.0
 0.0  1.0
 0.0  1.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0

[:, :, 3] =
 0.0  0.0
 0.0  0.0
 1.0  0.0
 0.0  1.0
 0.0  0.0
 0.0  0.0
 0.0  1.0
 0.0  1.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 1.0  0.0
 0.0  0.0
 ⋮    
 0.0  1.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  1.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0

;;; … 

[:, :, 23] =
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 1.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0
 1.0  0.0
 0.0  0.0
 ⋮    
 1.