Skip to content

Commit

Permalink
test causal fidelity tabular
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidPetiteau committed Jan 25, 2022
1 parent 318a019 commit 7809cd8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
30 changes: 28 additions & 2 deletions tests/metrics/test_fidelity.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import tensorflow as tf
import numpy as np

from ..utils import generate_model, generate_timeseries_model, generate_data, almost_equal
from xplique.metrics import Insertion, Deletion, MuFidelity, InsertionTS, DeletionTS
from ..utils import generate_model, generate_timeseries_model, generate_regression_model, generate_data, almost_equal
from xplique.metrics import Insertion, Deletion, MuFidelity, InsertionTS, DeletionTS, InsertionTab, DeletionTab


def test_mu_fidelity():
Expand Down Expand Up @@ -70,6 +70,32 @@ def test_perturbation_metrics():
assert 0.0 <= score <= 1.0


def test_regression_metrics():
# ensure we can compute insertion/deletion metric with consistent arguments
input_shape, nb_labels, nb_samples = ((20, 10), 5, 50)
x, y = generate_data(input_shape, nb_labels, nb_samples)
model = generate_regression_model(input_shape, nb_labels)
explanations = np.random.uniform(0, 1, x.shape)

for step in [5, 10]:
for baseline_mode in [0.0, lambda x: x-0.5]:
for metric in ["loss", "accuracy"]:
score_insertion = InsertionTab(
model, x, y, metric=metric, baseline_mode=baseline_mode,
steps=step, max_percentage_perturbed=0.2,
)(explanations)
score_deletion = DeletionTab(
model, x, y, metric=metric, baseline_mode=baseline_mode,
steps=step, max_percentage_perturbed=0.2,
)(explanations)

for score in [score_insertion, score_deletion]:
if metric == "loss":
assert 0.0 < score
elif score == "accuracy":
assert 0.0 <= score <= 1.0


def test_perfect_correlation():
"""Ensure we get perfect score if the correlation is perfect"""
# we ensure perfect correlation if the model return the sum of the input,
Expand Down
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def generate_regression_model(features_shape, output_shape=1):
model.add(Dense(4, activation='relu'))
model.add(Dense(4, activation='relu'))
model.add(Dense(output_shape))
model.compile(loss='mean_absolute_error',
optimizer='sgd')
model.compile(loss='mean_absolute_error', optimizer='sgd',
metrics=['accuracy'])

return model

Expand Down

0 comments on commit 7809cd8

Please sign in to comment.