Skip to content

Commit

Permalink
test fidelity: adapt to harmonization
Browse files Browse the repository at this point in the history
  • Loading branch information
AntoninPoche committed Dec 10, 2021
1 parent edcb5c4 commit 03d3849
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
20 changes: 12 additions & 8 deletions tests/metrics/test_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ..utils import generate_model, generate_timeseries_model, generate_data, almost_equal
from xplique.metrics import Insertion, Deletion, MuFidelity, InsertionTS, DeletionTS


def test_mu_fidelity():
# ensure we can compute the metric with consistents arguments
input_shape, nb_labels, nb_samples = ((32, 32, 3), 10, 20)
Expand Down Expand Up @@ -50,20 +51,23 @@ def test_perturbation_metrics():
model = generate_timeseries_model(input_shape, nb_labels)
explanations = np.random.uniform(0, 1, x.shape)

for step in [-1, 2, 10]:
for max_percentage_perturbed in [0.2, 1.0]:
for baseline_mode in [0.0, "zero", "inverse", "negative"]:
for step in [-1, 10]:
for baseline_mode in [0.0, "inverse"]:
for metric in ["loss", "accuracy"]:
score_insertion = InsertionTS(
model, x, y, metric="loss", baseline_mode=baseline_mode,
steps=step, max_percentage_perturbed=max_percentage_perturbed,
model, x, y, metric=metric, baseline_mode=baseline_mode,
steps=step, max_percentage_perturbed=0.2,
)(explanations)
score_deletion = DeletionTS(
model, x, y, metric="loss", baseline_mode=baseline_mode,
steps=step, max_percentage_perturbed=max_percentage_perturbed,
model, x, y, metric=metric, baseline_mode=baseline_mode,
steps=step, max_percentage_perturbed=0.2,
)(explanations)

for score in [score_insertion, score_deletion]:
assert 0.0 < score < 1
if metric == "loss":
assert 0.0 < score
elif score == "accuracy":
assert 0.0 <= score <= 1.0


def test_perfect_correlation():
Expand Down
3 changes: 2 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def generate_timeseries_model(input_shape=(20, 10), output_shape=10):
model.add(GlobalAveragePooling1D())
model.add(Dense(output_shape))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='sgd')
model.compile(loss='categorical_crossentropy', optimizer='sgd',
metrics=['accuracy'])

return model

Expand Down

0 comments on commit 03d3849

Please sign in to comment.