Skip to content

Commit

Permalink
Make function _get_surrogate_model_replication_measure() public (#495)
Browse files Browse the repository at this point in the history
Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>
  • Loading branch information
gaugup committed Jan 24, 2022
1 parent 589dbcd commit 3f1e586
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
17 changes: 17 additions & 0 deletions python/interpret_community/mimic/mimic_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,10 @@ def __setstate__(self, state):

def _get_surrogate_model_replication_measure(self, training_data):
"""Return the metric which tells how well the surrogate model replicates the teacher model.
For classification scenarios, this function will return accuracy. For regression scenarios,
this function will return r2_score.
:param training_data: The data for getting the replication metric.
:type training_data: numpy.array or pandas.DataFrame or scipy.sparse.csr_matrix
:return: Metric that tells how well the surrogate model replicates the behavior of teacher model.
Expand All @@ -769,3 +773,16 @@ def _get_surrogate_model_replication_measure(self, training_data):
else:
replication_measure = r2_score(teacher_model_predictions, surrogate_model_predictions)
return replication_measure

def get_surrogate_model_replication_measure(self, training_data):
"""Return the metric which tells how well the surrogate model replicates the teacher model.
For classification scenarios, this function will return accuracy. For regression scenarios,
this function will return r2_score.
:param training_data: The data for getting the replication metric.
:type training_data: numpy.array or pandas.DataFrame or scipy.sparse.csr_matrix
:return: Metric that tells how well the surrogate model replicates the behavior of teacher model.
:rtype: float
"""
return self._get_surrogate_model_replication_measure(training_data=training_data)
9 changes: 6 additions & 3 deletions tests/test_mimic_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def test_explain_raw_feats_regression(self, mimic_explainer):
def _verify_predictions_and_replication_metric(self, mimic_explainer, data):
predictions_main_model = mimic_explainer._get_teacher_model_predictions(data)
predictions_surrogate_model = mimic_explainer._get_surrogate_model_predictions(data)
replication_score = mimic_explainer._get_surrogate_model_replication_measure(data)
replication_score = mimic_explainer.get_surrogate_model_replication_measure(data)

assert predictions_main_model is not None
assert predictions_surrogate_model is not None
Expand All @@ -422,8 +422,11 @@ def _verify_predictions_and_replication_metric(self, mimic_explainer, data):
assert replication_score is not None and isinstance(replication_score, float)

if mimic_explainer.classes is None:
with pytest.raises(ScenarioNotSupportedException):
mimic_explainer._get_surrogate_model_replication_measure(
with pytest.raises(
ScenarioNotSupportedException,
match="Replication measure for regression surrogate not supported "
"because of single instance in training data"):
mimic_explainer.get_surrogate_model_replication_measure(
data[0].reshape(1, len(data[0])))

def test_explain_model_string_classes(self, mimic_explainer):
Expand Down

0 comments on commit 3f1e586

Please sign in to comment.