In [1]:
import os

import torch
from tempo2.models import TempoLinear
import matplotlib.pyplot as plt
from PIL import Image

from torchvision.transforms import Compose, Resize, ToTensor, InterpolationMode, Grayscale, GaussianBlur, Normalize
from zennit.torchvision import ResNetCanonizer
from zennit.composites import EpsilonPlusFlat
from zennit.attribution import Gradient

from zennit.image import imgify, imsave

In [2]:
MODEL_NAME = "vanilla.pth"

In [3]:
# load models
weights_tp = torch.load(f'../../model_zoo/{MODEL_NAME}')
model_tp = TempoLinear(out_features=24, weights=None,freeze_backbone=True)
model_tp.load_state_dict(weights_tp)
model_tp.requires_grad = True
_ = model_tp.eval()

weights_bl = torch.load('../../model_zoo/baseline.pth')
model_bl = TempoLinear(out_features=24, weights=None,freeze_backbone=True)
model_bl.load_state_dict(weights_bl)
model_bl.requires_grad = True
_ = model_bl.eval()

In [4]:
def get_relevance(model, data, num_classes):
    """
    Computes the relevance score for a classification decision
    """

    with torch.no_grad():
        pred = torch.argmax(model(data).flatten()).item()

    canonizer = ResNetCanonizer()

    # create a composite, specifying the canonizers
    composite = EpsilonPlusFlat(canonizers=[canonizer])

    # choose a target class for the attribution (label 437 is lighthouse)
    target = torch.eye(num_classes)[[pred]]

    # create the attributor, specifying model and composite
    with Gradient(model=model, composite=composite) as attributor:
        # compute the model output and attribution
        output, attribution = attributor(data, target)

    relevance = attribution.sum(1)

    return pred, relevance

In [5]:
images = {image[:8]: "imgsFine/leftImg8bit/Test/"+image for image in os.listdir("imgsFine/leftImg8bit/Test")}
masks = {mask[:8]: "gtFine/Test/"+mask for mask in os.listdir("gtFine/Test") if mask.endswith("color.png")}

In [6]:
transform_img = Compose([
    Resize(128, interpolation=InterpolationMode.NEAREST),
    ToTensor(),
])

transform_msk = Compose([
    Resize(128, interpolation=InterpolationMode.NEAREST),
    Grayscale(),
    ToTensor(),
])

In [7]:
def get_fat_mask(mask: Image) -> torch.FloatTensor:
    """
        Adds border to the provided mask
    """
    mask = transform_msk(mask)

    mask_blur = GaussianBlur(kernel_size=3)(mask)
    mask_blur = (mask_blur > 0).to(float)
    mask_blur = GaussianBlur(kernel_size=3)(mask_blur)
    mask_blur = (mask_blur > 0).to(float)

    return mask_blur

In [8]:
dataset = [(transform_img(Image.open(images[name])), get_fat_mask(Image.open(masks[name]))) for name in list(images.keys())]

In [9]:
def RR(heatmap: torch.FloatTensor, mask: torch.FloatTensor) -> float:
    """
    Computes the relative relevance given a heatmap and a mask 
    attributing the pixles to the object.
    """
    mean_all = heatmap.mean()
    mean_object = (heatmap * mask).mean()

    return mean_object / mean_all

In [10]:
rrs_tp = []
rrs_bl = []

for img, msk in dataset:
    data = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(img)[None]
    _, heatmap_tp = get_relevance(model_tp, data, num_classes=24)
    _, heatmap_bl = get_relevance(model_bl, data, num_classes=24)

    rr_tp = RR(heatmap_tp, msk)
    rr_bl = RR(heatmap_bl, msk)

    rrs_tp.append(rr_tp)
    rrs_bl.append(rr_bl)

    print(f"Tempo: {rr_tp}, Baseline: {rr_bl}")

print()
print(f"Tempo mean: {torch.tensor(rrs_tp).mean().item()}, Baseline mean: {torch.tensor(rrs_bl).mean().item()}")

Tempo: 0.7310643676631893, Baseline: 0.6538211564856075
Tempo: 0.7330982465051845, Baseline: 0.6004862648435708
Tempo: 0.7095254124845454, Baseline: 0.6224232602351775
Tempo: 0.6711422495583182, Baseline: 0.6401419504550205
Tempo: 0.7811425089061296, Baseline: 0.643317880517269
Tempo: 0.7165129279365784, Baseline: 0.6428706431471269
Tempo: 0.7638038151765815, Baseline: 0.6851929580778142
Tempo: 0.6947636900565572, Baseline: 0.545822104039594
Tempo: 0.7639014344575629, Baseline: 0.7247159878145255
Tempo: 0.7459748995106995, Baseline: 0.6320593829560376
Tempo: 0.7593232581104572, Baseline: 0.6429396042306966
Tempo: 0.7999292356196809, Baseline: 0.6589728701728628
Tempo: 0.7351277195817723, Baseline: 0.647914192274921
Tempo: 0.7911635548672756, Baseline: 0.6711955485872786
Tempo: 0.8289141469109743, Baseline: 0.6596835484934062
Tempo: 0.7470525684869768, Baseline: 0.6674330387547781
Tempo: 0.8495459675479227, Baseline: 0.749068905804232
Tempo: 0.8282129570269651, Baseline: 0.7605298977679