-
Notifications
You must be signed in to change notification settings - Fork 1
/
nesting.jl
65 lines (52 loc) · 2.13 KB
/
nesting.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
_isfulldummy(x::AbstractTerm) = false
function _isfulldummy(x::CategoricalTerm)
return isa(x.contrasts, StatsModels.ContrastsMatrix{StatsModels.FullDummyCoding})
end
function _fulldummycheck(outer::InteractionTerm)
all(_isfulldummy, outer.terms[1:end-1]) ||
throw(ArgumentError("Outer interactions in a nesting must consist only " *
" of categorical terms with FullDummyCoding, got $outer"))
return nothing
end
"""
group / term
Generate predictors for `term` within each level of `group`. Implemented as
`group + fulldummy(group) & term`.
"""
function Base.:(/)(outer::CategoricalTerm, inner::AbstractTerm)
return outer + fulldummy(outer) & inner
end
function Base.:(/)(outer::CategoricalTerm, inner::TupleTerm)
fd = fulldummy(outer)
return mapfoldl(x -> fd & x, +, inner; init=outer)
end
function Base.:(/)(outer::TupleTerm, inner::Union{AbstractTerm, TupleTerm})
return outer[1:end-1] + last(outer) / inner
end
function Base.:(/)(outer::InteractionTerm, inner::AbstractTerm)
# we should only get here via expansion where the interaction term,
# but who knows what devious things users will try
_fulldummycheck(outer)
return outer + outer & inner
end
function Base.:(/)(outer::InteractionTerm, inner::TupleTerm)
# we should only get here via expansion where the interaction term,
# but who knows what devious things users will try
_fulldummycheck(outer)
return mapfoldl(x -> outer & x, +, inner; init=outer)
end
function Base.:(/)(outer::AbstractTerm, inner::AbstractTerm)
throw(ArgumentError("Nesting terms requires categorical grouping term, got $outer / $inner " *
"Manually specify $outer as `CategoricalTerm` in hints/contrasts"))
end
function StatsModels.apply_schema(
t::FunctionTerm{typeof(/)},
sch::StatsModels.FullRank,
Mod::Type{<:RegressionModel},
)
length(t.args) == 2 ||
throw(ArgumentError("malformed nesting term: $t (Exactly two arguments required)"))
any(x -> isa(x, ConstantTerm), t.args) && return t
args = apply_schema.(t.args, Ref(sch), Mod)
return first(args) / last(args)
end