In [None]:
import json
import os
from glob import glob

import numpy as np
import rise
import torch
from captum import attr
from lime.lime_image import LimeImageExplainer
from matplotlib import pyplot as plt
from numpy import random as np_rand
from PIL import Image
from scipy import stats
from sklearn.metrics import classification_report
from sklearn.preprocessing import MinMaxScaler
from torch import nn
from torchvision import models, transforms
from torchvision.transforms import v2
from tqdm.auto import tqdm

In [None]:
DATA_PATH = f"./data/funny_birds/v2/test_CO/**/**/image.png"

## Load model

ResNET50 with a different head. Accuracy over ~0.90

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = models.resnet50(num_classes=50)

net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

num_ftrs = net.fc.in_features

net.fc = nn.Sequential(
    nn.Linear(num_ftrs, 128),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.7),
    nn.Linear(128, 47),
)


net.load_state_dict(torch.load("res_9.pt", weights_only=True))
net = net.to(device)

net = net.eval()

In [None]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Grayscale(),
        v2.ToDtype(torch.float),
    ]
)

In [None]:
preds = []
gts = []

for i, img_path in enumerate(
    tqdm(sorted(glob(f"./data/funny_birds/v2/test_CO/**/**/image.png")))
):
    folder_path, img_name = os.path.split(img_path)
    gt = int(folder_path.split(os.path.sep)[-2])

    img_pil = Image.open(img_path)
    img = transform(img_pil).unsqueeze(0)

    pred = int(torch.argmax(net(img.to(device))).cpu().detach().numpy())

    preds.append(pred)
    gts.append(gt)


print(classification_report(preds, gts))

## Get GT Relevance

Obtains the relevance of each part as:

$$
\text{Part}_{\text{imp}} = f(x) - f(x'),
$$

where $x'$ is the same image than $x$ without the part studied.

This $\text{Part}_{\text{imp}}$ is multiplied by the respective part. The addition of all part multiplied with their importance generates a saliency map GT, that is stored as a `.npy` file.

In [None]:
DATA_PATH = f"./data/funny_birds/v2/test_CO/**/**/image.png"

for img_path in tqdm(sorted(glob(DATA_PATH))):
    folder_path, img_name = os.path.split(img_path)
    cls = folder_path.split(os.path.sep)[-2]
    img_id = folder_path.split(os.path.sep)[-1]

    original_data = Image.open(img_path)

    # Prepare the data for the model
    original_data = transform(original_data).unsqueeze(0)

    output_org = net(original_data.to(device))[0][int(cls)]  # [batch][class]

    parts = {
        "wing": "body_beak_eye_foot_tail.png",
        "tail": "body_beak_eye_foot_wing.png",
        "foot": "body_beak_eye_tail_wing.png",
        "eye": "body_beak_foot_tail_wing.png",
        "beak": "body_eye_foot_tail_wing.png",
    }

    sal_map_gt = torch.zeros_like(original_data).float()
    res = dict()
    diff_arr = []
    for part, part_path in parts.items():
        path = os.path.join(folder_path, part_path)

        if not os.path.isfile(path):
            print(f"{path} - fora")

        data = Image.open(path)
        data = transform(data).unsqueeze(0)  # transforms.ToTensor()(data).unsqueeze(0)

        output = net(data.to(device))[0][int(cls)]  # [batch][class]
        res[part] = float((output_org - output).detach().cpu())

        diff = original_data - data
        diff = diff * res[part]

        diff_arr.append(diff)

        sal_map_gt = sal_map_gt + diff
    sal_map_gt = sal_map_gt.cpu().numpy()[0, 0, :, :]
    with open(f"./output/GT_resnet50/{img_id}.npy", "wb") as f:
        np.save(f, sal_map_gt)

In [None]:
plt.figure(figsize=(10, 10))
for i, a in enumerate(diff_arr):
    plt.subplot(1, 5, i + 1)
    plt.axis("off")
    plt.imshow(abs(a.cpu().numpy()[0, 0, :, :]));

# Get XAI

In [None]:
epsilon = 1e-5


def _to_probability(info):
    """Convert the input to a probability distribution.

    Args:
        info: NumPy array with the input to convert.

    Returns:
        NumPy array with the input converted to a probability distribution
    """
    info = np.copy(info)
    info_shape = info.shape
    scaler = MinMaxScaler()

    info = info.reshape(-1, 1)
    info = scaler.fit_transform(info)
    info = info.reshape(info_shape)

    return info / (np.sum(info) + epsilon)


def kl(sal_map_gt, sal_map):
    """Compute the Kullback-Leibler divergence between two saliency maps.

    Args:
        sal_map_gt: NumPy array with the ground truth saliency map.
        sal_map: NumPy array with the saliency map to compare.

    Returns:
        Float with the Kullback-Leibler divergence between the two saliency maps.
    """
    sal_map_gt = _to_probability(sal_map_gt)
    sal_map = _to_probability(sal_map)

    # You may want to instead make copies to avoid changing the np arrays.
    sal_map_gt = sal_map_gt + epsilon
    sal_map = sal_map + epsilon

    divergence = np.sum(sal_map_gt * np.log(sal_map_gt / sal_map))

    return divergence


def sim(sal_map_gt, sal_map):
    """Compute the sim distance between two saliency maps.

    Args:
        sal_map_gt: NumPy array with the ground truth saliency map.
        sal_map: NumPy array with the saliency map to compare.

    Returns:
        Float with the min distance between the two saliency maps.
    """
    sal_map_gt = _to_probability(sal_map_gt)
    sal_map = _to_probability(sal_map)

    diff = np.min(np.stack([sal_map, sal_map_gt]), axis=0)
    diff = np.sum(diff)

    return diff


def emd(sal_map_gt, sal_map):
    """Compute the Earth Mover's Distance between two saliency maps.

    Earth Mover's Distance (EMD) is a measure of the distance between two probability distributions over a region.
    It is defined as the minimum cost of turning one distribution into the other, where the cost is the amount of
    "earth" moved, or the amount of probability mass that must be moved from one point to another.

    Args:
        sal_map_gt: NumPy array with the ground truth saliency map.
        sal_map: NumPy array with the saliency map to compare.

    Returns:
        Float between 0 and 1 with the EMD between the two saliency maps.
    """
    sal_map_gt = _to_probability(sal_map_gt)
    sal_map = _to_probability(sal_map)

    sal_map_gt /= sal_map_gt.max() if sal_map_gt.max() > 0 else 1
    sal_map /= sal_map.max() if sal_map.max() > 0 else 1

    diff = stats.wasserstein_distance(sal_map.flatten(), sal_map_gt.flatten())

    return diff


def _to_zero_one(info):
    return (info - info.min()) / (info.max() - info.min())


def AUC_Borji(sal_map_gt, sal_map, n_rep=100, step_size=0.1, rand_sampler=None):
    """
    This measures how well the saliency map of an image predicts the ground truth human fixations on the image.
    ROC curve created by sweeping through threshold values at fixed step size until the maximum saliency map value.

    True positive (tp) rate correspond to the ratio of saliency map values above threshold at fixation locations
    to the total number of fixation locations.

    False positive (fp) rate correspond to the ratio of saliency map values above threshold at random locations to
    the total number of random locations (as many random locations as fixations, sampled uniformly from fixation_map
    ALL IMAGE PIXELS), averaging over n_rep number of selections of random locations.

    Parameters
    ----------
    saliency_map : real-valued matrix
    fixation_map : binary matrix
        Human fixation map.
    n_rep : int, optional
        Number of repeats for random sampling of non-fixated locations.
    step_size : int, optional
        Step size for sweeping through saliency map.
    rand_sampler : callable
        S_rand = rand_sampler(S, F, n_rep, n_fix)
        Sample the saliency map at random locations to estimate false positive.
        Return the sampled saliency values, S_rand.shape=(n_fix,n_rep)
    Returns
    -------
    AUC : float, between [0,1]
    """
    sal_map_gt = _to_zero_one(sal_map_gt)
    sal_map = _to_zero_one(sal_map)

    saliency_map = np.asarray(sal_map)
    fixation_map = np.asarray(sal_map_gt) > 0.5
    # If there are no fixation to predict, return NaN
    if not np.any(fixation_map):
        print("no fixation to predict")
        return np.nan
    # Normalize saliency map to have values between [0,1]
    # saliency_map = _to_probability(saliency_map)

    S = saliency_map.ravel()
    F = fixation_map.ravel()
    S_fix = S[F]  # Saliency map values at fixation locations
    n_fix = len(S_fix)
    n_pixels = len(S)
    # For each fixation, sample n_rep values from anywhere on the saliency map
    if rand_sampler is None:
        r = np_rand.randint(0, n_pixels, [n_fix, n_rep])
        S_rand = S[
            r
        ]  # Saliency map values at random locations (including fixated locations!? underestimated)
    else:
        S_rand = rand_sampler(S, F, n_rep, n_fix)
    # Calculate AUC per random split (set of random locations)
    auc = np.zeros(n_rep) * np.nan
    for rep in range(n_rep):
        thresholds = np.r_[0 : np.max(np.r_[S_fix, S_rand[:, rep]]) : step_size][::-1]
        tp = np.zeros(len(thresholds) + 2)
        fp = np.zeros(len(thresholds) + 2)
        tp[0] = 0
        tp[-1] = 1
        fp[0] = 0
        fp[-1] = 1
        for k, thresh in enumerate(thresholds):
            tp[k + 1] = np.sum(S_fix >= thresh) / float(n_fix)
            fp[k + 1] = np.sum(S_rand[:, rep] >= thresh) / float(n_fix)
        auc[rep] = np.trapezoid(tp, fp)
    return np.mean(auc)  # Average across random splits


metrics = {"emd": emd, "kl": kl, "sim": sim, "auc": AUC_Borji}

## Methods

### RISE

In [None]:
explainer_rise = rise.RISE(net, (256, 256), gpu_batch=1, device=device)
explainer_rise.generate_masks(N=600, s=8, p1=0.1, savepath="masks.npy")

### LIME

In [None]:
explainer_lime = LimeImageExplainer()


def batch_predict(
    image: np.array, network, multi_channel: bool = False, n_classes=None
) -> np.array:
    """Function to predict the output of the network for a batch of images.

    Args:
        image: NumPy array of shape (n, m, 3) with the image.
        network: Callable function to predict the output of the network.

    Returns:
        NumPy array with the output of the network.
    """
    if n_classes is None:
        n_classes = 1

    image = np.copy(image)
    image = np.transpose(image, (0, 3, 1, 2))

    if not multi_channel:
        image = image[:, 0:1, :, :]
    output = network(image).reshape((-1, n_classes))

    return output

In [None]:
def get_lime(img, label, *args, **kwargs):
    explanation = explainer_lime.explain_instance(
        img[0, 0, :, :],
        lambda x: batch_predict(
            x,
            lambda x: net(torch.from_numpy(x.astype(np.float32)).to(device))
            .detach()
            .cpu()
            .numpy(),
            multi_channel=False,
            n_classes=47,
        ),
        num_samples=1500,
        batch_size=1,
        random_seed=42,
        progress_bar=False,
        hide_color=0,
    )

    mask = np.zeros(
        (explanation.segments.shape[0], explanation.segments.shape[1]),
        dtype=np.float64,
    )

    lime_res = []
    for key, val in explanation.local_exp[label]:
        if key != 0:
            mask[explanation.segments == key] = abs(val)
    lime_res.append(mask)
    lime_res = np.array(lime_res)

    return lime_res

### Gradient

In [None]:
def grad_fn(x, y, xai):
    res = xai.attribute(x.to(device), target=y)[:, 0, :, :]

    return res


sal = attr.Saliency(net)

### SHAP

In [None]:
kernel_shap = attr.KernelShap(net)

### IG

In [None]:
ig = attr.IntegratedGradients(net)

### SHAP

In [None]:
kernel_shap = attr.KernelShap(net)

## Results

In [None]:
RESULTS_PATH = "results_binary_img.json"

In [None]:
deep_lift = attr.DeepLift(net)

In [None]:
methods_fn = {
    "dl": lambda x, y: grad_fn(x, y, deep_lift),
    "shap": lambda x, y: kernel_shap.attribute(x.to(device), target=y, n_samples=200),
    "rise": lambda img, pred: explainer_rise(img)[pred],
    "lime": get_lime,
    "grad": lambda x, y: grad_fn(x, y, sal),
    "ig": lambda x, y: grad_fn(x, y, ig),
}

### Calculate XAI

In [None]:
for method_name, method in methods_fn.items():
    i = 0
    out_folder_path = os.path.join("results", "bin_img", method_name)
    os.makedirs(out_folder_path, exist_ok=True)
    for i, img_path in enumerate(
        tqdm(
            sorted(glob(f"./data/funny_birds/v2/test_CO/**/**/image.png")),
            desc=method_name,
        )
    ):
        folder_path, img_name = os.path.split(img_path)
        img_id = folder_path.split(os.path.sep)[-1]

        img_pil = Image.open(img_path)
        img = transform(img_pil).unsqueeze(0)

        cls_prediction = int(torch.argmax(net(img.to(device))).detach().cpu().numpy())
        explanation = method(img.to(device), cls_prediction)

        if isinstance(explanation, torch.Tensor):
            explanation = explanation.detach().cpu().numpy()

        with open(os.path.join(out_folder_path, f"{img_id}.npy"), "wb") as f:
            np.save(f, explanation)

### Calculate measures

In [None]:
results = dict()


for method_name, method in methods_fn.items():
    results_method = {k: [] for k in metrics.keys()}
    out_folder_path = os.path.join("results", "bin_img", method_name)

    for pred_path, img_path in zip(
        tqdm(sorted(glob(os.path.join(out_folder_path, f"*.npy"))), desc=method_name),
        sorted(glob(f"./data/funny_birds/v2/test_CO/**/**/image.png")),
    ):
        _, image_id = os.path.split(pred_path)

        gt_xai = np.load(f"./output/GT_resnet50/{image_id}")
        pred_xai = np.load(pred_path)

        gtai = int(os.path.split(img_path)[0].split(os.path.sep)[-2])

        img_pil = Image.open(img_path)
        img = transform(img_pil).unsqueeze(0)

        pred_ai = int(torch.argmax(net(img.to(device))).cpu().detach().numpy())

        if pred_ai != gtai:
            continue

        for metric_name, metric_fn in metrics.items():
            res = metric_fn(gt_xai.flatten(), pred_xai.flatten())
            results_method[metric_name].append(float(res))
    results[method_name] = results_method

with open(RESULTS_PATH, "w") as f:
    json.dump(results, f)

In [None]:
with open(RESULTS_PATH) as f:
    results = json.load(f)

for method_name, method_info in results.items():
    print(method_name.upper())
    for k, v in method_info.items():
        if k == "emd":
            continue
        print(f"{k}: {np.nanmean(v)} - {np.nanstd(v)}")
    print("-" * 25)