**VISUM 2022 - Exaplainable AI - Hands on session**
# Explainability for Vision Models
**led by Mara Graziani**
* postdoctoral researcher at Hes-so Valais and IBM Research x ZHAW
* mara.graziani@hevs.ch ; @mormontre

Content

1.   Post-hoc methods: LIME and others
2.   Evaluation of explainability results
3.   Concept-based post-hoc attribution
4.   Interpretable modelling (for tabular data)

## Introduction 

## Take Aways

*   List item
*   List item

### Acknowledgements and References

I would like to thank Anna Hedström (<hedstroem.anna@gmail.com>) for her help with the XAI evaluation toolbox. For any questions concerning Quantus feel free to directly reach out to her. 


### Installation and Set up


In [2]:
!pip install captum opencv-python xmltodict
#!pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html

Collecting captum
  Downloading captum-0.5.0-py3-none-any.whl (1.4 MB)
[?25l[K     |▎                               | 10 kB 19.7 MB/s eta 0:00:01[K     |▌                               | 20 kB 23.0 MB/s eta 0:00:01[K     |▊                               | 30 kB 16.5 MB/s eta 0:00:01[K     |█                               | 40 kB 11.6 MB/s eta 0:00:01[K     |█▏                              | 51 kB 6.2 MB/s eta 0:00:01[K     |█▍                              | 61 kB 7.1 MB/s eta 0:00:01[K     |█▋                              | 71 kB 7.3 MB/s eta 0:00:01[K     |█▉                              | 81 kB 6.9 MB/s eta 0:00:01[K     |██                              | 92 kB 7.6 MB/s eta 0:00:01[K     |██▎                             | 102 kB 6.9 MB/s eta 0:00:01[K     |██▌                             | 112 kB 6.9 MB/s eta 0:00:01[K     |██▊                             | 122 kB 6.9 MB/s eta 0:00:01[K     |███                             | 133 kB 6.9 MB/s eta 0:00:01[K 

In [3]:
!pip install quantus

Collecting quantus
  Downloading quantus-0.1.4-py3-none-any.whl (123 kB)
[K     |████████████████████████████████| 123 kB 6.3 MB/s 
[?25hCollecting pytest==6.2.5
  Downloading pytest-6.2.5-py3-none-any.whl (280 kB)
[K     |████████████████████████████████| 280 kB 38.7 MB/s 
Collecting opencv-python==4.5.5.62
  Downloading opencv_python-4.5.5.62-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (60.4 MB)
[K     |████████████████████████████████| 60.4 MB 29 kB/s 
[?25hCollecting numpy==1.19.5
  Downloading numpy-1.19.5-cp37-cp37m-manylinux2010_x86_64.whl (14.8 MB)
[K     |████████████████████████████████| 14.8 MB 35.3 MB/s 
[?25hCollecting pytest-cov==3.0.0
  Downloading pytest_cov-3.0.0-py3-none-any.whl (20 kB)
Collecting scikit-image==0.19.1
  Downloading scikit_image-0.19.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (13.3 MB)
[K     |████████████████████████████████| 13.3 MB 21.4 MB/s 
[?25hCollecting scikit-learn==0.24.2
  Downloading scikit_learn-0.24.

In [1]:
import quantus
#from google.colab import drive
import sys
import gc
import warnings
import pathlib
import numpy as np
import pandas as pd
import torch
import torchvision
from torchvision import transforms, datasets
import captum
from captum.attr import *
import random
import os
import cv2
import PIL
from xml.etree import ElementTree
import xmltodict
import collections
from IPython.display import clear_output

# Plotting specifics.
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.path import Path
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D
import seaborn as sns

# Notebook settings.
#drive.mount('/content/drive', force_remount=True)
#path = "/content/drive/MyDrive/Projects"
#sys.path.append(f'{path}/quantus')



import quantus

sns.set(font_scale=1.25)
plt.style.use('seaborn-white')
plt.rcParams['ytick.labelleft'] = True
plt.rcParams['xtick.labelbottom'] = True
gc.collect()
torch.cuda.empty_cache()
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
%load_ext autoreload
%autoreload 2
clear_output()

ModuleNotFoundError: ignored

In [None]:
def load_sample(path: str) -> dict:
    """Read data as image and path. """
    return {"input": PIL.Image.open(path).convert("RGB"), "path": path}


def load_binary_mask(filename: str, path_data: str):
    """Load the binary mask for the given path of the data in the correct format. """
    binary_mask = {}

    filename = os.path.splitext(filename)[0]

    # Get label and file name.
    label = filename.split("/")[-2]
    fname = filename.split("_")[-1]

    # Parse annotations.
    tree = ElementTree.parse(os.path.join(path_data, "Annotation/{}/{}_{}.xml".format(label, label, fname)))
    xml_data = tree.getroot()
    xmlstr = ElementTree.tostring(xml_data, encoding="utf-8", method="xml")
    annotation = dict(xmltodict.parse(xmlstr))['annotation']

    width = int(annotation["size"]["width"])
    height = int(annotation["size"]["height"])

    # Iterate objects.
    objects = annotation["object"]

    if type(objects) != list:
        mask = np.zeros((height, width), dtype=int)
        mask[int(objects['bndbox']['ymin']):int(objects['bndbox']['ymax']),
        int(objects['bndbox']['xmin']):int(objects['bndbox']['xmax'])] = 1
        binary_mask[objects['name']] = mask

    else:
        for object in annotation['object']:
            if type(object) == collections.OrderedDict:
                if object['name'] in binary_mask.keys():
                    mask = binary_mask[object['name']]
                else:
                    mask = np.zeros((height, width), dtype=np.uint8)

                mask[int(object['bndbox']['ymin']):int(object['bndbox']['ymax']),
                int(object['bndbox']['xmin']):int(object['bndbox']['xmax'])] = 1

                binary_mask[object['name']] = mask

    # Preprocess binary masks to fit shape of image data.
    for key in binary_mask.keys():
        binary_mask[key] = cv2.resize(binary_mask[key],
                                      (224, 224),
                                      interpolation=cv2.INTER_NEAREST).astype(np.int)[:, :, np.newaxis]

    return binary_mask


class CustomTransform(torch.nn.Module):
    """Custom transformation to handle image processing and binary mask processing simultaneously. """

    image_transform = transforms.Compose([transforms.Resize((224, 224)),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    def __init__(self, path):
        super().__init__()
        self.path = path

    def __call__(self, img):
        return {"input": self.image_transform(img["input"]),
                "mask": load_binary_mask(img["path"], self.path)}


def CustomCollate(batch):
    """Collate function to pack the image, mask and label data accordingly. """
    inputs = []
    targets = []
    masks = []
    custom_classindices = [96, 126, 155, 292, 301, 347, 387, 405, 417, 426, 446, 546, 565, 573, 604, 758, 844, 890, 937,
                           954]

    for b in batch:
        inputs.append(b[0]["input"])
        masks.append(b[0]["mask"])
        targets.append(custom_classindices[b[1]])

    # Rearrange masks, inputs and targets.
    masks = torch.Tensor(np.array([m[list(m.keys())[0]][:, :, 0] for m in masks]))
    inputs = torch.stack(inputs, 0)
    targets = torch.tensor(targets)

    return inputs, masks, targets


def get_imagenet_labels(path: str = ''):
    """Make a int-string label mapping for Imagenet classes."""
    mapping = {}
    with open(f'{path}/assets/imagenet_labels.txt', 'r') as f:
        for ix, line in enumerate(f):
            if ix not in [0, 1001]:
                line = line.split('b" ')[0]
                key = line.split(':')[0]
                value = line.split(": '")[1].split("',\n")[0]
                mapping[int(key)] = str(value)

    return mapping

def evaluate_model(model, data, device):
    """Evaluate accuracy of torch model."""
    model.eval()
    logits = torch.Tensor().to(device)
    targets = torch.LongTensor().to(device)

    with torch.no_grad():
        for images, _, labels in data:
            images, labels = images.to(device), labels.to(device)
            logits = torch.cat([logits, model(images)])
            targets = torch.cat([targets, labels])

    return np.mean(np.argmax(logits.cpu().numpy(), axis=1) == targets.cpu().numpy())

# Spyder plot addition!
# Source code: https://matplotlib.org/stable/gallery/specialty_plots/radar_chart.html.

def radar_factory(num_vars, frame='circle'):
    """Create a radar chart with `num_vars` axes.

    This function creates a RadarAxes projection and registers it.

    Parameters
    ----------
    num_vars : int
        Number of variables for radar chart.
    frame : {'circle' | 'polygon'}
        Shape of frame surrounding axes.
    """
    # calculate evenly-spaced axis angles
    theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)

    class RadarAxes(PolarAxes):

        name = 'radar'

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # rotate plot such that the first axis is at the top
            self.set_theta_zero_location('N')

        def fill(self, *args, closed=True, **kwargs):
            """Override fill so that line is closed by default."""
            return super().fill(closed=closed, *args, **kwargs)

        def plot(self, *args, **kwargs):
            """Override plot so that line is closed by default."""
            lines = super().plot(*args, **kwargs)
            for line in lines:
                self._close_line(line)

        def _close_line(self, line):
            x, y = line.get_data()
            # FIXME: markers at x[0], y[0] get doubled-up
            if x[0] != x[-1]:
                x = np.concatenate((x, [x[0]]))
                y = np.concatenate((y, [y[0]]))
                line.set_data(x, y)

        def set_varlabels(self, labels, angles=None):
            self.set_thetagrids(angles=np.degrees(theta), labels=labels)

        def _gen_axes_patch(self):
            # The Axes patch must be centered at (0.5, 0.5) and of radius 0.5
            # in axes coordinates.
            if frame == 'circle':
                return Circle((0.5, 0.5), 0.5)
            elif frame == 'polygon':
                return RegularPolygon((0.5, 0.5), num_vars,
                                      radius=.5, edgecolor="k")
            else:
                raise ValueError("unknown value for 'frame': %s" % frame)

        def draw(self, renderer):
            """ Draw. If frame is polygon, make gridlines polygon-shaped."""
            if frame == 'polygon':
                gridlines = self.yaxis.get_gridlines()
                for gl in gridlines:
                    gl.get_path()._interpolation_steps = num_vars
            super().draw(renderer)


        def _gen_axes_spines(self):
            if frame == 'circle':
                return super()._gen_axes_spines()
            elif frame == 'polygon':
                # spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.
                spine = Spine(axes=self,
                              spine_type='circle',
                              path=Path.unit_regular_polygon(num_vars))
                # unit_regular_polygon gives a polygon of radius 1 centered at
                # (0, 0) but we want a polygon of radius 0.5 centered at (0.5,
                # 0.5) in axes coordinates.
                spine.set_transform(Affine2D().scale(.5).translate(.5, .5)
                                    + self.transAxes)

                return {'polar': spine}
            else:
                raise ValueError("unknown value for 'frame': %s" % frame)

    register_projection(RadarAxes)
    return theta

In [None]:
!git clone https://github.com/EliSchwartz/imagenet-sample-images.git

Cloning into 'imagenet-sample-images'...
remote: Enumerating objects: 1005, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 1005 (delta 0), reused 0 (delta 0), pack-reused 1002[K
Receiving objects: 100% (1005/1005), 103.81 MiB | 26.29 MiB/s, done.


In [None]:
import os

files = os.listdir('imagenet-sample-images/')
for file_ in files[:100]:
  

In [None]:
# Settings data.
batch_size = 12

# load imagenet 


#dataset = torchvision.datasets.ImageNet("imagenet-sample-images", loader=load_sample, extensions=(".jpeg", "png"),
#                                             transform=CustomTransform(path="drive/MyDrive/Projects/quantus/tutorials/assets/imagenet_images/"))

dataset = torchvision.datasets.DatasetFolder("imagenet-sample-images/", 
                                             loader=load_sample, extensions=(".jpeg", "png", ".JPEG"))
#                                             transform=CustomTransform(path="imagenet-sample-images"))
                                             
test_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=batch_size, collate_fn=CustomCollate)

# Load a batch of inputs, segmentation masks and outputs to use for evaluation.
x_batch, s_batch, y_batch = iter(test_loader).next()
x_batch, s_batch, y_batch = x_batch.to(device), s_batch.to(device), y_batch.to(device)

# Visualise some inputs.
N = 6
mapping = get_imagenet_labels("/content/drive/MyDrive/Projects/quantus/tutorials")
fig, axes = plt.subplots(nrows=1, ncols=N, figsize=(N*3, int(N*2/3)))

for i in range(N):
    y_name = str(mapping[y_batch[i].item()]).split(",")[0]
    axes[i].imshow((np.moveaxis(quantus.denormalise(x_batch[i].cpu().numpy()), 0, -1) * 255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray")
    axes[i].title.set_text(f"ImageNet - {y_name}")
    axes[i].axis("off")
plt.show()

FileNotFoundError: ignored

In [None]:
# Load model.
model = torchvision.models.resnet18(pretrained=True) 
model = model.to(device)
print(f"\nModel test accuracy: {(100 * evaluate_model(model.to(device), test_loader, device)):.2f}%")

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

NameError: ignored