-
Notifications
You must be signed in to change notification settings - Fork 6
/
core_struct.jl
120 lines (105 loc) · 3.61 KB
/
core_struct.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
using ..GenerativeModels: GenerativeModels
using MultivariateStats: MultivariateStats
"""
A struct that collects all information relevant to a specific counterfactual explanation for a single individual.
"""
mutable struct CounterfactualExplanation <: AbstractCounterfactualExplanation
x::AbstractArray
target::RawTargetType
target_encoded::EncodedTargetType
s′::AbstractArray
x′::AbstractArray
data::DataPreprocessing.CounterfactualData
M::Models.AbstractFittedModel
generator::Generators.AbstractGenerator
search::Union{Dict,Nothing}
convergence::AbstractConvergence
num_counterfactuals::Int
initialization::Symbol
end
"""
function CounterfactualExplanation(;
x::AbstractArray,
target::RawTargetType,
data::CounterfactualData,
M::Models.AbstractFittedModel,
generator::Generators.AbstractGenerator,
num_counterfactuals::Int = 1,
initialization::Symbol = :add_perturbation,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
)
Outer method to construct a `CounterfactualExplanation` structure.
"""
function CounterfactualExplanation(
x::AbstractArray,
target::RawTargetType,
data::CounterfactualData,
M::Models.AbstractFittedModel,
generator::Generators.AbstractGenerator;
num_counterfactuals::Int=1,
initialization::Symbol=:add_perturbation,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
)
# Setups:
convergence = Convergence.get_convergence_type(convergence, data.y_levels)
if generator.latent_space &&
!(typeof(data.input_encoder) <: GenerativeModels.AbstractGenerativeModel)
@info "No pre-trained generative model found. Training default VAE."
data.input_encoder = DataPreprocessing.fit_transformer(data, GenerativeModels.VAE)
end
if generator.dim_reduction &&
!(typeof(data.input_encoder) <: MultivariateStats.AbstractDimensionalityReduction)
@info "No pre-trained dimensionality reduction model found. Training default PCA."
data.input_encoder = DataPreprocessing.fit_transformer(data, MultivariateStats.PCA)
end
# Factual and target:
x = typeof(x) == Int ? select_factual(data, x) : x
target_encoded = data.output_encoder(target; y_levels=data.y_levels)
# Instantiate:
ce = CounterfactualExplanation(
x,
target,
target_encoded,
x,
x,
data,
M,
deepcopy(generator),
nothing,
convergence,
num_counterfactuals,
initialization,
)
# Initialize:
initialize!(ce)
return ce
end
"""
initialize!(ce::CounterfactualExplanation)
Initializes the counterfactual explanation. This method is called by the constructor. It does the following:
1. Creates a dictionary to store information about the search.
2. Initializes the counterfactual state.
3. Initializes the search path.
4. Initializes the loss.
"""
function initialize!(ce::CounterfactualExplanation)
# Initialize search:
ce.search = Dict(
:iteration_count => 0,
:mutability => DataPreprocessing.mutability_constraints(ce.data),
)
# Check if the objective needs neighbours:
if Objectives.needs_neighbours(ce)
get!(
ce.search,
:potential_neighbours,
CounterfactualExplanations.find_potential_neighbours(ce),
)
end
# Initialization:
adjust_shape!(ce) |> encode_state! |> initialize_state! |> decode_state!
ce.search[:path] = [ce.s′]
ce.search[:times_changed_features] = zeros(size(decode_state(ce)))
ce.search[:loss] = [Generators.total_loss(ce)]
return ce
end