In [None]:
import os
from glob import glob
import numpy as np

from PIL import Image
from scipy import stats
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from sklearn.preprocessing import MinMaxScaler

import rise
import torch
from torch import nn
from torchvision.transforms import v2
from torchvision import transforms, models

In [None]:
DATA_PATH = f"./data/funny_birds/v2/test_CO/**/**/image.png"

## Load model

ResNET50 with a different head. Accuracy over ~0.90

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = models.resnet50(num_classes = 50)

net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,
                           bias=False)

num_ftrs = net.fc.in_features

net.fc = nn.Sequential(
    nn.Linear(num_ftrs, 128),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.7),
    nn.Linear(128, 47),
)


net.load_state_dict(torch.load('res_9.pt', weights_only=True))
net = net.to(device)

net = net.eval()

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Grayscale(),
    v2.ToDtype(torch.float),
])

## Get GT Relevance

Obtains the relevance of each part as:

$$
\text{Part}_{\text{imp}} = f(x) - f(x'),
$$

where $x'$ is the same image than $x$ without the part studied.

This $\text{Part}_{\text{imp}}$ is multiplied by the respective part. The addition of all part multiplied with their importance generates a saliency map GT, that is stored as a `.npy` file.

In [None]:

for img_path in tqdm(sorted(glob(DATA_PATH))):
    folder_path, img_name = os.path.split(img_path)
    cls = folder_path.split(os.path.sep)[-2]
    img_id = folder_path.split(os.path.sep)[-1]
    
    original_data = Image.open(img_path)

    # Prepare the data for the model
    original_data = transform(original_data).unsqueeze(0)
    
    output_org = net(original_data.to(device))[0][int(cls)] # [batch][class]
    
    parts = {
        "wing": "body_beak_eye_foot_tail.png", 
        "tail": "body_beak_eye_foot_wing.png",
        "foot": "body_beak_eye_tail_wing.png",
        "eye": "body_beak_foot_tail_wing.png",
        "beak": "body_eye_foot_tail_wing.png",
    }

    sal_map_gt = torch.zeros_like(original_data).float()
    res = dict()
    diff_arr = []
    for part, part_path in parts.items():
        path = os.path.join(folder_path, part_path)
        
        if not os.path.isfile(path):
            print(f"{path} - fora")
        
        data = Image.open(path)
        data =  transform(data).unsqueeze(0) # transforms.ToTensor()(data).unsqueeze(0)

        output = net(data.to(device))[0][int(cls)] # [batch][class]
        res[part] = float((output_org - output).detach().cpu())

        diff = original_data - data
        diff = diff * res[part]

        diff_arr.append(diff)

        sal_map_gt = sal_map_gt + diff
    sal_map_gt = sal_map_gt.cpu().numpy()[0, 0, :, :]
    with open(f"./output/GT_resnet50/{img_id}.npy", 'wb') as f:
        np.save(f, sal_map_gt)

In [None]:
plt.figure(figsize=(10, 10))
for i, a in enumerate(diff_arr):
    plt.subplot(1,5,i+1)
    plt.axis("off")
    plt.imshow(abs(a.cpu().numpy()[0, 0, :, :]));

# Get XAI

In [None]:
epsilon = 1e-5


def _to_probability(info):
    """Convert the input to a probability distribution.

    Args:
        info: NumPy array with the input to convert.

    Returns:
        NumPy array with the input converted to a probability distribution
    """
    info = np.copy(info)
    info_shape = info.shape
    scaler = MinMaxScaler()

    info = info.reshape(-1, 1)
    info = scaler.fit_transform(info)
    info = info.reshape(info_shape)

    return info / (np.sum(info) + epsilon)


def kl(sal_map_gt, sal_map):
    """Compute the Kullback-Leibler divergence between two saliency maps.

    Args:
        sal_map_gt: NumPy array with the ground truth saliency map.
        sal_map: NumPy array with the saliency map to compare.

    Returns:
        Float with the Kullback-Leibler divergence between the two saliency maps.
    """
    sal_map_gt = _to_probability(sal_map_gt)
    sal_map = _to_probability(sal_map)

    # You may want to instead make copies to avoid changing the np arrays.
    sal_map_gt = sal_map_gt + epsilon
    sal_map = sal_map + epsilon

    divergence = np.sum(sal_map_gt * np.log(sal_map_gt / sal_map))

    return divergence

def emd(sal_map_gt, sal_map):
    """Compute the Earth Mover's Distance between two saliency maps.

    Earth Mover's Distance (EMD) is a measure of the distance between two probability distributions over a region.
    It is defined as the minimum cost of turning one distribution into the other, where the cost is the amount of
    "earth" moved, or the amount of probability mass that must be moved from one point to another.

    Args:
        sal_map_gt: NumPy array with the ground truth saliency map.
        sal_map: NumPy array with the saliency map to compare.

    Returns:
        Float between 0 and 1 with the EMD between the two saliency maps.
    """
    sal_map_gt = _to_probability(sal_map_gt)
    sal_map = _to_probability(sal_map)

    sal_map_gt /= sal_map_gt.max() if sal_map_gt.max() > 0 else 1
    sal_map /= sal_map.max() if sal_map.max() > 0 else 1

    diff = stats.wasserstein_distance(sal_map.flatten(), sal_map_gt.flatten())

    return diff

metrics = {
    "emd": emd,
    "kl": kl
}

## RISE

In [None]:
explainer_rise = rise.RISE(net, (256, 256), gpu_batch=1, device=device)
explainer_rise.generate_masks(N=600, s=8, p1=0.1, savepath="masks.npy")

## LIME

In [None]:
from lime.lime_image import LimeImageExplainer

explainer_lime = LimeImageExplainer()

def batch_predict(image: np.array, network, multi_channel: bool = False, n_classes = None) -> np.array:
    """Function to predict the output of the network for a batch of images.

    Args:
        image: NumPy array of shape (n, m, 3) with the image.
        network: Callable function to predict the output of the network.

    Returns:
        NumPy array with the output of the network.
    """
    if n_classes is None:
        n_classes = 1
    
    image = np.copy(image)
    image = np.transpose(image, (0, 3, 1, 2))

    if not multi_channel:
        image = image[:, 0:1, :, :]
    output = network(image).reshape((-1, n_classes))

    return output
    

In [None]:
def get_lime(img, label, *args, **kwargs):
    explanation = explainer_lime.explain_instance(
        img[0, 0, :, :],
        lambda x: batch_predict(
            x,
            lambda x: net(torch.from_numpy(x.astype(np.float32)).to(device)).detach().cpu().numpy(),
            multi_channel=False,
            n_classes=47,
        ),
        num_samples=1500,
        batch_size=1,
        random_seed=42,
        progress_bar=False,
        hide_color=0,
    )
    
    mask = np.zeros(
        (explanation.segments.shape[0], explanation.segments.shape[1]), dtype=np.float64,
    )
    
    lime_res = []
    for key, val in explanation.local_exp[label]:
        if key != 0:
            mask[explanation.segments == key] = abs(val)
    lime_res.append(mask)
    lime_res = np.array(lime_res)

    return lime_res

## Methods

In [None]:
methods_fn = {
    "rise": lambda img, pred: explainer_rise(img)[pred],
    "lime": get_lime
}

In [None]:
results = dict()

for method_name, method in methods_fn.items():   
    results_method = {k: [] for k in metrics.keys()}
    i = 0
    for img_path in tqdm(sorted(glob(f"./data/funny_birds/v2/test_CO/**/**/image.png")), desc=method_name):
        folder_path, img_name = os.path.split(img_path)
        img_id = folder_path.split(os.path.sep)[-1]
    
        img_pil = Image.open(img_path)
        img = transform(img_pil).unsqueeze(0)
        
        cls_prediction = int(torch.argmax(net(img.to(device))).detach().cpu().numpy())
        explanation = method(img.to(device), cls_prediction)#.detach().cpu().numpy()
    
        gt = f"./output/GT_resnet50/{img_id}.npy"
        gt = np.load(gt)

        for metric_name, metric_fn in metrics.items():
            res = metric_fn(gt, explanation)
            results_method[metric_name].append(res)
        i = i +1
        if i == 3:
            break
    results[method_name] = results_method


### Results

In [None]:
for method_name, method_info in results.items():
    print(method_name)
    for k, v in method_info.items():
        print(f"{k}: {np.nanmean(v)} - {np.nanstd(v)}")
    print("-"*25)