Skip to content

Commit

Permalink
Merge pull request #10 from PrimozGodec/fix-calibration
Browse files Browse the repository at this point in the history
Temporarily fix explainer to work with calibration modelom
  • Loading branch information
ajdapretnar committed Jan 7, 2021
2 parents 8766100 + ca754f0 commit 4081fbe
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 13 deletions.
31 changes: 24 additions & 7 deletions orangecontrib/explain/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def _explain_trees(
for i in range(0, len(data_sample), batch_size):
progress_callback(i / len(data_sample))
batch = data_sample.X[i : i + batch_size]
shap_values.append(
explainer.shap_values(batch, check_additivity=False)
)
shap_values.append(explainer.shap_values(batch, check_additivity=False))

shap_values = _join_shap_values(shap_values)
base_value = explainer.expected_value
Expand Down Expand Up @@ -152,7 +150,9 @@ def _explain_other_models(
for i, row in enumerate(data_sample.X):
progress_callback(i / len(data_sample))
shap_values.append(
explainer.shap_values(row, nsamples=100, silent=True, l1_reg="num_features(90)")
explainer.shap_values(
row, nsamples=100, silent=True, l1_reg="num_features(90)"
)
)
return (
_join_shap_values(shap_values),
Expand Down Expand Up @@ -205,8 +205,24 @@ def compute_shap_values(
progress_callback = dummy_callback
progress_callback(0, "Computing explanation ...")

data_transformed = model.data_to_model_domain(data)
reference_data_transformed = model.data_to_model_domain(reference_data)
#### workaround for bug with calibration
#### remove when fixed
from Orange.classification import (
ThresholdClassifier,
CalibratedClassifier,
)

trans_model = model
while isinstance(
trans_model, (ThresholdClassifier, CalibratedClassifier)
):
trans_model = trans_model.base_model
#### end of workaround for bug with calibration

data_transformed = trans_model.data_to_model_domain(data)
reference_data_transformed = trans_model.data_to_model_domain(
reference_data
)

shap_values, sample_mask, base_value = _explain_trees(
model,
Expand Down Expand Up @@ -422,7 +438,8 @@ def explain_predictions(
Domain(data.domain.attributes, None, data.domain.metas)
)
predictions = model(
classless_data, model.Probs if model.domain.class_var.is_discrete else model.Value
classless_data,
model.Probs if model.domain.class_var.is_discrete else model.Value,
)
# for regression - predictions array is 1d transform it shape N x 1
if predictions.ndim == 1:
Expand Down
29 changes: 23 additions & 6 deletions orangecontrib/explain/tests/test_explainer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import inspect
import unittest

import numpy as np
import pkg_resources

from Orange.classification import (
LogisticRegressionLearner,
RandomForestLearner,
SGDClassificationLearner,
SVMLearner,
TreeLearner,
ThresholdLearner,
)
from Orange.data import Table, Domain
from Orange.regression import LinearRegressionLearner
from Orange.tests.test_classification import LearnerAccessibility
from Orange.tests import test_regression
from Orange.tests import test_regression, test_classification
from Orange.widgets.data import owcolor
from orangecontrib.explain.explainer import (
compute_colors,
Expand Down Expand Up @@ -159,12 +161,17 @@ def test_class_not_predicted(self):
# missing class has all shap values 0
self.assertTrue(not np.any(shap_values[2].sum()))

@unittest.skip("Enable when learners fixed")
def test_all_classifiers(self):
""" Test explanation for all classifiers """
for learner in LearnerAccessibility.all_learners(None):
for learner in test_classification.all_learners():
with self.subTest(learner.name):
model = learner(self.iris)
if learner == ThresholdLearner:
# ThresholdLearner require binary class
continue
kwargs = {}
if "base_learner" in inspect.signature(learner).parameters:
kwargs = {"base_learner": LogisticRegressionLearner()}
model = learner(**kwargs)(self.iris)
shap_values, _, _, _ = compute_shap_values(
model, self.iris, self.iris
)
Expand All @@ -176,7 +183,7 @@ def test_all_classifiers(self):

@unittest.skipIf(
not hasattr(test_regression, "all_learners"),
"all_learners not available in Orange < 3.26"
"all_learners not available in Orange < 3.26",
)
def test_all_regressors(self):
""" Test explanation for all regressors """
Expand Down Expand Up @@ -569,6 +576,16 @@ def test_no_class(self):
self.assertTupleEqual(self.iris.X.shape, shap_values[0].shape)
self.assertTupleEqual((len(self.iris),), sample_mask.shape)

def test_remove_calibration_workaround(self):
"""
When this test start to fail remove the workaround in
explainer.py-207:220 if allready fixed - revert the pullrequest
that adds those lines.
"""
self.assertGreater(
"3.29.0", pkg_resources.get_distribution("orange3").version
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4081fbe

Please sign in to comment.