Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

calculate the feature importance with mean absolute SHAP values #3507

Merged
merged 13 commits into from May 17, 2022

Conversation

liaison
Copy link
Contributor

@liaison liaison commented Apr 24, 2022

Motivation

This PR adds the functionality to calculate the feature importance for the hyperparameters during the trial, based on the SHAP values.

Description of the changes

Here is the issue that this PR is addressing.
In order to calculate the SHAP values, we need a surrogate model, e.g. RandomForest. This PR uses the RandomForest model from the MeanDecreaseImpurityImportanceEvaluator.

@github-actions github-actions bot added the optuna.importance Related to the `optuna.importance` submodule. This is automatically labeled by github-actions. label Apr 24, 2022
@nzw0301
Copy link
Collaborator

nzw0301 commented Apr 25, 2022

As mentioned in #3448 (comment), I also suggest adding this shape feature to optuna.integration, not optuna.importance.

@nzw0301 nzw0301 self-assigned this Apr 25, 2022
@himkt
Copy link
Member

himkt commented Apr 25, 2022

I'd like @contramundum53 to review the PR as another reviewer!

@nzw0301
Copy link
Collaborator

nzw0301 commented Apr 25, 2022

To fix the CI error due to missing shap library,
could you add shap as an element of

"testing": [
and

optuna/setup.py

Line 141 in 9d80368

"integration": [
in https://github.com/optuna/optuna/blob/master/setup.py.

@codecov-commenter
Copy link

Codecov Report

Merging #3507 (9d80368) into master (1a7a51b) will increase coverage by 0.00%.
The diff coverage is 90.00%.

❗ Current head 9d80368 differs from pull request most recent head dd00a81. Consider uploading reports for the commit dd00a81 to get more accurate results

@@           Coverage Diff           @@
##           master    #3507   +/-   ##
=======================================
  Coverage   91.56%   91.56%           
=======================================
  Files         158      158           
  Lines       12226    12228    +2     
=======================================
+ Hits        11195    11197    +2     
  Misses       1031     1031           
Impacted Files Coverage Δ
optuna/cli.py 20.47% <0.00%> (ø)
optuna/study/study.py 96.19% <ø> (ø)
optuna/visualization/_pareto_front.py 98.57% <ø> (ø)
optuna/study/_optimize.py 98.63% <100.00%> (+<0.01%) ⬆️
optuna/study/_tell.py 97.95% <100.00%> (+0.02%) ⬆️
optuna/visualization/_contour.py 98.12% <100.00%> (ø)
optuna/visualization/_parallel_coordinate.py 100.00% <100.00%> (ø)
optuna/visualization/_slice.py 98.43% <100.00%> (ø)

📣 Codecov can now indicate which changes are the most critical in Pull Requests. Learn more

@HideakiImamura HideakiImamura added the feature Change that does not break compatibility, but affects the public interfaces. label Apr 25, 2022
@nzw0301
Copy link
Collaborator

nzw0301 commented Apr 26, 2022

Thank you for your update! I've skimmed the main logic, and it looks nice. I come up with one minor thing, so let me write here to remember it.

I'm wondering if we could split the part of fitting random forest in optuna/importance/_mean_decrease_impurity.py into a function because the current shap computes the importance score with the parent class: MeanDecreaseImpurityImportanceEvaluator, whose importance score calculation can be skipped when we evaluate shap based importance score. But it can be addressed in the follow-up PR.

@g-votte
Copy link
Member

g-votte commented Apr 27, 2022

In my opinion, the ShapleyImportanceEvaluator should not inherit MeanDecreaseImpurityImportanceEvaluator but hold it as "has-a" relationship.

As known as the Liskov substitution principle, a subclass should not break any behavior of its superclass. In other words, as long as an evaluator is a (grandchild) instance of MeanDecreaseImpurityImportanceEvaluator, it should evaluate a study with the mean decrease impurity algorithm. In the current PR, however, ShapleyImportanceEvaluator breaks the behavior of MeanDecreaseImpurityImportanceEvaluator replacing it with a totally different algorithm.

Alternative approach 1

We can implement ShapleyImportanceEvaluator as a direct subclass of BaseImportanceEvaluator and have it hold a MeanDecreaseImpurityImportanceEvaluator as a property.

class ShapleyImportanceEvaluator(BaseImportanceEvaluator):
    def __init__(
        self, *, n_trees: int = 64, max_depth: int = 64, seed: Optional[int] = None
    ) -> None:
        _imports.check()

        self._backend_evaluator = MeanDecreaseImpurityImportanceEvaluator(
            n_trees=n_trees, max_depth=max_depth, seed=seed
        )
        self._explainer: TreeExplainer = None

    def evaluate(...) -> Dict[str, float]:
        # Train a RandomForest from the parent class.
        self._backend_evaluator.evaluate(study=study, params=params, target=target)

        # Get necessary properties of self._backend_evaluator in the following lines.
        ...

Alternative approach 2

This is related to @nzw0301’s previous comment. We might be able to remove even the "has-a" relationship once we split the random forest fitting as a static function. Both ShapleyImportanceEvaluator and MeanDecreaseImpurityImportanceEvaluator just rely on the split function instead of being tightly-coupled each other.

I propose to rewrite the ShapleyImportanceEvaluator in either of the alternative approaches above. What do you think? @liaison @nzw0301 @himkt @contramundum53

@nzw0301
Copy link
Collaborator

nzw0301 commented Apr 27, 2022

@g-votte Thank you for your suggestion with a detailed explanation! I agree with you.

@liaison
Copy link
Contributor Author

liaison commented Apr 27, 2022

hi @g-votte @nzw0301 I quite agree with you both actually. To tell you a secret, it bothered me a bit to have ShapleyImportanceEvaluator inherited from MeanDecreaseImpurityImportanceEvaluator.

Admittedly I took a little shortcut to have the least changes for the implementation. I thought about extracting the logic of retrieving trial parameters and fitting a RandomForest as a separate function. But the logic requires quite some context (tedious to pass everything as parameters). In addition, the exception handling in between complicates things further. Inheriting all this logic seems like an easy choice while minimizing the changes. Voila. So you can say it is a choice of "optimization" :)

We could further refactor the classes should we have more evaluators in the future, or we want to have a different surrogate model (other than RandomForeset) for SHAP value calculation. Let me know what you think.

for i in feature_importances_reduced.argsort()[::-1]:
param_importances[param_names[i]] = feature_importances_reduced[i].item()
param_importances[self._param_names[i]] = feature_importances_reduced[i].item()

return param_importances
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your PR!
As also mentioned by @nzw0301, we don't actually need to introduce these member fields. Instead of having _trans_params, _trans_values and _param_names as member fields, we could make another function _evaluate returning (param_importances, trans_params, trans_values, param_names) in MeanDecreaseImpurityImportanceEvaluator and call that function from both ShapleyImportanceEvaluator.evaluate() and MeanDecreaseImpurityImportanceEvaluator.evaluate().

Copy link
Collaborator

@nzw0301 nzw0301 Apr 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to suggest splitting the current evaluation methods into i) pre-processing part that returns trans, trans_params, trans_values and ii) the computing importance score part because the evaluator classes including this shap class call common parameter transformation rather than returning importance score, transformed parameters, etc by _evaluate. Alternatively, simply implement the same parameter transformation logic in ShapleyImportanceEvaluator.evaluate.

I'm still not sure which approach is optimal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @contramundum53 @nzw0301 I do understand your concerns about separating the concerns and removing the class members. The problem is the implementation of MeanDecreaseImpurityImportanceEvaluator.evaluate() is so intertwined with many states and variables.

If we were to have a separate method to retrieve the parameters from the trials, then returning the tuple of (trans, trans_params, trans_values) isn't sufficient, we need also the distribution which is used at the end of the function. In addition, there are some exceptions raising and early returns in-between that need to be properly re-transformed. I'm not convinced that there is much to gain by doing so.

Will the code be more modularized? probably. But one might argue that having the class encapsulate the state information is itself a good way to modularize the code, which is the current situation now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but then the pre-processing takes distributions as an argument to be reused after random forest fitting. This does not need to return an empty OrderedDict; we might return it before and after preprocessing.

What do you think @g-votte? I would like to know whether the first approach suggested in #3507 (comment) is the current implementation or not. Specifically, can we introduce new private attributes to MeanDecreaseImpurityImportanceEvaluator for ShapleyImportanceEvaluator.

Copy link
Contributor Author

@liaison liaison Apr 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @nzw0301 @g-votte @contramundum53 I've given it another thought.

First of all, I think it is a good thing to keep the class MeanDecreaseImpurityImportanceEvaluator stateful, rather than insisting on having a stateless evaluate() function. One would like to look into the evaluator after the evaluation if there is anything wrong, or simply the state information could be useful for other purposes (e.g. pass it on to ShapleyImportanceEvaluator).

If you agree with the above point, then we can address the points that make you guys uneasy (including myself), namely 1). we access directly the attributes of MeanDecreaseImpurityImportanceEvaluator (although Python doesn't have any access control anyway). 2). it is not entirely necessary to evaluate MeanDecreaseImpurityImportanceEvaluator in ShapleyImportanceEvaluator.

To address the above two concerns, here is what I propose (hopefully the code is self-explanatory):

       # Retrieve the trial parameters and fit a RandomForest model out of them.
       model = self._backend_evaluator.fit(study=study, params=params, target=target)
       
       # Retrieve the trial parameters via the access function.
       (trans_params, trans_values, param_names)  = self._backend_evaluator.get_params()

PS: the above APIs are inspired by the sklearn.BaseEstimator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @nzw0301 @g-votte @contramundum53 any feedback on the above proposal?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, we were on a vacation.
Thanks, I agree with your proposal.
One minor comment: I think it would be more readable and robust to make the return value of get_params() a dict instead of a tuple.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for coming late. I think the current PR reflects what I suggested in #3507 (comment) .

As mentioned by @nzw0301 in #3507 (review), we'd need some followup refactoring to extract the common logic among importance evaluators and to get rid of unnecessary calculation with the current approach. #3552 will shortly address those problems.

Copy link
Collaborator

@nzw0301 nzw0301 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delayed response. I agree with your suggestions at this time: make this importance evaluator stateful in this PR.
Since we notice that all importance evaluators have common logics thanks to this PR, we shall refactor them in the future, but I think this PR should not address such refactoring. Hence, the pull request looks great to me now.

Nevertheless, I've left a few comments that can be fixed by a follow-up pull request.

After creating this PR, #3500 has introduced two new tests:
test_mean_decrease_impurity_importance_evaluator_with_infinite and test_multi_objective_mean_decrease_impurity_importance_evaluator_with_infinite to the test of MeanDecreaseImpurityImportanceEvaluator. We'll add a similar test functions to shap's counterpart.

optuna/integration/shap.py Outdated Show resolved Hide resolved
optuna/integration/shap.py Show resolved Hide resolved
optuna/integration/shap.py Outdated Show resolved Hide resolved
optuna/integration/shap.py Outdated Show resolved Hide resolved
liaison and others added 3 commits May 12, 2022 22:18
Co-authored-by: Kento Nozawa <k_nzw@klis.tsukuba.ac.jp>
Co-authored-by: Kento Nozawa <k_nzw@klis.tsukuba.ac.jp>
Co-authored-by: Kento Nozawa <k_nzw@klis.tsukuba.ac.jp>
@contramundum53
Copy link
Member

I experimented with this objective and ShapleyImportanceEvaluator.

This is the result for ShapleyImportanceEvaluator.
スクリーンショット 2022-05-13 12 10 54

Just for reference, this is the result for the default fANOVA importance evaluator.
スクリーンショット 2022-05-13 12 11 01

The result looks good to me.

@himkt himkt removed their assignment May 17, 2022
@contramundum53 contramundum53 merged commit b099d23 into optuna:master May 17, 2022
@nzw0301 nzw0301 added this to the v3.0.0-b1 milestone May 17, 2022
@hrntsm hrntsm mentioned this pull request Aug 31, 2022
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Change that does not break compatibility, but affects the public interfaces. optuna.importance Related to the `optuna.importance` submodule. This is automatically labeled by github-actions.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants