Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

299 refactor encodings and decodings #432

Merged
merged 13 commits into from
Apr 22, 2024

Conversation

pat-alt
Copy link
Member

@pat-alt pat-alt commented Apr 19, 2024

  • Refactors the encodings and decodings such that it now more streamlined. Instead of conditional statements, encodings are now dispatched on the type of a new unifying data.input_encoder field.
  • Refactors the check for redundancy. This is now based on the convergence type and done right before the counterfactual search begins, if not redundant.

@pat-alt pat-alt linked an issue Apr 19, 2024 that may be closed by this pull request
4 tasks
convergence = Convergence.get_convergence_type(convergence, data.y_levels)
if generator.latent_space && !(typeof(data.input_encoder) <: GenerativeModels.AbstractGenerativeModel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if generator.latent_space && !(typeof(data.input_encoder) <: GenerativeModels.AbstractGenerativeModel)
if generator.latent_space &&
!(typeof(data.input_encoder) <: GenerativeModels.AbstractGenerativeModel)

@info "No pre-trained generative model found. Training default VAE."
data.input_encoder = DataPreprocessing.fit_transformer(data, GenerativeModels.VAE)
end
if generator.dim_reduction && !(typeof(data.input_encoder) <: MultivariateStats.AbstractDimensionalityReduction)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if generator.dim_reduction && !(typeof(data.input_encoder) <: MultivariateStats.AbstractDimensionalityReduction)
if generator.dim_reduction &&
!(typeof(data.input_encoder) <: MultivariateStats.AbstractDimensionalityReduction)

Comment on lines 17 to 19
CounterfactualExplanations.GenerativeModels.retrain!(
counterfactual_data.generative_model, X, ys
counterfactual_data.input_encoder, X
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
CounterfactualExplanations.GenerativeModels.retrain!(
counterfactual_data.generative_model, X, ys
counterfactual_data.input_encoder, X
)
CounterfactualExplanations.GenerativeModels.retrain!(counterfactual_data.input_encoder, X)

@@ -18,17 +27,43 @@ end

Helper function to encode an array `x` using a data transform `dt::StatsBase.AbstractDataTransform`.
"""
function encode_array(dt::StatsBase.AbstractDataTransform, x::AbstractArray)
return StatsBase.transform(dt, x)
function encode_array(data::CounterfactualData, dt::StatsBase.AbstractDataTransform, x::AbstractArray)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function encode_array(data::CounterfactualData, dt::StatsBase.AbstractDataTransform, x::AbstractArray)
function encode_array(
data::CounterfactualData, dt::StatsBase.AbstractDataTransform, x::AbstractArray
)

@pat-alt pat-alt merged commit 4132ca9 into main Apr 22, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Refactor encodings and decodings
1 participant