-
Notifications
You must be signed in to change notification settings - Fork 6
/
ensemble.jl
124 lines (101 loc) · 3.05 KB
/
ensemble.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
FluxEnsemble <: AbstractFluxModel
Constructor for deep ensembles trained in `Flux.jl`.
"""
struct FluxEnsemble <: AbstractFluxModel
model::Any
likelihood::Symbol
function FluxEnsemble(model, likelihood)
if likelihood ∈ [:classification_binary, :classification_multi]
new(model, likelihood)
else
throw(
ArgumentError(
"`type` should be in `[:classification_binary,:classification_multi]`"
),
)
end
end
end
# Outer constructor method:
function FluxEnsemble(model; likelihood::Symbol=:classification_binary)
@.(Flux.testmode!(model))
return FluxEnsemble(model, likelihood)
end
function logits(M::FluxEnsemble, X::AbstractArray)
return sum(map(nn -> nn(X), M.model)) / length(M.model)
end
function probs(M::FluxEnsemble, X::AbstractArray)
if M.likelihood == :classification_binary
output = sum(map(nn -> Flux.σ.(nn(X)), M.model)) / length(M.model)
elseif M.likelihood == :classification_multi
output = sum(map(nn -> Flux.softmax(nn(X)), M.model)) / length(M.model)
end
return output
end
"""
FluxModelParams
Default Deep Ensemble training parameters.
"""
Base.@kwdef struct FluxEnsembleParams
loss::Symbol = :logitbinarycrossentropy
opt::Symbol = :Adam
n_epochs::Int = 100
batchsize::Int = 1
end
"""
train(M::FluxEnsemble, data::CounterfactualData; kwargs...)
Wrapper function to retrain.
"""
function train(M::FluxEnsemble, data::CounterfactualData; args=flux_training_params)
# Prepare data:
data = data_loader(data; batchsize=args.batchsize)
# Multi-class case:
if M.likelihood == :classification_multi
loss = :logitcrossentropy
else
loss = args.loss
end
# Setup:
ensemble = M.model
if flux_training_params.verbose
@info "Begin training Deep Ensemble"
end
count = 1
n_models = length(ensemble)
for model in ensemble
# Model name
models_done = repeat("#", count)
models_missing = repeat("-", n_models - count)
msg = "MLP $(count): $(models_done)$(models_missing) ($(count)/$(n_models))"
# Train:
forward!(
model,
data;
loss=args.loss,
opt=args.opt,
n_epochs=args.n_epochs,
model_name=msg,
)
count += 1
end
return M
end
"""
build_ensemble(K::Int;kw=(input_dim=2,n_hidden=32,output_dim=1))
Helper function that builds an ensemble of `K` models.
"""
function build_ensemble(K::Int; kwargs...)
ensemble = [build_mlp(; kwargs...) for i in 1:K]
return ensemble
end
function FluxEnsemble(data::CounterfactualData, K::Int=5; kwargs...)
# Basic setup:
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(data)
input_dim = size(X, 1)
output_dim = size(y, 1)
# Build deep ensemble:
ensemble = build_ensemble(K; input_dim=input_dim, output_dim=output_dim, kwargs...)
M = FluxEnsemble(ensemble; likelihood=data.likelihood)
return M
end