From 4ce843d63574b907faa66621c90093f3fb811396 Mon Sep 17 00:00:00 2001 From: gaugup Date: Fri, 23 Apr 2021 00:07:34 -0700 Subject: [PATCH] Remove unused method and refactor tests Signed-off-by: gaugup --- dice_ml/counterfactual_explanations.py | 11 +- tests/test_counterfactual_explanations.py | 177 ++++++++++++---------- 2 files changed, 97 insertions(+), 91 deletions(-) diff --git a/dice_ml/counterfactual_explanations.py b/dice_ml/counterfactual_explanations.py index 5acb3b0f..7e19376a 100644 --- a/dice_ml/counterfactual_explanations.py +++ b/dice_ml/counterfactual_explanations.py @@ -67,13 +67,6 @@ def __eq__(self, other_cf): self.metadata == other_cf.metadata return False - @property - def __dict__(self): - return {'cf_examples_list': self.cf_examples_list, - 'local_importance': self.local_importance, - 'summary_importance': self.summary_importance, - 'metadata': self.metadata} - @property def cf_examples_list(self): return self._cf_examples_list @@ -220,8 +213,8 @@ def to_json(self): entire_dict, version=serialization_version) return json.dumps(entire_dict) else: - raise Exception("Unsupported serialization version {}".format( - serialization_version)) + raise UserConfigValidationException( + "Unsupported serialization version {}".format(serialization_version)) @staticmethod def from_json(json_str): diff --git a/tests/test_counterfactual_explanations.py b/tests/test_counterfactual_explanations.py index 1f7ea1f1..72140389 100644 --- a/tests/test_counterfactual_explanations.py +++ b/tests/test_counterfactual_explanations.py @@ -10,32 +10,6 @@ class TestCounterfactualExplanations: - @pytest.mark.parametrize("version", ['1.0', '2.0']) - def test_serialization_deserialization_counterfactual_explanations_class(self, version): - - counterfactual_explanations = CounterfactualExplanations( - cf_examples_list=[], - local_importance=None, - summary_importance=None, - version=version) - assert counterfactual_explanations.cf_examples_list is not None - assert len(counterfactual_explanations.cf_examples_list) == 0 - assert counterfactual_explanations.summary_importance is None - assert counterfactual_explanations.local_importance is None - assert counterfactual_explanations.metadata is not None - assert counterfactual_explanations.metadata['version'] is not None - assert counterfactual_explanations.metadata['version'] == version - - counterfactual_explanations_as_json = counterfactual_explanations.to_json() - assert counterfactual_explanations_as_json is not None - - recovered_counterfactual_explanations = CounterfactualExplanations.from_json( - counterfactual_explanations_as_json) - - assert recovered_counterfactual_explanations is not None - assert recovered_counterfactual_explanations.metadata['version'] == version - assert counterfactual_explanations == recovered_counterfactual_explanations - def test_sorted_summary_importance_counterfactual_explanations(self): unsorted_summary_importance = { @@ -132,21 +106,6 @@ def test_sorted_local_importance_counterfactual_explanations(self): assert list(unsorted_local_importance[index].keys()) != list(counterfactual_explanations.local_importance[index].keys()) assert list(sorted_local_importance[index].keys()) == list(counterfactual_explanations.local_importance[index].keys()) - @pytest.mark.parametrize('version', ['3.0', '']) - def test_unsupported_versions_json_input(self, version): - json_str = json.dumps({'metadata': {'version': version}}) - with pytest.raises(UserConfigValidationException) as ucve: - CounterfactualExplanations.from_json(json_str) - - assert "Incompatible version {} found in json input".format(version) in str(ucve) - - json_str = json.dumps({'metadata': {'versio': version}}) - with pytest.raises(UserConfigValidationException) as ucve: - CounterfactualExplanations.from_json(json_str) - - assert "No version field in the json input" in str(ucve) - - @pytest.fixture def random_binary_classification_exp_object(): @@ -165,8 +124,31 @@ def _initiate_exp_object(self, random_binary_classification_exp_object): self.exp = random_binary_classification_exp_object # explainer object self.data_df_copy = self.exp.data_interface.data_df.copy() - @pytest.mark.parametrize("desired_class, total_CFs", [(0, 2)]) @pytest.mark.parametrize("version", ['1.0', '2.0']) + def verify_counterfactual_explanations(self, counterfactual_explanations, + total_CFs, num_query_points, version, + local_importance_available=False, + summary_importance_available=False): + assert counterfactual_explanations is not None + assert counterfactual_explanations.cf_examples_list is not None + assert len(counterfactual_explanations.cf_examples_list) == num_query_points + if total_CFs is not None: + assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs + assert counterfactual_explanations.metadata is not None + assert counterfactual_explanations.metadata['version'] is not None + counterfactual_explanations.metadata['version'] = version + if local_importance_available: + assert counterfactual_explanations.local_importance is not None + assert len(counterfactual_explanations.local_importance) == num_query_points + else: + assert counterfactual_explanations.local_importance is None + if summary_importance_available: + assert counterfactual_explanations.summary_importance is not None + else: + assert counterfactual_explanations.summary_importance is None + + @pytest.mark.parametrize("version", ['1.0', '2.0']) + @pytest.mark.parametrize("desired_class, total_CFs", [(0, 2)]) def test_random_counterfactual_explanations_output(self, desired_class, sample_custom_query_1, total_CFs, version): @@ -174,81 +156,112 @@ def test_random_counterfactual_explanations_output(self, desired_class, query_instances=sample_custom_query_1, desired_class=desired_class, total_CFs=total_CFs) - assert counterfactual_explanations is not None - assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0] - assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs - assert counterfactual_explanations.local_importance is None - assert counterfactual_explanations.summary_importance is None + self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs, + sample_custom_query_1.shape[0], version) - counterfactual_explanations.metadata['version'] = version json_output = counterfactual_explanations.to_json() assert json_output is not None assert json.loads(json_output).get('metadata').get('version') == version recovered_counterfactual_explanations = CounterfactualExplanations.from_json(json_output) - assert recovered_counterfactual_explanations is not None - assert recovered_counterfactual_explanations == counterfactual_explanations - assert recovered_counterfactual_explanations.metadata['version'] == version + self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs, + sample_custom_query_1.shape[0], version) - assert len(recovered_counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0] - assert recovered_counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs - assert recovered_counterfactual_explanations.local_importance is None - assert recovered_counterfactual_explanations.summary_importance is None + assert recovered_counterfactual_explanations == counterfactual_explanations - @pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)]) @pytest.mark.parametrize("version", ['1.0', '2.0']) + @pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)]) def test_random_local_importance_output(self, desired_class, sample_custom_query_1, total_CFs, version): counterfactual_explanations = self.exp.local_feature_importance( query_instances=sample_custom_query_1, desired_class=desired_class, total_CFs=total_CFs) - assert counterfactual_explanations is not None - assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0] - assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs - assert counterfactual_explanations.local_importance is not None - assert counterfactual_explanations.summary_importance is None + self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs, + sample_custom_query_1.shape[0], version, + local_importance_available=True) - counterfactual_explanations.metadata['version'] = version json_output = counterfactual_explanations.to_json() assert json_output is not None assert json.loads(json_output).get('metadata').get('version') == version recovered_counterfactual_explanations = CounterfactualExplanations.from_json(json_output) - assert recovered_counterfactual_explanations is not None - assert recovered_counterfactual_explanations == counterfactual_explanations - assert recovered_counterfactual_explanations.metadata['version'] == version + self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs, + sample_custom_query_1.shape[0], version, + local_importance_available=True) - assert len(recovered_counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0] - assert recovered_counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs - assert recovered_counterfactual_explanations.local_importance is not None - assert counterfactual_explanations.summary_importance is None + assert recovered_counterfactual_explanations == counterfactual_explanations - @pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)]) @pytest.mark.parametrize("version", ['1.0', '2.0']) + @pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)]) def test_random_summary_importance_output(self, desired_class, sample_custom_query_10, total_CFs, version): counterfactual_explanations = self.exp.global_feature_importance( query_instances=sample_custom_query_10, desired_class=desired_class, total_CFs=total_CFs) - assert counterfactual_explanations is not None - assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_10.shape[0] - assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs - assert counterfactual_explanations.local_importance is not None - assert counterfactual_explanations.summary_importance is not None + self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs, + sample_custom_query_10.shape[0], version, + local_importance_available=True, + summary_importance_available=True) - counterfactual_explanations.metadata['version'] = version json_output = counterfactual_explanations.to_json() assert json_output is not None assert json.loads(json_output).get('metadata').get('version') == version recovered_counterfactual_explanations = CounterfactualExplanations.from_json(json_output) - assert recovered_counterfactual_explanations is not None + self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs, + sample_custom_query_10.shape[0], version, + local_importance_available=True, + summary_importance_available=True) + assert recovered_counterfactual_explanations == counterfactual_explanations - assert recovered_counterfactual_explanations.metadata['version'] == version - assert len(recovered_counterfactual_explanations.cf_examples_list) == sample_custom_query_10.shape[0] - assert recovered_counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs - assert recovered_counterfactual_explanations.local_importance is not None - assert counterfactual_explanations.summary_importance is not None + @pytest.mark.parametrize("version", ['1.0', '2.0']) + def test_empty_counterfactual_explanations_object(self, version): + + counterfactual_explanations = CounterfactualExplanations( + cf_examples_list=[], + local_importance=None, + summary_importance=None, + version=version) + self.verify_counterfactual_explanations(counterfactual_explanations, None, + 0, version) + + counterfactual_explanations_as_json = counterfactual_explanations.to_json() + assert counterfactual_explanations_as_json is not None + + recovered_counterfactual_explanations = CounterfactualExplanations.from_json( + counterfactual_explanations_as_json) + + self.verify_counterfactual_explanations(recovered_counterfactual_explanations, None, + 0, version) + + assert counterfactual_explanations == recovered_counterfactual_explanations + + @pytest.mark.parametrize('unsupported_version', ['3.0', '']) + def test_unsupported_versions_from_json(self, unsupported_version): + json_str = json.dumps({'metadata': {'version': unsupported_version}}) + with pytest.raises(UserConfigValidationException) as ucve: + CounterfactualExplanations.from_json(json_str) + + assert "Incompatible version {} found in json input".format(unsupported_version) in str(ucve) + + json_str = json.dumps({'metadata': {'versio': unsupported_version}}) + with pytest.raises(UserConfigValidationException) as ucve: + CounterfactualExplanations.from_json(json_str) + + assert "No version field in the json input" in str(ucve) + + @pytest.mark.parametrize('unsupported_version', ['3.0', '']) + def test_unsupported_versions_to_json(self, unsupported_version): + counterfactual_explanations = CounterfactualExplanations( + cf_examples_list=[], + local_importance=None, + summary_importance=None, + version=unsupported_version) + + with pytest.raises(UserConfigValidationException) as ucve: + counterfactual_explanations.to_json() + + assert "Unsupported serialization version {}".format(unsupported_version) in str(ucve)