diff --git a/tests/metrics/test_fidelity.py b/tests/metrics/test_fidelity.py index 5ff4f202..eeaf8e71 100644 --- a/tests/metrics/test_fidelity.py +++ b/tests/metrics/test_fidelity.py @@ -4,6 +4,7 @@ from ..utils import generate_model, generate_timeseries_model, generate_data, almost_equal from xplique.metrics import Insertion, Deletion, MuFidelity, InsertionTS, DeletionTS + def test_mu_fidelity(): # ensure we can compute the metric with consistents arguments input_shape, nb_labels, nb_samples = ((32, 32, 3), 10, 20) @@ -50,20 +51,23 @@ def test_perturbation_metrics(): model = generate_timeseries_model(input_shape, nb_labels) explanations = np.random.uniform(0, 1, x.shape) - for step in [-1, 2, 10]: - for max_percentage_perturbed in [0.2, 1.0]: - for baseline_mode in [0.0, "zero", "inverse", "negative"]: + for step in [-1, 10]: + for baseline_mode in [0.0, "inverse"]: + for metric in ["loss", "accuracy"]: score_insertion = InsertionTS( - model, x, y, metric="loss", baseline_mode=baseline_mode, - steps=step, max_percentage_perturbed=max_percentage_perturbed, + model, x, y, metric=metric, baseline_mode=baseline_mode, + steps=step, max_percentage_perturbed=0.2, )(explanations) score_deletion = DeletionTS( - model, x, y, metric="loss", baseline_mode=baseline_mode, - steps=step, max_percentage_perturbed=max_percentage_perturbed, + model, x, y, metric=metric, baseline_mode=baseline_mode, + steps=step, max_percentage_perturbed=0.2, )(explanations) for score in [score_insertion, score_deletion]: - assert 0.0 < score < 1 + if metric == "loss": + assert 0.0 < score + elif score == "accuracy": + assert 0.0 <= score <= 1.0 def test_perfect_correlation(): diff --git a/tests/utils.py b/tests/utils.py index 6ca74a04..c3f831f3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -30,7 +30,8 @@ def generate_timeseries_model(input_shape=(20, 10), output_shape=10): model.add(GlobalAveragePooling1D()) model.add(Dense(output_shape)) model.add(Activation('softmax')) - model.compile(loss='categorical_crossentropy', optimizer='sgd') + model.compile(loss='categorical_crossentropy', optimizer='sgd', + metrics=['accuracy']) return model