-
Notifications
You must be signed in to change notification settings - Fork 6
/
growing_spheres.jl
41 lines (36 loc) · 1.28 KB
/
growing_spheres.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
generate_counterfactual(
x::Matrix,
target::RawTargetType,
data::DataPreprocessing.CounterfactualData,
M::Models.AbstractFittedModel,
generator::Generators.GrowingSpheresGenerator;
num_counterfactuals::Int=1,
convergence::Union{AbstractConvergence,Symbol}=Convergence.DecisionThresholdConvergence(;
decision_threshold=(1 / length(data.y_levels)), max_iter=1000
),
kwrgs...,
)
Overloads the `generate_counterfactual` for the `GrowingSpheresGenerator` generator.
"""
function generate_counterfactual(
x::Matrix,
target::RawTargetType,
data::DataPreprocessing.CounterfactualData,
M::Models.AbstractFittedModel,
generator::Generators.GrowingSpheresGenerator;
num_counterfactuals::Int=1,
convergence::Union{AbstractConvergence,Symbol}=Convergence.DecisionThresholdConvergence(;
decision_threshold=(1 / length(data.y_levels)), max_iter=1000
),
kwrgs...,
)
ce = CounterfactualExplanation(
x, target, data, M, generator; num_counterfactuals, convergence
)
Generators.growing_spheres_generation!(ce)
Generators.feature_selection!(ce)
# growing spheres does not support encodings, thus x′ is just s′
ce.x′ = ce.s′
return ce
end