-
Notifications
You must be signed in to change notification settings - Fork 2
/
modelchains.jl
113 lines (98 loc) · 2.79 KB
/
modelchains.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#################### ModelChains ####################
#################### Constructors ####################
"""
ModelChains(c::Chains, m::Model)
See `Chains()`.
"""
function ModelChains(
c::Chains,
m::Model,
stats::Array{Float64},
stat_names::Vector{AbstractString},
params::SimulationParameters,
cs::Union{Nothing,ConvergenceStorage},
samplers::Vector{Vector{Sampler}},
)
ModelChains(
c.value,
c.range,
c.names,
c.chains,
m,
c.trees,
c.tree_names,
stats,
stat_names,
params,
cs,
samplers,
)
end
#################### Conversions ####################
Base.convert(::Type{Chains}, mc::ModelChains) =
Chains(mc.value, mc.range, mc.names, mc.chains, mc.trees, mc.tree_names)
#################### Indexing ####################
function Base.getindex(mc::ModelChains, window, names, chains)
c = getindex(convert(Chains, mc), window, names2inds(mc, names), chains)
ModelChains(
c,
mc.model,
mc.stats,
mc.stat_names,
mc.sim_params,
mc.conv_storage,
mc.samplers,
)
end
names2inds(mc::ModelChains, nodekey::Symbol) = names2inds(mc, [nodekey])
function names2inds(mc::ModelChains, nodekeys::Vector{Symbol})::Vector{Int}
inds = Int[]
missing = Symbol[]
for key in nodekeys
keyinds = indexin(names(mc.model, key), mc.names)
nothing in keyinds ? push!(missing, key) : append!(inds, keyinds)
end
if !isempty(missing)
throw(
ArgumentError(
string(
"chain values are missing for nodes : ",
join(map(string, missing), ", "),
),
),
)
end
inds
end
function Base.keys(mc::ModelChains, ntype::Symbol, at...)
values = Symbol[]
m = mc.model
nodekeys =
ntype == :dependent ? keys(m, :dependent) :
intersect(keys(m, ntype, at...), keys(m, :dependent))
for key in nodekeys
all(name -> name in mc.names, names(m, key)) && push!(values, key)
end
values
end
#################### Auxilliary Functions ####################
function link(c::ModelChains)
cc = copy(c.value)
inds_queue = 1:length(c.names)
for key in intersect(keys(c.model, :monitor), keys(c.model, :stochastic))
node = c.model[key]
inds = findall((in)(names(node)), c.names)
if !isempty(inds)
f(x) = unlist(node, relist(node, x), true)
cc[:, inds, :] = mapslices(f, cc[:, inds, :], dims = 2)
inds_queue = setdiff(inds_queue, inds)
end
end
for j in inds_queue
x = cc[:, j, :]
if minimum(x) > 0.0
cc[:, j, :] = maximum(x) < 1.0 ? logit.(x) : log.(x)
end
end
cc
end