Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions dice_ml/counterfactual_explanations.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,6 @@ def __eq__(self, other_cf):
self.metadata == other_cf.metadata
return False

@property
def __dict__(self):
return {'cf_examples_list': self.cf_examples_list,
'local_importance': self.local_importance,
'summary_importance': self.summary_importance,
'metadata': self.metadata}

@property
def cf_examples_list(self):
return self._cf_examples_list
Expand Down Expand Up @@ -220,8 +213,8 @@ def to_json(self):
entire_dict, version=serialization_version)
return json.dumps(entire_dict)
else:
raise Exception("Unsupported serialization version {}".format(
serialization_version))
raise UserConfigValidationException(
"Unsupported serialization version {}".format(serialization_version))

@staticmethod
def from_json(json_str):
Expand Down
177 changes: 95 additions & 82 deletions tests/test_counterfactual_explanations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,6 @@

class TestCounterfactualExplanations:

@pytest.mark.parametrize("version", ['1.0', '2.0'])
def test_serialization_deserialization_counterfactual_explanations_class(self, version):

counterfactual_explanations = CounterfactualExplanations(
cf_examples_list=[],
local_importance=None,
summary_importance=None,
version=version)
assert counterfactual_explanations.cf_examples_list is not None
assert len(counterfactual_explanations.cf_examples_list) == 0
assert counterfactual_explanations.summary_importance is None
assert counterfactual_explanations.local_importance is None
assert counterfactual_explanations.metadata is not None
assert counterfactual_explanations.metadata['version'] is not None
assert counterfactual_explanations.metadata['version'] == version

counterfactual_explanations_as_json = counterfactual_explanations.to_json()
assert counterfactual_explanations_as_json is not None

recovered_counterfactual_explanations = CounterfactualExplanations.from_json(
counterfactual_explanations_as_json)

assert recovered_counterfactual_explanations is not None
assert recovered_counterfactual_explanations.metadata['version'] == version
assert counterfactual_explanations == recovered_counterfactual_explanations

def test_sorted_summary_importance_counterfactual_explanations(self):

unsorted_summary_importance = {
Expand Down Expand Up @@ -132,21 +106,6 @@ def test_sorted_local_importance_counterfactual_explanations(self):
assert list(unsorted_local_importance[index].keys()) != list(counterfactual_explanations.local_importance[index].keys())
assert list(sorted_local_importance[index].keys()) == list(counterfactual_explanations.local_importance[index].keys())

@pytest.mark.parametrize('version', ['3.0', ''])
def test_unsupported_versions_json_input(self, version):
json_str = json.dumps({'metadata': {'version': version}})
with pytest.raises(UserConfigValidationException) as ucve:
CounterfactualExplanations.from_json(json_str)

assert "Incompatible version {} found in json input".format(version) in str(ucve)

json_str = json.dumps({'metadata': {'versio': version}})
with pytest.raises(UserConfigValidationException) as ucve:
CounterfactualExplanations.from_json(json_str)

assert "No version field in the json input" in str(ucve)



@pytest.fixture
def random_binary_classification_exp_object():
Expand All @@ -165,90 +124,144 @@ def _initiate_exp_object(self, random_binary_classification_exp_object):
self.exp = random_binary_classification_exp_object # explainer object
self.data_df_copy = self.exp.data_interface.data_df.copy()

@pytest.mark.parametrize("desired_class, total_CFs", [(0, 2)])
@pytest.mark.parametrize("version", ['1.0', '2.0'])
def verify_counterfactual_explanations(self, counterfactual_explanations,
total_CFs, num_query_points, version,
local_importance_available=False,
summary_importance_available=False):
assert counterfactual_explanations is not None
assert counterfactual_explanations.cf_examples_list is not None
assert len(counterfactual_explanations.cf_examples_list) == num_query_points
if total_CFs is not None:
assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
assert counterfactual_explanations.metadata is not None
assert counterfactual_explanations.metadata['version'] is not None
counterfactual_explanations.metadata['version'] = version
if local_importance_available:
assert counterfactual_explanations.local_importance is not None
assert len(counterfactual_explanations.local_importance) == num_query_points
else:
assert counterfactual_explanations.local_importance is None
if summary_importance_available:
assert counterfactual_explanations.summary_importance is not None
else:
assert counterfactual_explanations.summary_importance is None

@pytest.mark.parametrize("version", ['1.0', '2.0'])
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 2)])
def test_random_counterfactual_explanations_output(self, desired_class,
sample_custom_query_1, total_CFs,
version):
counterfactual_explanations = self.exp.generate_counterfactuals(
query_instances=sample_custom_query_1, desired_class=desired_class,
total_CFs=total_CFs)

assert counterfactual_explanations is not None
assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0]
assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
assert counterfactual_explanations.local_importance is None
assert counterfactual_explanations.summary_importance is None
self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs,
sample_custom_query_1.shape[0], version)

counterfactual_explanations.metadata['version'] = version
json_output = counterfactual_explanations.to_json()
assert json_output is not None
assert json.loads(json_output).get('metadata').get('version') == version

recovered_counterfactual_explanations = CounterfactualExplanations.from_json(json_output)
assert recovered_counterfactual_explanations is not None
assert recovered_counterfactual_explanations == counterfactual_explanations
assert recovered_counterfactual_explanations.metadata['version'] == version
self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs,
sample_custom_query_1.shape[0], version)

assert len(recovered_counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0]
assert recovered_counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
assert recovered_counterfactual_explanations.local_importance is None
assert recovered_counterfactual_explanations.summary_importance is None
assert recovered_counterfactual_explanations == counterfactual_explanations

@pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)])
@pytest.mark.parametrize("version", ['1.0', '2.0'])
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)])
def test_random_local_importance_output(self, desired_class, sample_custom_query_1,
total_CFs, version):
counterfactual_explanations = self.exp.local_feature_importance(
query_instances=sample_custom_query_1, desired_class=desired_class,
total_CFs=total_CFs)

assert counterfactual_explanations is not None
assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0]
assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
assert counterfactual_explanations.local_importance is not None
assert counterfactual_explanations.summary_importance is None
self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs,
sample_custom_query_1.shape[0], version,
local_importance_available=True)

counterfactual_explanations.metadata['version'] = version
json_output = counterfactual_explanations.to_json()
assert json_output is not None
assert json.loads(json_output).get('metadata').get('version') == version

recovered_counterfactual_explanations = CounterfactualExplanations.from_json(json_output)
assert recovered_counterfactual_explanations is not None
assert recovered_counterfactual_explanations == counterfactual_explanations
assert recovered_counterfactual_explanations.metadata['version'] == version
self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs,
sample_custom_query_1.shape[0], version,
local_importance_available=True)

assert len(recovered_counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0]
assert recovered_counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
assert recovered_counterfactual_explanations.local_importance is not None
assert counterfactual_explanations.summary_importance is None
assert recovered_counterfactual_explanations == counterfactual_explanations

@pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)])
@pytest.mark.parametrize("version", ['1.0', '2.0'])
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)])
def test_random_summary_importance_output(self, desired_class, sample_custom_query_10,
total_CFs, version):
counterfactual_explanations = self.exp.global_feature_importance(
query_instances=sample_custom_query_10, desired_class=desired_class,
total_CFs=total_CFs)

assert counterfactual_explanations is not None
assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_10.shape[0]
assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
assert counterfactual_explanations.local_importance is not None
assert counterfactual_explanations.summary_importance is not None
self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs,
sample_custom_query_10.shape[0], version,
local_importance_available=True,
summary_importance_available=True)

counterfactual_explanations.metadata['version'] = version
json_output = counterfactual_explanations.to_json()
assert json_output is not None
assert json.loads(json_output).get('metadata').get('version') == version

recovered_counterfactual_explanations = CounterfactualExplanations.from_json(json_output)
assert recovered_counterfactual_explanations is not None
self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs,
sample_custom_query_10.shape[0], version,
local_importance_available=True,
summary_importance_available=True)

assert recovered_counterfactual_explanations == counterfactual_explanations
assert recovered_counterfactual_explanations.metadata['version'] == version

assert len(recovered_counterfactual_explanations.cf_examples_list) == sample_custom_query_10.shape[0]
assert recovered_counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
assert recovered_counterfactual_explanations.local_importance is not None
assert counterfactual_explanations.summary_importance is not None
@pytest.mark.parametrize("version", ['1.0', '2.0'])
def test_empty_counterfactual_explanations_object(self, version):

counterfactual_explanations = CounterfactualExplanations(
cf_examples_list=[],
local_importance=None,
summary_importance=None,
version=version)
self.verify_counterfactual_explanations(counterfactual_explanations, None,
0, version)

counterfactual_explanations_as_json = counterfactual_explanations.to_json()
assert counterfactual_explanations_as_json is not None

recovered_counterfactual_explanations = CounterfactualExplanations.from_json(
counterfactual_explanations_as_json)

self.verify_counterfactual_explanations(recovered_counterfactual_explanations, None,
0, version)

assert counterfactual_explanations == recovered_counterfactual_explanations

@pytest.mark.parametrize('unsupported_version', ['3.0', ''])
def test_unsupported_versions_from_json(self, unsupported_version):
json_str = json.dumps({'metadata': {'version': unsupported_version}})
with pytest.raises(UserConfigValidationException) as ucve:
CounterfactualExplanations.from_json(json_str)

assert "Incompatible version {} found in json input".format(unsupported_version) in str(ucve)

json_str = json.dumps({'metadata': {'versio': unsupported_version}})
with pytest.raises(UserConfigValidationException) as ucve:
CounterfactualExplanations.from_json(json_str)

assert "No version field in the json input" in str(ucve)

@pytest.mark.parametrize('unsupported_version', ['3.0', ''])
def test_unsupported_versions_to_json(self, unsupported_version):
counterfactual_explanations = CounterfactualExplanations(
cf_examples_list=[],
local_importance=None,
summary_importance=None,
version=unsupported_version)

with pytest.raises(UserConfigValidationException) as ucve:
counterfactual_explanations.to_json()

assert "Unsupported serialization version {}".format(unsupported_version) in str(ucve)