diff --git a/orangecontrib/explain/explainer.py b/orangecontrib/explain/explainer.py index 3070c28..51d7bd6 100644 --- a/orangecontrib/explain/explainer.py +++ b/orangecontrib/explain/explainer.py @@ -102,9 +102,7 @@ def _explain_trees( for i in range(0, len(data_sample), batch_size): progress_callback(i / len(data_sample)) batch = data_sample.X[i : i + batch_size] - shap_values.append( - explainer.shap_values(batch, check_additivity=False) - ) + shap_values.append(explainer.shap_values(batch, check_additivity=False)) shap_values = _join_shap_values(shap_values) base_value = explainer.expected_value @@ -152,7 +150,9 @@ def _explain_other_models( for i, row in enumerate(data_sample.X): progress_callback(i / len(data_sample)) shap_values.append( - explainer.shap_values(row, nsamples=100, silent=True, l1_reg="num_features(90)") + explainer.shap_values( + row, nsamples=100, silent=True, l1_reg="num_features(90)" + ) ) return ( _join_shap_values(shap_values), @@ -205,8 +205,24 @@ def compute_shap_values( progress_callback = dummy_callback progress_callback(0, "Computing explanation ...") - data_transformed = model.data_to_model_domain(data) - reference_data_transformed = model.data_to_model_domain(reference_data) + #### workaround for bug with calibration + #### remove when fixed + from Orange.classification import ( + ThresholdClassifier, + CalibratedClassifier, + ) + + trans_model = model + while isinstance( + trans_model, (ThresholdClassifier, CalibratedClassifier) + ): + trans_model = trans_model.base_model + #### end of workaround for bug with calibration + + data_transformed = trans_model.data_to_model_domain(data) + reference_data_transformed = trans_model.data_to_model_domain( + reference_data + ) shap_values, sample_mask, base_value = _explain_trees( model, @@ -422,7 +438,8 @@ def explain_predictions( Domain(data.domain.attributes, None, data.domain.metas) ) predictions = model( - classless_data, model.Probs if model.domain.class_var.is_discrete else model.Value + classless_data, + model.Probs if model.domain.class_var.is_discrete else model.Value, ) # for regression - predictions array is 1d transform it shape N x 1 if predictions.ndim == 1: diff --git a/orangecontrib/explain/tests/test_explainer.py b/orangecontrib/explain/tests/test_explainer.py index eefc646..2abfaaa 100644 --- a/orangecontrib/explain/tests/test_explainer.py +++ b/orangecontrib/explain/tests/test_explainer.py @@ -1,6 +1,8 @@ +import inspect import unittest import numpy as np +import pkg_resources from Orange.classification import ( LogisticRegressionLearner, @@ -8,11 +10,11 @@ SGDClassificationLearner, SVMLearner, TreeLearner, + ThresholdLearner, ) from Orange.data import Table, Domain from Orange.regression import LinearRegressionLearner -from Orange.tests.test_classification import LearnerAccessibility -from Orange.tests import test_regression +from Orange.tests import test_regression, test_classification from Orange.widgets.data import owcolor from orangecontrib.explain.explainer import ( compute_colors, @@ -159,12 +161,17 @@ def test_class_not_predicted(self): # missing class has all shap values 0 self.assertTrue(not np.any(shap_values[2].sum())) - @unittest.skip("Enable when learners fixed") def test_all_classifiers(self): """ Test explanation for all classifiers """ - for learner in LearnerAccessibility.all_learners(None): + for learner in test_classification.all_learners(): with self.subTest(learner.name): - model = learner(self.iris) + if learner == ThresholdLearner: + # ThresholdLearner require binary class + continue + kwargs = {} + if "base_learner" in inspect.signature(learner).parameters: + kwargs = {"base_learner": LogisticRegressionLearner()} + model = learner(**kwargs)(self.iris) shap_values, _, _, _ = compute_shap_values( model, self.iris, self.iris ) @@ -176,7 +183,7 @@ def test_all_classifiers(self): @unittest.skipIf( not hasattr(test_regression, "all_learners"), - "all_learners not available in Orange < 3.26" + "all_learners not available in Orange < 3.26", ) def test_all_regressors(self): """ Test explanation for all regressors """ @@ -569,6 +576,16 @@ def test_no_class(self): self.assertTupleEqual(self.iris.X.shape, shap_values[0].shape) self.assertTupleEqual((len(self.iris),), sample_mask.shape) + def test_remove_calibration_workaround(self): + """ + When this test start to fail remove the workaround in + explainer.py-207:220 if allready fixed - revert the pullrequest + that adds those lines. + """ + self.assertGreater( + "3.29.0", pkg_resources.get_distribution("orange3").version + ) + if __name__ == "__main__": unittest.main()