In [None]:
from pathlib import Path
import tensorflow as tf
import io
import os
import numpy as np
from PIL import Image, ImageOps
from keras.preprocessing import image
import shap
from keras.applications.vgg16 import preprocess_input

MOUNTED_MODELS_ROOT = Path("/storage")
model_path = list(MOUNTED_MODELS_ROOT.glob("**/*.h5"))[0]


# Training config
INPUT_SIZE = (224, 224)  # Default input size for VGG16
CLASS_ENCODING = {"Maltese dog": 0, "Afghan hound": 1}
BATCH_SIZE = 32

model = tf.keras.models.load_model(model_path)
model.summary()

In [None]:
def img_to_numpy(im, target_size) -> np.ndarray:
    if isinstance(im, np.ndarray):
        return im
    if isinstance(im, bytes):
        img_pil = Image.open(io.BytesIO(im))
    if isinstance(im, (str, Path, io.BytesIO)):
        img_pil = Image.open(im)
        try:
            img_pil = ImageOps.exif_transpose(image)
        except Exception:
            pass
    if not isinstance(im, (str, Path, io.BytesIO, bytes, np.ndarray)):
        raise ValueError(f"Unexpected input type: {type(im)}")
    img_pil = img_pil.convert("RGB")
    img_pil = img_pil.resize(INPUT_SIZE, Image.NEAREST)
    return image.img_to_array(img_pil)

In [None]:
# add SHAP js code to the notebook
shap.initjs()

# select files for explanation
data_path = Path(os.environ["PROJECT"]) / "data" / "Images"
files_to_explain = list(data_path.glob("*.jpg"))[0:4]
inputs = np.array([img_to_numpy(file, target_size=INPUT_SIZE) for file in files_to_explain])

# define a masker that is used to mask out partitions of the input image.
masker = shap.maskers.Image("inpaint_telea", inputs[0].shape)

In [None]:
# define prediction function
def predict_fn(x):
    tmp = x.copy()
    x = preprocess_input(tmp)
    return model(tmp)

# create an instance of explainer
explainer = shap.Explainer(predict_fn, masker, output_names=list(CLASS_ENCODING.keys()))

In [None]:
# explain images using 500 evaluations of the underlying model to estimate the SHAP values
shap_values = explainer(inputs, max_evals=500, batch_size=BATCH_SIZE, outputs=shap.Explanation.argsort.flip[:2])

In [None]:
shap.image_plot(shap_values)