-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
299 refactor encodings and decodings #432
Conversation
src/counterfactuals/core_struct.jl
Outdated
convergence = Convergence.get_convergence_type(convergence, data.y_levels) | ||
if generator.latent_space && !(typeof(data.input_encoder) <: GenerativeModels.AbstractGenerativeModel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
if generator.latent_space && !(typeof(data.input_encoder) <: GenerativeModels.AbstractGenerativeModel) | |
if generator.latent_space && | |
!(typeof(data.input_encoder) <: GenerativeModels.AbstractGenerativeModel) |
src/counterfactuals/core_struct.jl
Outdated
@info "No pre-trained generative model found. Training default VAE." | ||
data.input_encoder = DataPreprocessing.fit_transformer(data, GenerativeModels.VAE) | ||
end | ||
if generator.dim_reduction && !(typeof(data.input_encoder) <: MultivariateStats.AbstractDimensionalityReduction) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
if generator.dim_reduction && !(typeof(data.input_encoder) <: MultivariateStats.AbstractDimensionalityReduction) | |
if generator.dim_reduction && | |
!(typeof(data.input_encoder) <: MultivariateStats.AbstractDimensionalityReduction) |
test/models/generative_models.jl
Outdated
CounterfactualExplanations.GenerativeModels.retrain!( | ||
counterfactual_data.generative_model, X, ys | ||
counterfactual_data.input_encoder, X | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
CounterfactualExplanations.GenerativeModels.retrain!( | |
counterfactual_data.generative_model, X, ys | |
counterfactual_data.input_encoder, X | |
) | |
CounterfactualExplanations.GenerativeModels.retrain!(counterfactual_data.input_encoder, X) |
src/counterfactuals/encodings.jl
Outdated
@@ -18,17 +27,43 @@ end | |||
|
|||
Helper function to encode an array `x` using a data transform `dt::StatsBase.AbstractDataTransform`. | |||
""" | |||
function encode_array(dt::StatsBase.AbstractDataTransform, x::AbstractArray) | |||
return StatsBase.transform(dt, x) | |||
function encode_array(data::CounterfactualData, dt::StatsBase.AbstractDataTransform, x::AbstractArray) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function encode_array(data::CounterfactualData, dt::StatsBase.AbstractDataTransform, x::AbstractArray) | |
function encode_array( | |
data::CounterfactualData, dt::StatsBase.AbstractDataTransform, x::AbstractArray | |
) |
data.input_encoder
field.