# Evaluation with metrics

## Loading imports, model and dataset

In [2]:
import torch
import pandas as pd
import numpy as np
from models.resnetv2 import ResNet50
from models.resnet_attention import resnet50 as ResNet50Att, ResNetAtt
from torchvision import datasets, transforms
from PIL import Image
import torch.utils.data as data
from melanoma.melanoma_loader import Melanoma_loader as melanoma_dataset
import matplotlib.pyplot as plt
import os
from PIL import Image
import quantus

from zennit.composites import EpsilonPlusFlat
from zennit.torchvision import ResNetCanonizer
from zennit.attribution import Gradient
from crp.image import imgify
from captum.metrics import infidelity


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
WIDTH = 256
HEIGHT = 256
ROOT = "data/train/train/"

model_paths = [
    "./model_resnet.pt",
    "./model_resnet_attention.pt",
    "./model_resnet_unbiased.pt",
    "./model_resnet_attention_unbiased.pt",
]
model_names = [
    "Biased Resnet",
    "Biased Resnet+att",
    "Unbiased Resnet",
    "Unbiased Resnet+att",
]
model_classes = [ResNet50, ResNet50Att, ResNet50, ResNet50Att]
model_urls = [
    "https://www.dropbox.com/s/wyma5jispzl63gr/resnet_bias_ckpt_epoch_49.pth?dl=0",
    "https://www.dropbox.com/s/wnvxz05hy2slymx/resnet_att_bias_ckpt_epoch_71.pth?dl=0",
    "https://www.dropbox.com/s/1apem4cqx7akycq/resnet_unb_ckpt_epoch_23.pth?dl=0",
    "https://www.dropbox.com/s/xl826hxrcrypon0/resnet_att_unb_ckpt_epoch_72.pth?dl=0",
]

models = []
for model_path, model_url, model_name, model_class in zip(
    model_paths, model_urls, model_names, model_classes
):
    if not os.path.exists(model_path):
        os.system(f"wget -O {model_path} {model_url}")

    if model_class == ResNet50:
        model = model_class(out_features=2, freeze=True, in_channels=3)
    elif model_class == ResNet50Att:
        model = model_class(pretrained=False)

    model.load_state_dict(
        torch.load(model_path, map_location="cpu")["model"], strict=False
    )
    model.eval()
    models.append(model)


In [8]:
def stacked_img(img_path, extra_img):
    img = Image.open(ROOT + img_path + ".jpg")

    new_im = Image.new('RGB', (2 * WIDTH, HEIGHT))
    new_im.paste(img, (0, 0))
    new_im.paste(extra_img, (WIDTH, 0))
    return new_im

def iterate_class(dataset, find_melanoma=1):
    for idx in range(len(dataset)):
        if dataset[idx][1] == find_melanoma:
            yield dataset[idx][0].view(1, 3, 256, 256), dataset.lookup_path(idx)


In [9]:
# Dataloader
dataset = melanoma_dataset(root = "data/train/train", ann_path = "melanoma/data/test_set.csv", 
                          transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Resize((256, 256)),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]))
dataloader = data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

positive_iterator = iterate_class(dataset, find_melanoma=1)
negative_iterator = iterate_class(dataset, find_melanoma=0)
iterator = iter(dataset)

## Evaluate model

### Create explanations

In [10]:
class AttentionWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)[0]

def lrp_explainer(
    inputs,
    targets,
    model,
    abs=False,
    normalise=False,
    sum_channels=False,
    *args,
    **kwargs
) -> np.array:
    model.eval()
    if isinstance(inputs, tuple) and len(inputs) == 1:
        inputs = inputs[0]
    if not isinstance(inputs, torch.Tensor):
        inputs = (
            torch.Tensor(inputs)
            .reshape(
                -1,
                kwargs.get("nr_channels", 3),
                kwargs.get("img_size", 256),
                kwargs.get("img_size", 256),
            )
            .to(kwargs.get("device", None))
        )
    inputs.requires_grad = True
    if not isinstance(targets, torch.Tensor):
        targets = torch.as_tensor(targets).long().to(kwargs.get("device", None))

    assert (
        len(np.shape(inputs)) == 4
    ), "Inputs should be shaped (nr_samples, nr_channels, img_size, img_size) e.g., (1, 3, 256, 256)."

    # use the ResNet-specific canonizer
    canonizer = ResNetCanonizer()

    # create a composite, specifying the canonizers
    composite = EpsilonPlusFlat(canonizers=[canonizer])

    col1 = targets == 0
    col2 = targets == 1
    target = torch.vstack([col1, col2]).T.float()

    # create the attributor, specifying model and composite
    if isinstance(model, ResNetAtt):
        model = AttentionWrapper(model)

    model.requires_grad_(True)
    inputs.requires_grad_(True)
    target.requires_grad_(True)

    with Gradient(model=model, composite=composite) as attributor:
        # compute the model output and attribution
        output, attribution = attributor(inputs, target)

    if abs:
        attribution = torch.abs(attribution)

    # sum over the channels
    if sum_channels:
        relevance = attribution.sum(1)
    else:
        relevance = attribution

    explanation = relevance

    if isinstance(explanation, torch.Tensor):
        explanation = explanation.cpu().detach().numpy()

    if normalise:
        explanation = quantus.normalise_func.normalise_by_max(explanation)
        
    return explanation


In [6]:
POSITIVE = False

for i in range(100):
    print(i)
    if POSITIVE:
        x, path = next(positive_iterator)
        targets = [1] * x.shape[0]
    else:
        x, path = next(negative_iterator)
        targets = [0] * x.shape[0]

    for model, model_name in zip(models, model_names):
        # attention
        explanation = lrp_explainer(
            inputs=x,
            targets=targets,
            model=model,
            sum_channels=True,
            normalise=True
        )
        img = imgify(explanation[0], symmetric=True, cmap='coldnhot')
        new_img = stacked_img(path, img)
        os.makedirs(f'lrp_results/{int(POSITIVE)}/{path}/', exist_ok=True)
        new_img.save(f'lrp_results/{int(POSITIVE)}/{path}/{model_name}.png')
        # show the image
        # display(new_img)
    


0
1
2
3


KeyboardInterrupt: 

### Evaluate AvgSensitivity

In [14]:
# x, y = next(iter(dataloader))
from itertools import repeat, chain

# x, _ = next(positive_iterator)
# y = torch.tensor([1])
metric_init = quantus.AvgSensitivity(
    nr_samples=10,
    lower_bound=0.1,
    norm_numerator=quantus.norm_func.fro_norm,
    norm_denominator=quantus.norm_func.fro_norm,
    perturb_func=quantus.perturb_func.uniform_noise,
    similarity_func=quantus.similarity_func.difference,
    disable_warnings=True,
    normalise=True,
    abs=True,
)

SAMPLE_SIZE = 10

positives = []
negatives = []
for i in range(SAMPLE_SIZE):
    x, _ = next(positive_iterator)
    positives.append(x)
    x, _ = next(negative_iterator)
    negatives.append(x)

for model, model_name in zip(models, model_names):
    s = 0 
    for sample, target in zip(
        chain(positives, negatives),
        chain(repeat(1, SAMPLE_SIZE), repeat(0, SAMPLE_SIZE)),
    ):
        value = metric_init(
            model=model,
            x_batch=sample.numpy(),
            y_batch=np.array([target]),
            explain_func=lrp_explainer,
        )
        s += value[0]
    print(f"{model_name} avg sensitivity: {s / (2 * SAMPLE_SIZE)}")


Biased Resnet avg sensitivity: 0.013954376486944966
Biased Resnet+att avg sensitivity: 0.04411492634098976
Unbiased Resnet avg sensitivity: 0.02628094685147516
Unbiased Resnet+att avg sensitivity: 0.05559212098130957


### Evaluate Infidelity

In [19]:
def perturb_fn(inputs):
   noise = torch.tensor(np.random.normal(0, 0.01, inputs.shape)).float()
   return noise, inputs - noise

for model, model_name in zip(models, model_names):
   if isinstance(model, ResNetAtt):
      model = AttentionWrapper(model)
   
   s = 0
   for sample, target in zip(
      chain(positives, negatives),
      chain(repeat(1, SAMPLE_SIZE), repeat(0, SAMPLE_SIZE)),
   ):
      explanation = lrp_explainer(sample, torch.tensor([target]), model, abs=False, normalise=True, sum_channels=False)

      infid = infidelity(model, perturb_fn, x, torch.tensor(explanation, requires_grad=True), target=target)
      s += infid
   print(f'{model_name} infidelity: {s / (2 * SAMPLE_SIZE)}')

Biased Resnet infidelity: tensor([0.0480])
Biased Resnet+att infidelity: tensor([0.0731])
Unbiased Resnet infidelity: tensor([0.0808])
Unbiased Resnet+att infidelity: tensor([0.1556])
