In [48]:
from PIL import Image
import torchvision.transforms as transforms
import torch

from zennit.rules import Epsilon
from zennit.composites import EpsilonPlusFlat, EpsilonAlpha2Beta1Flat, DeconvNet, GuidedBackprop, BetaSmooth, ExcitationBackprop
from zennit.attribution import Gradient

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import os
from dataclasses import dataclass
from typing import Callable
from scipy.stats import pearsonr


from util.cnn_parameters import IMG_HEIGHT, IMG_WIDTH, TEST_IMAGES_PATH, TRAIN_IMAGES_PATH
from util.torch_architecture import CNN
from zennit.image import imgify, imsave


from captum.attr import DeepLift, LRP

INPUT_IMG_FILE = "input.jpg"
RELEVANCE_IMG_FILE = "relevance.png"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [6]:
model = CNN()
model.load_state_dict(torch.load("cats_dogs_cnn.pth", map_location=device))
model.eval()

  model.load_state_dict(torch.load("cats_dogs_cnn.pth", map_location=device))


CNN(
  (model): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Dropout(p=0.25, inplace=False)
    (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Dropout(p=0.25, inplace=False)
    (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Dropout(p=0.25, inplace=False)
    (15): Conv2d(128, 256, kernel_size=(3, 3), stride=

In [4]:
def prepare_image(image_path):

    image = Image.open(image_path).convert("RGB")

    transform = transforms.Compose([
        transforms.Resize((IMG_WIDTH, IMG_HEIGHT)),
        transforms.ToTensor(),
    ])

    input_image = transform(image).unsqueeze(0)

    # print("Input image shape:", input_image.shape)  # [1, 3, 128, 128]
    return input_image

In [7]:
input_img = prepare_image(INPUT_IMG_FILE)
# print(model.requires_grad)
torch.enable_grad()
print(model(input_img))

tensor([[0.2032]], grad_fn=<SigmoidBackward0>)


In [57]:
def deeplift_method(model, input_img):
    baseline = torch.zeros_like(input_img)
    dl = DeepLift(model)
    attributions, delta = dl.attribute(input_img, baseline, target=0, return_convergence_delta=True)
    output = model(input_img)
    print('DeepLift Attributions:', attributions)
    print('Convergence Delta:', delta)
    return output, attributions

In [41]:
def plot_dl(input_image, relevance):
    for _, inp in enumerate(input_image):
        imsave(INPUT_IMG_FILE, inp.detach())
    relevance = torch.detach(relevance)
    absrel = relevance.abs().sum(1)
    imsave(RELEVANCE_IMG_FILE, absrel[0], vmin=0, vmax=absrel[0].amax())
    image = imgify(absrel[0], vmin=0, vmax=absrel[0].amax())

    image1 = mpimg.imread(INPUT_IMG_FILE)
    image2 = mpimg.imread(RELEVANCE_IMG_FILE)

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    axes[0].imshow(image1)
    axes[0].axis('off')
    axes[0].set_title("Original Image")

    axes[1].imshow(image2)
    axes[1].axis('off')
    axes[1].set_title("DeepLIFT Heatmap")

    plt.tight_layout()
    plt.show()

    # os.remove(INPUT_IMG_FILE)
    # os.remove(RELEVANCE_IMG_FILE)


In [58]:
output, attributions = deeplift_method(model, input_img)
plot_dl(input_img, attributions)

TypeError: Module of type <class 'torch.nn.modules.flatten.Flatten'> has no rule defined and nodefault rule exists for this module type. Please, set a ruleexplicitly for this module and assure that it is appropriatefor this type of layer.

# Metrics


In [34]:
# Metrics utils

DAUC = "DAUC"
IAUC = "IAUC"
DELETION = "Deletion"
INSERTION = "Insertion"

@dataclass
class CorrMetricsParams:
    perturbed_image: Callable
    mask: Callable
    label: str


def cloned_image(input_image):
    return input_image.clone()

def zeros_image(input_image):
    return torch.zeros_like(input_image)

def mask_set_zero(abs_relevance, num_remove, sorted_indices):
    mask = torch.ones_like(abs_relevance.view(-1))
    mask[sorted_indices[:num_remove]] = 0  # Set most relevant pixels to 0
    return mask

def mask_set_one(abs_relevance, num_add, sorted_indices):
    mask = torch.zeros_like(abs_relevance.view(-1))
    mask[sorted_indices[:num_add]] = 1  # Set most relevant pixels to 1
    return mask

del_corr = CorrMetricsParams(
    perturbed_image=cloned_image,
    mask=mask_set_zero,
    label="removed"
)

ins_corr = CorrMetricsParams(
    perturbed_image=zeros_image,
    mask=mask_set_one,
    label="added"
)

metrics_params = {
    DAUC: True,
    IAUC: False,
    DELETION: del_corr,
    INSERTION: ins_corr,
}

def apply_mask(input_image, mask, abs_relevance):
    mask = mask.view(abs_relevance.shape).unsqueeze(0).repeat(3, 1, 1)
    perturbed_image = input_image * mask
    return perturbed_image

def get_abs_relevance(relevance):
    return relevance.abs().sum(1)[0]

def get_num_pixels(input_image):
    return input_image.shape[2] * input_image.shape[3]


In [35]:

def compute_auc(model, input_image, output, relevance, auc, steps=10, plot=False):
    model.eval()

    abs_relevance = get_abs_relevance(relevance)
    sorted_indices = abs_relevance.view(-1).argsort(descending=metrics_params[auc])

    perturbed_image = cloned_image(input_image)
    perturbation_curve = []
    num_pixels = get_num_pixels(input_image)

    for step in range(1, steps + 1):
        num_remove = int((step / steps) * num_pixels)
        mask = mask_set_zero(abs_relevance, num_remove, sorted_indices)
        perturbed_image = apply_mask(input_image, mask, abs_relevance)

        with torch.no_grad():
            output = model(perturbed_image)

        perturbation_curve.append(output.item())

    x_vals = np.linspace(0, 1, steps)
    auc_value = np.trapz(perturbation_curve, x_vals)

    if plot:
        plt.figure(figsize=(6, 4))
        plt.plot(x_vals, perturbation_curve, marker='o', label="Score Drop")
        plt.xlabel("Fraction of pixels occluded")
        plt.ylabel("Model Output Score")
        plt.title(f"{auc} Evaluation (AUC = {auc_value:.4f})")
        plt.legend()
        plt.show()

    return auc_value, perturbation_curve

def compute_ins_del_correlation(model, input_image, output, relevance, corr, steps=10, plot=False):
    model.eval()

    abs_relevance = get_abs_relevance(relevance)
    sorted_indices = abs_relevance.view(-1).argsort(descending=True)

    perturbed_image = metrics_params[corr].perturbed_image(input_image)
    perturbation_curve = []
    num_pixels = get_num_pixels(input_image)

    for step in range(1, steps + 1):
        num_remove = int((step / steps) * num_pixels)

        mask = metrics_params[corr].mask(abs_relevance, num_remove, sorted_indices)
        perturbed_image = apply_mask(input_image, mask, abs_relevance)

        with torch.no_grad():
            output = model(perturbed_image)

        perturbation_curve.append(output.item())

    x_vals = np.linspace(0, 1, steps)
    correlation, _ = pearsonr(x_vals, perturbation_curve)  # Pearson

    if plot:
        plt.figure(figsize=(6, 4))
        plt.plot(x_vals, perturbation_curve, marker='o', label="Score Drop")
        plt.xlabel(f"Fraction of pixels {metrics_params[corr].label}")
        plt.ylabel("Model Output Score")
        plt.title(f"{corr} Correlation = {correlation:.4f}")
        plt.legend()
        plt.show()

    return correlation

def compute_average_drop(model, input_image,output, relevance, steps=10, plot=False):
    model.eval()

    abs_relevance = get_abs_relevance(relevance)
    sorted_indices = abs_relevance.view(-1).argsort(descending=True)

    perturbed_image = cloned_image(input_image)
    drop_curve = []
    initial_output = model(input_image).item()
    num_pixels = get_num_pixels(input_image)

    for step in range(1, steps + 1):
        num_remove = int((step / steps) * num_pixels)

        mask = mask_set_zero(abs_relevance, num_remove, sorted_indices)
        perturbed_image = apply_mask(input_image, mask, abs_relevance)

        with torch.no_grad():
            output = model(perturbed_image)

        drop = initial_output - output.item()
        drop_curve.append(drop)

    average_drop = np.mean(drop_curve)

    if plot:
        plt.figure(figsize=(6, 4))
        plt.plot(np.linspace(0, 1, steps), drop_curve, marker='o', label="Output Drop")
        plt.xlabel("Fraction of pixels removed")
        plt.ylabel("Drop in model output")
        plt.title(f"Average Drop = {average_drop:.4f}")
        plt.legend()
        plt.show()

    return average_drop

def compute_increase_in_confidence(model, input_image, output, relevance, steps=10, plot=False):

    model.eval()

    abs_relevance = get_abs_relevance(relevance)
    sorted_indices = abs_relevance.view(-1).argsort(descending=True)

    perturbed_image = zeros_image(input_image)
    confidence_curve = []
    num_pixels = get_num_pixels(input_image)

    for step in range(1, steps + 1):
        num_add = int((step / steps) * num_pixels)

        mask = mask_set_one(abs_relevance, num_add, sorted_indices)
        perturbed_image = apply_mask(input_image, mask, abs_relevance)

        with torch.no_grad():
            output = model(perturbed_image)

        confidence = torch.sigmoid(output).item()
        confidence_curve.append(confidence)

    increase_in_confidence = confidence_curve[-1] - confidence_curve[0]

    if plot:
        plt.figure(figsize=(6, 4))
        plt.plot(np.linspace(0, 1, steps), confidence_curve, marker='o', label="Confidence")
        plt.xlabel("Fraction of pixels added")
        plt.ylabel("Confidence (probability of positive class)")
        plt.title(f"Increase in Confidence = {increase_in_confidence:.4f}")
        plt.legend()
        plt.show()

    return increase_in_confidence



In [37]:
# For one example
input_image = input_img
relevance = attributions
iauc_value, curve = compute_auc(model, input_image, output, relevance, auc=IAUC,steps=10)
print(f"IAUC Score: {iauc_value:.4f}")

dauc_value, curve = compute_auc(model, input_image, output, relevance, auc=DAUC,steps=10)
print(f"DAUC Score: {dauc_value:.4f}")

insertion_correlation = compute_ins_del_correlation(model, input_image, output, relevance, corr=INSERTION,steps=10)
print(f"Insertion Correlation: {insertion_correlation:.4f}")

deletion_correlation = compute_ins_del_correlation(model, input_image, output, relevance, corr=DELETION,steps=10)
print(f"Deletion Correlation: {deletion_correlation:.4f}")

average_drop = compute_average_drop(model, input_image, output, relevance, steps=10)
print(f"Average Drop: {average_drop:.4f}")

increase_in_confidence = compute_increase_in_confidence(model, input_image, output, relevance, steps=10)
print(f"Increase in Confidence: {increase_in_confidence:.4f}")

IAUC Score: 0.1603
DAUC Score: 0.2813
Insertion Correlation: -0.0693
Deletion Correlation: 0.7572
Average Drop: -0.0643
Increase in Confidence: -0.0619


# Batch Metrics


In [44]:

def get_output_relevance(input_images):
    outputs = []
    relevances = []
    for i in input_images:
        o, r = deeplift_method(model, i)
        outputs.append(o)
        relevances.append(r)
    return outputs, relevances

def batch_dauc(input_images):
    outputs, relevances = get_output_relevance(input_images)

    dauc_values = []
    for input_image, output, relevance in zip(input_images, outputs, relevances):
        dauc_value, _ = compute_auc(model, input_image, output, relevance, auc=DAUC,steps=10)
        dauc_values.append(dauc_value)

    return np.mean(dauc_values)

def batch_iauc(input_images):
    outputs, relevances = get_output_relevance(input_images)

    iauc_values = []
    for input_image, output, relevance in zip(input_images, outputs, relevances):
        iauc_value, _ = compute_auc(model, input_image, output, relevance, auc=IAUC,steps=10)
        iauc_values.append(iauc_value)

    return np.mean(iauc_values)

def batch_inser_corr(input_images):
    outputs, relevances = get_output_relevance(input_images)

    corr_values = []
    for input_image, output, relevance in zip(input_images, outputs, relevances):
        corr_value = compute_ins_del_correlation(model, input_image, output, relevance, corr=INSERTION,steps=10)
        corr_values.append(corr_value)

    return np.mean(corr_values)

def batch_del_corr(input_images):
    outputs, relevances = get_output_relevance(input_images)

    corr_values = []
    for input_image, output, relevance in zip(input_images, outputs, relevances):
        corr_value = compute_ins_del_correlation(model, input_image, output, relevance, corr=DELETION,steps=10)
        corr_values.append(corr_value)

    return np.mean(corr_values)

def batch_avg_drop(input_images):
    outputs, relevances = get_output_relevance(input_images)

    values = []
    for input_image, output, relevance in zip(input_images, outputs, relevances):
        val = compute_average_drop(model, input_image, output, relevance, steps=10)
        values.append(val)

    return np.mean(values)

def batch_confidence_incr(input_images):
    outputs, relevances = get_output_relevance(input_images)

    values = []
    for input_image, output, relevance in zip(input_images, outputs, relevances):
        val = compute_increase_in_confidence(model, input_image, output, relevance, steps=10)
        values.append(val)

    return np.mean(values)

In [45]:
NUM_IMAGES = 200
img_paths_for_metrics = [TEST_IMAGES_PATH + "\\" + str(i) + ".jpg" for i in range(1, NUM_IMAGES+1)]
input_images_for_metrics = [prepare_image(img_path) for img_path in img_paths_for_metrics]

print("Average DAUC: " + str(batch_dauc(input_images_for_metrics)))
print("Average IAUC: " + str(batch_iauc(input_images_for_metrics)))
print("Average Insertion Correlation: " + str(batch_inser_corr(input_images_for_metrics)))
print("Average Deletion Correlation: " + str(batch_del_corr(input_images_for_metrics)))
print("Average Average Drop: " + str(batch_avg_drop(input_images_for_metrics)))
print("Average Increase in Confidence: " + str(batch_confidence_incr(input_images_for_metrics)))
# 0.0014


DeepLift Attributions: tensor([[[[ 1.5315e-05,  1.4570e-05, -1.0388e-05,  ...,  2.6049e-05,
            1.1423e-05,  4.0170e-05],
          [ 1.5981e-05,  5.8556e-05, -1.1238e-04,  ...,  3.7323e-05,
            5.8492e-05,  5.0017e-05],
          [-3.5276e-05, -8.6901e-05, -6.7606e-05,  ..., -1.4547e-05,
           -5.0833e-06,  9.6527e-05],
          ...,
          [ 6.9517e-07, -4.6465e-04, -4.3124e-05,  ..., -2.4652e-05,
           -9.0594e-05,  2.6677e-05],
          [-5.0655e-05, -1.9300e-04, -1.5264e-04,  ..., -1.5958e-05,
           -4.4039e-05,  7.5364e-06],
          [-1.0994e-05, -7.3606e-05, -2.2913e-05,  ..., -3.8549e-05,
           -2.1956e-05,  8.7577e-06]],

         [[ 9.2220e-06,  1.2351e-05, -9.4092e-06,  ...,  1.2294e-05,
            5.0117e-07,  6.3581e-05],
          [ 1.3592e-05,  2.9299e-05, -2.1345e-04,  ...,  2.6217e-05,
            5.7270e-05,  8.7305e-05],
          [-1.3217e-05, -3.3131e-05, -6.0864e-06,  ...,  5.9139e-06,
           -1.9498e-05,  7.6877e-05