From ca754f01d993cbb58b9ecb9b75916db16e280aae Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Primo=C5=BE=20Godec?=
Date: Tue, 5 Jan 2021 13:36:25 +0100
Subject: [PATCH] Explainer temporary workaround for calibration model and
tests
---
orangecontrib/explain/explainer.py | 31 ++++++++++++++-----
orangecontrib/explain/tests/test_explainer.py | 29 +++++++++++++----
2 files changed, 47 insertions(+), 13 deletions(-)
diff --git a/orangecontrib/explain/explainer.py b/orangecontrib/explain/explainer.py
index 3070c28..51d7bd6 100644
--- a/orangecontrib/explain/explainer.py
+++ b/orangecontrib/explain/explainer.py
@@ -102,9 +102,7 @@ def _explain_trees(
for i in range(0, len(data_sample), batch_size):
progress_callback(i / len(data_sample))
batch = data_sample.X[i : i + batch_size]
- shap_values.append(
- explainer.shap_values(batch, check_additivity=False)
- )
+ shap_values.append(explainer.shap_values(batch, check_additivity=False))
shap_values = _join_shap_values(shap_values)
base_value = explainer.expected_value
@@ -152,7 +150,9 @@ def _explain_other_models(
for i, row in enumerate(data_sample.X):
progress_callback(i / len(data_sample))
shap_values.append(
- explainer.shap_values(row, nsamples=100, silent=True, l1_reg="num_features(90)")
+ explainer.shap_values(
+ row, nsamples=100, silent=True, l1_reg="num_features(90)"
+ )
)
return (
_join_shap_values(shap_values),
@@ -205,8 +205,24 @@ def compute_shap_values(
progress_callback = dummy_callback
progress_callback(0, "Computing explanation ...")
- data_transformed = model.data_to_model_domain(data)
- reference_data_transformed = model.data_to_model_domain(reference_data)
+ #### workaround for bug with calibration
+ #### remove when fixed
+ from Orange.classification import (
+ ThresholdClassifier,
+ CalibratedClassifier,
+ )
+
+ trans_model = model
+ while isinstance(
+ trans_model, (ThresholdClassifier, CalibratedClassifier)
+ ):
+ trans_model = trans_model.base_model
+ #### end of workaround for bug with calibration
+
+ data_transformed = trans_model.data_to_model_domain(data)
+ reference_data_transformed = trans_model.data_to_model_domain(
+ reference_data
+ )
shap_values, sample_mask, base_value = _explain_trees(
model,
@@ -422,7 +438,8 @@ def explain_predictions(
Domain(data.domain.attributes, None, data.domain.metas)
)
predictions = model(
- classless_data, model.Probs if model.domain.class_var.is_discrete else model.Value
+ classless_data,
+ model.Probs if model.domain.class_var.is_discrete else model.Value,
)
# for regression - predictions array is 1d transform it shape N x 1
if predictions.ndim == 1:
diff --git a/orangecontrib/explain/tests/test_explainer.py b/orangecontrib/explain/tests/test_explainer.py
index eefc646..2abfaaa 100644
--- a/orangecontrib/explain/tests/test_explainer.py
+++ b/orangecontrib/explain/tests/test_explainer.py
@@ -1,6 +1,8 @@
+import inspect
import unittest
import numpy as np
+import pkg_resources
from Orange.classification import (
LogisticRegressionLearner,
@@ -8,11 +10,11 @@
SGDClassificationLearner,
SVMLearner,
TreeLearner,
+ ThresholdLearner,
)
from Orange.data import Table, Domain
from Orange.regression import LinearRegressionLearner
-from Orange.tests.test_classification import LearnerAccessibility
-from Orange.tests import test_regression
+from Orange.tests import test_regression, test_classification
from Orange.widgets.data import owcolor
from orangecontrib.explain.explainer import (
compute_colors,
@@ -159,12 +161,17 @@ def test_class_not_predicted(self):
# missing class has all shap values 0
self.assertTrue(not np.any(shap_values[2].sum()))
- @unittest.skip("Enable when learners fixed")
def test_all_classifiers(self):
""" Test explanation for all classifiers """
- for learner in LearnerAccessibility.all_learners(None):
+ for learner in test_classification.all_learners():
with self.subTest(learner.name):
- model = learner(self.iris)
+ if learner == ThresholdLearner:
+ # ThresholdLearner require binary class
+ continue
+ kwargs = {}
+ if "base_learner" in inspect.signature(learner).parameters:
+ kwargs = {"base_learner": LogisticRegressionLearner()}
+ model = learner(**kwargs)(self.iris)
shap_values, _, _, _ = compute_shap_values(
model, self.iris, self.iris
)
@@ -176,7 +183,7 @@ def test_all_classifiers(self):
@unittest.skipIf(
not hasattr(test_regression, "all_learners"),
- "all_learners not available in Orange < 3.26"
+ "all_learners not available in Orange < 3.26",
)
def test_all_regressors(self):
""" Test explanation for all regressors """
@@ -569,6 +576,16 @@ def test_no_class(self):
self.assertTupleEqual(self.iris.X.shape, shap_values[0].shape)
self.assertTupleEqual((len(self.iris),), sample_mask.shape)
+ def test_remove_calibration_workaround(self):
+ """
+ When this test start to fail remove the workaround in
+ explainer.py-207:220 if allready fixed - revert the pullrequest
+ that adds those lines.
+ """
+ self.assertGreater(
+ "3.29.0", pkg_resources.get_distribution("orange3").version
+ )
+
if __name__ == "__main__":
unittest.main()