Skip to content

Commit

Permalink
tests: introduce Craft tests
Browse files Browse the repository at this point in the history
Signed-off-by: Frederic Boisnard <frederic.boisnard@irt-saintexupery.com>
  • Loading branch information
fredericboisnard committed Oct 9, 2023
1 parent 3327cac commit a5a8055
Show file tree
Hide file tree
Showing 5 changed files with 613 additions and 2 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ deps =
tf211: tensorflow ~= 2.11.0
-rrequirements.txt
commands =
pytest --cov=xplique --ignore=xplique/wrappers/pytorch.py --ignore=tests/wrappers/test_pytorch_wrapper.py {posargs}
pytest --cov=xplique --ignore=xplique/wrappers/pytorch.py --ignore=tests/wrappers/test_pytorch_wrapper.py --ignore=tests/concepts/test_craft_torch.py {posargs}

[testenv:py{38,39,310}-tf{25,28,211}-torch{111,113,200}]
deps =
Expand All @@ -60,7 +60,7 @@ deps =
torch200: torch
-rrequirements.txt
commands =
pytest --cov=xplique/wrappers/pytorch tests/wrappers/test_pytorch_wrapper.py
pytest --cov=xplique/wrappers/pytorch tests/wrappers/test_pytorch_wrapper.py tests/concepts/test_craft_torch.py

[mypy]
check_untyped_defs = True
Expand Down
245 changes: 245 additions & 0 deletions tests/concepts/test_craft_tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import numpy as np
import tensorflow as tf
import pytest
import os
import urllib.request
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Conv1D, Conv2D, Activation, GlobalAveragePooling1D, Dropout, Flatten, MaxPooling2D, Input
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam

from xplique.concepts import CraftTf as Craft
from ..utils import generate_data, generate_model, generate_txt_images_data
from ..utils_functions import google_drive_helpers


def test_shape():
"""Ensure the output shape is correct"""

input_shapes = [(32, 32, 3), (32, 32, 1), (64, 64, 3), (64, 32, 3)]
nb_labels = 3
nb_samples = 100

for input_shape in input_shapes:
# Generate a fake dataset
x, y = generate_data(input_shape, nb_labels, nb_samples)
model = generate_model(input_shape, nb_labels)
model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(0.1))

# Cut the model in two parts (as explained in the paper)
# First part is g(.) our 'input_to_latent' model,
# second part is h(.) our 'latent_to_logit' model
#
# Use 2 different set of indexes to check the behavior for different activation
# shapes ; (0, 1) produces an activation.shape of 4 dims, (-1, -1) leads to an
# activation.shape of 2 dims for index_layer_g, index_layer_h in [(0, 1), (-1, -1)]

for index_layer_g, index_layer_h in [(-1, -1)]:
g = tf.keras.Model(model.input, model.layers[index_layer_g].output)
h = tf.keras.Model(model.layers[index_layer_h].input, model.layers[-1].output)

# The activations must be positives
assert np.all(g(x) >= 0.0)

# Initialize Craft
number_of_concepts = 10
patch_size = 15
craft = Craft(input_to_latent_model = g,
latent_to_logit_model = h,
number_of_concepts = number_of_concepts,
patch_size = patch_size,
batch_size = 64)

# Now we can fit the concept using our images
# Focus on class id 0
class_id = 0
images_preprocessed = x[y.argmax(1)==class_id] # select only images of class 'class_id'
crops, crops_u, w = craft.fit(images_preprocessed, class_id)

# Checking shape of crops, crops_u, w
assert crops.shape[1] == crops.shape[2] == patch_size # Check patch sizes
assert crops.shape[0] == crops_u.shape[0] # Check numbers of patches
assert crops_u.shape[1] == w.shape[0] == number_of_concepts

# Importance estimation
importances = craft.estimate_importance()
assert len(importances) == number_of_concepts

# Checking the results of transform()
images_u = craft.transform(images_preprocessed)
if len(images_u.shape) == 4:
assert images_u.shape == (images_preprocessed.shape[0],
images_preprocessed.shape[1]-1,
images_preprocessed.shape[2]-1,
number_of_concepts)
elif len(images_u.shape) == 2:
assert images_u.shape == (images_preprocessed.shape[0], number_of_concepts)
else:
assert False # Should not happen

def test_wrong_layers():
"""Ensure that Craft complains when the input models are incompatible"""

input_shapes = [(32, 32, 3)]
nb_labels = 3

for input_shape in input_shapes:
# Generate a fake dataset
model = generate_model(input_shape, nb_labels)
model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(0.1))

g = tf.keras.Model(model.input, model.layers[0].output)
h = lambda x: 2*x

# Initialize Craft
number_of_concepts = 10
patch_size = 15
with pytest.raises(TypeError):
Craft(input_to_latent_model = g,
latent_to_logit_model = h,
number_of_concepts = number_of_concepts,
patch_size = patch_size,
batch_size = 64)

def test_classifier():
""" Check the Craft results on a small fake dataset """

input_shape = (64, 64, 3)
nb_labels = 3
nb_samples = 200

# Create a dataset of 'ABC', 'BCD', 'CDE' images
x, y, nb_samples, _ = generate_txt_images_data(input_shape, nb_labels, nb_samples)

# train a small classifier on the dataset
def create_classifier_model(input_shape=(64, 64, 3), output_shape=10):
model = Sequential()
model.add(Input(shape=input_shape))
model.add(Conv2D(6, kernel_size=(2, 2)))
model.add(Activation('relu'))
model.add(Conv2D(6, kernel_size=(2, 2)))
model.add(Activation('relu'))
model.add(Conv2D(6, kernel_size=(2, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(output_shape))
model.add(Activation('softmax'))
opt = Adam(learning_rate=0.005)
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

return model

model = create_classifier_model(input_shape, nb_labels)

tf.keras.utils.set_random_seed(0)
tf.config.experimental.enable_op_determinism()

# Retrieve checkpoints
checkpoint_path = "tests/concepts/checkpoints/classifier_test_craft_tf.ckpt"
if not os.path.exists(f"{checkpoint_path}.index"):
os.makedirs("tests/concepts/checkpoints/", exist_ok=True)
identifier = "1NLA7x2EpElzEEmyvFQhD6VS6bMwS_bCs"
google_drive_helpers.download_file(identifier, f"{checkpoint_path}.index")

identifier = "1wDi-y9b-3I_a-ZtqRlfuib-D7Ox4j8pX"
google_drive_helpers.download_file(identifier, f"{checkpoint_path}.data-00000-of-00001")

model.load_weights(checkpoint_path)

acc = np.sum(np.argmax(model(x), axis=1) == np.argmax(y, axis=1)) / nb_samples
assert acc == 1.0

# cut the model in two parts (as explained in the paper)
# first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model
g = tf.keras.Model(model.input, model.layers[-4].output)
h = tf.keras.Model(model.layers[-3].input, model.layers[-1].output)

assert np.all(g(x) >= 0.0)

# Init Craft on the full dataset
craft = Craft(input_to_latent_model = g,
latent_to_logit_model = h,
number_of_concepts = 3,
patch_size = 12,
batch_size = 32)

# Expected best crop for class 0 (ABC) is AB
AB_str = """
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 1 1 1 1 1 1 1 1
0 0 0 0 0 0 1 0 0 0 1 1
1 0 0 0 0 0 1 0 0 0 0 1
1 0 0 0 0 0 1 0 0 0 1 1
1 0 0 0 0 0 1 1 1 1 1 1
1 1 0 0 0 0 1 0 0 0 0 1
0 1 0 0 0 0 1 0 0 0 0 0
0 1 1 0 0 0 1 0 0 0 0 1
1 1 1 1 1 1 1 1 1 1 1 1
0 0 0 0 0 0 0 0 0 0 0 0
"""
AB = np.genfromtxt(AB_str.splitlines())

# Expected best crop for class 1 (BCD) is BC
BC_str = """
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
1 1 1 1 1 1 0 0 0 0 1 1
1 0 0 0 1 1 0 0 0 1 1 0
1 0 0 0 0 1 0 0 0 1 0 0
1 0 0 0 1 1 0 0 0 1 0 0
1 1 1 1 1 1 0 0 0 1 0 0
1 0 0 0 0 1 1 0 0 1 0 0
1 0 0 0 0 0 1 0 0 1 0 0
1 0 0 0 0 1 1 0 0 1 1 0
1 1 1 1 1 1 0 0 0 0 1 1
"""
BC = np.genfromtxt(BC_str.splitlines())

# Expected best crop for class 2 (CDE) is DE
DE_str = """
0 0 0 0 0 0 0 0 0 0 0 0
1 0 0 1 1 1 1 1 1 1 1 0
1 1 0 0 0 1 0 0 0 0 1 0
0 1 0 0 0 1 0 0 0 0 1 0
0 1 1 0 0 1 0 0 1 0 0 0
0 1 1 0 0 1 1 1 1 0 0 0
0 1 1 0 0 1 0 0 1 0 0 0
0 1 0 0 0 1 0 0 0 0 1 1
1 1 0 0 0 1 0 0 0 0 1 1
1 0 0 1 1 1 1 1 1 1 1 1
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
"""
DE = np.genfromtxt(DE_str.splitlines())

expected_best_crops = [AB, BC, DE]

# Run 3 Craft studies on each class, and in each case check if the best crop is the expected one
for class_id in range(3):
# Focus on class class_id
# Selecting subset for class {class_id} : {labels_str[class_id]}'
x_subset = x[np.argmax(y, axis=1)==class_id,:,:,:]

# fit craft on the selected class
crops, crops_u, w = craft.fit(x_subset, class_id)

# compute importances
importances = craft.estimate_importance()
assert importances[0] > 0.8

# find the best crop and compare it to the expected best crop
most_important_concepts = np.argsort(importances)[::-1]

# Find the best crop for the most important concept
c_id = most_important_concepts[0]
best_crops_ids = np.argsort(crops_u[:, c_id])[::-1]
best_crop = np.array(crops)[best_crops_ids[0]]

# Compare this best crop to the expectation
predicted_best_crop = np.where(best_crop.sum(axis=2) > 0.25, 1, 0)
expected_best_crop = expected_best_crops[class_id].astype(np.uint8)

check = np.all(expected_best_crop == predicted_best_crop)
assert check

0 comments on commit a5a8055

Please sign in to comment.