In [None]:
# if running on colab install facenet-pytorch
ON_COLAB = 'google.colab' in str(get_ipython())

if ON_COLAB:
    !pip install -q facenet-pytorch zennit

if ON_COLAB:
    BASE_PATH = '/content/drive/MyDrive/xai_faces/'
    MODEL_PATH = '/content/drive/MyDrive/xai_faces/models/'
else:
    BASE_PATH = '../data/'
    MODEL_PATH = '../models/'
        
DARK_UNDERSAMPLED_PATH = BASE_PATH + 'dark_undersampled_cropped' 
LIGHT_UNDERSAMPLED_PATH = BASE_PATH + 'light_undersampled_cropped' 
DARK_MODEL_PATH = MODEL_PATH + 'dark_undersampled1.pt'
LIGHT_MODEL_PATH = MODEL_PATH + 'light_undersampled1.pt'

RANDOM_SEED = 80223

In [None]:
import sys
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image

import torch
from torch.nn import Linear
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, Normalize, Compose, Resize, CenterCrop

from facenet_pytorch import InceptionResnetV1

from zennit.image import imgify#, imsave
from zennit.attribution import Gradient, SmoothGrad
from zennit.composites import EpsilonGammaBox, EpsilonPlusFlat, SpecialFirstLayerMapComposite
from zennit.torchvision import ResNetCanonizer

from zennit.rules import Epsilon, ZPlus, ZBox, Norm, Pass, Flat
from zennit.types import Convolution, Activation, AvgPool, BatchNorm, MaxPool, Linear as AnyLinear

from xai_helpers import generate_xai_image1, generate_xai_image2, generate_xai_image3, get_image_data1

In [None]:
def find_index(dataset, image_path):
    for i, (path, target) in enumerate(dataset.imgs):
        if path == image_path:
            return i
    raise ValueError("Image path not found in the dataset")

def get_image_data(dataset, index, transform, class_to_idx):
    input_image, input_target = dataset.imgs[index]
    image = Image.open(input_image)
    data = transform(image)[None]
    target = torch.eye(len(class_to_idx))[[input_target]]
    return input_image, input_target, image, data, target

def generate_xai_image(method, model, data, target, noise_level = 0.1, n_iter = 20, symmetric = False, cmap = 'hot'):
    if method == 'Gradient':
        with Gradient(model = model) as attributor: 
            output, attribution = attributor(data, target)
    elif method == 'SmoothGrad':
        with SmoothGrad(noise_level = noise_level, n_iter = n_iter, model = model) as attributor: 
            output, attribution = attributor(data, target)
    ### Layer-wise Relevance Propagation (LRP) with EpsilonPlusFlat
    elif method == 'EpsilonPlusFlat':
        composite = EpsilonPlusFlat()
        with Gradient(model = model, composite = composite) as attributor: 
            output, attribution = attributor(data, target)
    ### LRP with EpsilonGammaBox
    elif method == 'EpsilonGammaBox':
        # the EpsilonGammaBox composite needs the lowest and highest values, which are here for ImageNet 0. and 1. with a different normalization for each channel
        transform_norm = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]])) # create a composite, specifying required arguments
        composite = EpsilonGammaBox(low = low, high = high)
        with Gradient(model = model, composite = composite) as attributor: 
            output, attribution = attributor(data, target)
    else:
        raise ValueError("Invalid method name. Choose either 'Gradient', 'SmoothGrad', 'EpsilonPlusFlat' or 'EpsilonGammaBox'.")

In [None]:
transform_norm = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

# define the full tensor transform
transform = Compose([
    base_transform,
    ToTensor(),
    transform_norm,
])