In [None]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

from datasets import load_dataset

import random
import json
from PIL import Image
import pandas as pd
from tqdm import tqdm
import time

import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision

import quantus

from xai_methods import RELAX, LaFAM, GradCAMHeatmap
from utils import *

def choose_device() -> str:
    if torch.cuda.is_available():
        return "cuda:0"
    if hasattr(torch.backends, "mps"):
        if torch.backends.mps.is_available():
            return "mps"
    return "cpu"

device = torch.device(choose_device())
print(torch.cuda.get_device_name(device))

# fix seed for reproducibility
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # multi-GPU

In [None]:
imagenet_transform = torchvision.transforms.Compose(
    [
        SquareCropAndResize(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)

inverse_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Normalize(
            mean=[0.0, 0.0, 0.0], std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
        ),
        torchvision.transforms.Normalize(
            mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0]
        ),
    ]
)

quantus_metrics = [
    quantus.PointingGame,
    quantus.TopKIntersection,
    quantus.Sparseness,
    quantus.AUC,
    quantus.AttributionLocalisation,
    quantus.RelevanceRankAccuracy,
    quantus.RelevanceMassAccuracy,
]


### Load models

In [None]:
resnet = torchvision.models.resnet50(weights="ResNet50_Weights.IMAGENET1K_V1")
# resnet = torchvision.models.resnet50(pretrained=True)
resnet.eval().to(device)

layer_idx = get_layer_idx(resnet, resnet.layer4)
# print("Layer index: ", layer_idx)

# Take only CCN layers
resnet_layer4 = torch.nn.Sequential(*list(resnet.children())[: layer_idx + 1])
resnet_layer4.eval();
# list(resnet_layer4.children())[-1]

In [None]:
from pl_bolts.models.self_supervised import SimCLR, SwAV


simclr = (
    SimCLR.load_from_checkpoint(
        "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt",
        strict=False,
    )
    .encoder.eval()
    .to(device)
)
layer_idx = get_layer_idx(simclr, simclr.layer4)
simclr_layer4 = torch.nn.Sequential(*list(simclr.children())[: layer_idx + 1])
simclr_layer4.eval().to(device)

swav = (
    SwAV.load_from_checkpoint(
        "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/bolts_swav_imagenet/swav_imagenet.ckpt",
        strict=False,
    )
    .model
    .eval()
    .to(device)
)
layer_idx = get_layer_idx(swav, swav.layer4)
swav_layer4 = torch.nn.Sequential(*list(swav.children())[: layer_idx + 1])
swav_layer4.eval().to(device);

# ImageNet-1k + ImageNet-S


In [None]:
# If the dataset is gated/private, make sure you have run huggingface-cli login
imgent_ds = load_dataset("imagenet-1k", streaming=False, split="validation")

imgnet_labels = imgent_ds._info.features["label"].names

len(imgent_ds), len(imgnet_labels)

### Convert huggingface dataset to dataframe

In [None]:
rows = []
for i, (img, l) in enumerate(zip(imgent_ds._data["image"], imgent_ds._data["label"])):
    l = l.as_py()
    path = img["path"].as_py()
    row = {}
    row["category"] = path.split("/")[-1].split("_")[-1].split(".")[0]
    row["filename"] = path.split("/")[-1].split("_n")[0]
    row["label_names"] = imgnet_labels[l]
    row["label"] = l
    row["dataset_idx"] = i
    rows.append(row)

# set dtype int instead of numpy.int64
imgent_df = pd.DataFrame(rows)
del rows
imgent_df.head()

### Map ImageNet-1k to ImageNet-S

In [None]:
import glob

file_paths = glob.glob("../../ImageNetS919/validation-segmentation/*/*")
imgnet_s_path = [path.split("/")[-1].split(".")[0] for path in file_paths]
print(len(file_paths), len(imgnet_s_path))
imgnet_s_path[:5]

In [None]:
# match names and dataframe
imgnet_s_df = imgent_df[imgent_df["filename"].isin(imgnet_s_path)].copy()
# match file_paths and df_seg['filename']
imgnet_s_df["file_path"] = imgnet_s_df["filename"].apply(
    lambda x: file_paths[imgnet_s_path.index(x)]
)

print(len(imgnet_s_df))
imgnet_s_df[['filename', 'file_path']].head()

In [None]:
# https://github.com/LUSSeg/ImageNet-S?tab=readme-ov-file#qa
def get_seg_id(path):
    segmentation = Image.open(path)  # RGB
    segmentation = np.array(segmentation)
    seg = segmentation[:, :, 1] * 256 + segmentation[:, :, 0]  # R+G*256
    seg = np.unique(seg)
    seg = seg[(seg != 0) & (seg != 1000)]
    return seg


imgnet_s_df["segmentation_id"] = imgnet_s_df["file_path"].apply(lambda x: get_seg_id(x))
imgnet_s_df.head()

In [None]:
imgnet_s_df = imgnet_s_df[imgnet_s_df["segmentation_id"].apply(lambda x: len(x)) > 1].copy()
print(len(imgnet_s_df))
imgnet_s_df.head()

### Evaluation

In [None]:
xai_methods = (
    ("ResNet LaFAM", LaFAM(resnet_layer4, interpolation="nearest", threshold=None)),
    ("SimCLR LaFAM", LaFAM(simclr_layer4, interpolation="nearest", threshold=None)),
    (
        "SimCLR RELAX",
        RELAX(
            simclr,
            n_masks=2048,
            n_cells=7,
            occlusion_batch_size=1024,
            threshold=None,
            unpack_output=lambda x: x[0], # unpack output from SimCLR
            device=device,
            mask_interpolation="nearest",
            heatmap_interpolation="nearest",
        ),
    ),
    ("SwAV LaFAM", LaFAM(swav_layer4, interpolation="nearest", threshold=None)),
    (
        "SwAV RELAX",
        RELAX(
            swav,
            n_masks=2048,
            n_cells=7,
            occlusion_batch_size=1024,
            threshold=None,
            unpack_output=lambda x: x[0],
            device=device,
            mask_interpolation="nearest",
            heatmap_interpolation="nearest",
        ),
    ),
)

gradcam_heatmap = GradCAMHeatmap(
    resnet, resnet.layer4, imgnet_labels, interpolation="nearest", threshold=None
)

target_transform = SquareCropAndResize(224, interpolation=Image.NEAREST)

results = []

for _, row in tqdm(imgnet_s_df.iterrows(), total=len(imgnet_s_df)):
    ds_idx = int(row["dataset_idx"])

    img = imgent_ds[ds_idx]["image"]
    img = imagenet_transform(img.convert("RGB"))
    x_batch = img.to(device).unsqueeze(0)

    seg = Image.open(row["file_path"])
    seg = target_transform(seg)
    seg = np.array(seg)
    seg = seg[:, :, 1] * 256 + seg[:, :, 0]

    m = seg.max()
    mask = torch.zeros(seg.shape)
    mask[seg == m] = 1
    s_batch = mask.unsqueeze(0)

    start = time.time()
    a_batch, pred_label = gradcam_heatmap(x_batch)
    total_time = time.time() - start

    result = {
        "dataset_idx": ds_idx,
        "total_time": total_time,
        "prediction": pred_label,
        "labels": row["label_names"],
        "xai_method": "ResNet Grad-CAM",
    }

    result.update(evaluate(x_batch, s_batch, a_batch.detach(), quantus_metrics, device))
    results.append(result)

    for method_info in xai_methods:
        name, xai_method = method_info

        start = time.time()
        a_batch = xai_method(x_batch, silent=True)
        total_time = time.time() - start

        result = {
            "dataset_idx": ds_idx,
            "total_time": total_time,
            "prediction": pred_label,
            "labels": row["label_names"],
            "xai_method": name,
        }

        r = evaluate(x_batch, s_batch, a_batch, quantus_metrics, device)
        result.update(r)
        results.append(result)

In [None]:
results_df = pd.DataFrame(results)
results_df.tail()

In [None]:
cols = [metric.name for metric in quantus_metrics]
cols.append("total_time")
cols.append("xai_method")
results_df[cols].groupby("xai_method").mean().T

# PASCAL Segmentation


In [None]:
pascal_labels = {
    1: "aeroplane",
    2: "bicycle",
    3: "bird",
    4: "boat",
    5: "bottle",
    6: "bus",
    7: "car",
    8: "cat",
    9: "chair",
    10: "cow",
    11: "diningtable",
    12: "dog",
    13: "horse",
    14: "motorbike",
    15: "person",
    16: "potted plant",
    17: "sheep",
    18: "sofa",
    19: "train",
    20: "tv/monitor",
}


pascal_ds = torchvision.datasets.VOCSegmentation(
    root="../data/VOCdevkit",
    year="2012",
    image_set="val",
    download=False,
    transform=imagenet_transform,
    target_transform=SquareCropAndResize(224, interpolation=Image.NEAREST),
)



In [None]:
# Load labels from json file
with open("imagenet_class_index.json") as f:
    imagenet_labels = list(json.load(f).values())

### Create pandas dataframe from torch dataset


In [None]:
data = []

for idx, (_, mask) in enumerate(pascal_ds):
    mask = np.array(mask)
    labels = np.unique(mask)
    labels = labels[(labels != 0) & (labels != 255)]

    data.append(
        {
            "dataset_idx": idx,
            "labels": labels,
            "label_count": len(labels),
            "label_names": [pascal_labels[label] for label in labels],
        }
    )

# Create a pandas dataframe
pascal_df = pd.DataFrame(data)


print(f"Number of images: {len(pascal_df)}")

pascal_df["label_names"].explode().value_counts().plot(
    kind="bar", figsize=(6, 3), title="Pascal VOC 2012 - labels distribution"
)

In [None]:
pascal_df = pascal_df[pascal_df["label_count"] == 1].copy()

exclude_classes = [
    "person",
    "tv/monitor",
    "sofa",
    "potted plant",
    "diningtable",
    "chair",
    "bottle",
]
pascal_df = pascal_df[
    ~pascal_df["label_names"].apply(lambda x: any(cls in x for cls in exclude_classes))
].copy()

print(f"Number of images: {len(pascal_df)}")

pascal_df["label_names"].explode().value_counts().plot(
    kind="bar", figsize=(6, 3), title="Pascal VOC 2012 - labels distribution"
)

In [None]:
# torch dataset wrapper for pandas dataframe
class PascalVOC2012Seg(torch.utils.data.Dataset):
    def __init__(self, df, ds):
        self.df = df
        self.ds = ds

    def __getitem__(self, idx):
        # get df row
        row = self.df.iloc[idx]

        img, seg = self.ds[row.dataset_idx]

        seg = torch.tensor(np.array(seg))

        # remove outline
        seg[seg == 255] = 0
        seg[seg > 0] = 1

        return img, seg, row.to_dict()

    def __len__(self):
        return len(self.df)


pasvoc2012_seg = PascalVOC2012Seg(pascal_df, pascal_ds)

In [None]:
xai_methods = (
    ("ResNet LaFAM", LaFAM(resnet_layer4, interpolation="nearest", threshold=None)),
    ("SimCLR LaFAM", LaFAM(simclr_layer4, interpolation="nearest", threshold=None)),
    (
        "SimCLR RELAX",
        RELAX(
            simclr,
            n_masks=2048,
            n_cells=7,
            occlusion_batch_size=1024,
            threshold=None,
            unpack_output=lambda x: x[0],
            device=device,
            mask_interpolation="nearest",
            heatmap_interpolation="nearest",
        ),
    ),
    ("SwAV LaFAM", LaFAM(swav_layer4, interpolation="nearest", threshold=None)),
    (
        "SwAV RELAX",
        RELAX(
            swav,
            n_masks=2048,
            n_cells=7,
            occlusion_batch_size=1024,
            threshold=None,
            unpack_output=lambda x: x[0],
            device=device,
            mask_interpolation="nearest",
            heatmap_interpolation="nearest",
        ),
    ),
)

gradcam_heatmap = GradCAMHeatmap(
    resnet, resnet.layer4, imagenet_labels, interpolation="nearest", threshold=None
)

results = []

for batch in tqdm(torch.utils.data.DataLoader(pasvoc2012_seg, batch_size=1)):
    x_batch = batch[0].to(device)
    s_batch = batch[1]
    row = batch[2]

    ds_idx = row["dataset_idx"].item()

    start = time.time()
    a_batch, pred_label = gradcam_heatmap(x_batch)
    total_time = time.time() - start

    result = {
        "dataset_idx": ds_idx,
        "total_time": total_time,
        "prediction": pred_label,
        "labels": row["label_names"][0],
        "xai_method": "ResNet Grad-CAM",
    }

    result.update(evaluate(x_batch, s_batch, a_batch.detach(), quantus_metrics, device))
    results.append(result)

    for method_info in xai_methods:
        name, xai_method = method_info

        start = time.time()
        a_batch = xai_method(x_batch, silent=True)
        total_time = time.time() - start

        result = {
            "dataset_idx": ds_idx,
            "total_time": total_time,
            "prediction": pred_label,
            "labels": row["label_names"][0],
            "xai_method": name,
        }

        r = evaluate(x_batch, s_batch, a_batch, quantus_metrics, device)
        result.update(r)
        results.append(result)

In [None]:
results_df = pd.DataFrame(results)
results_df.tail()

In [None]:
cols = [metric.name for metric in quantus_metrics]
cols.append("total_time")
cols.append("xai_method")
results_df[cols].groupby("xai_method").mean().T