Test models with the xAI evaluatin metrics of th ReVel framework
================================================================

In [8]:
import torch

torch.set_num_threads(1)
from torch.utils.data.dataloader import DataLoader
import tqdm
import numpy as np
import json
import argparse
import pandas as pd
import os

from ReVel.LLEs import get_xai_model
from ReVel.perturbations import get_perturbation
from ReVel.load_data import load_data
from ReVel.revel.revel import ReVel
import SHIELD.procedures as procedures
device = "cpu"

n_class = 102
ds = "Flowers"
iterations = 5
batch_size = 32
max_examples = 500
samples = 10
sigma = 12
xai_model = "LIME"
dim = 8

csv_file = "./metrics.csv"
perturbation = get_perturbation(
    name="square", dim=dim, num_classes=n_class, final_size=(224, 224)
)
Test = load_data(ds, perturbation=perturbation, train=False, dir="./data")
# Hacer que Test tenga solo las primeras 'samples' de Test
if isinstance(samples, int):
    indices = np.random.choice(
        [i for i in range(len(Test))], size=samples, replace=False
    )
    Test = torch.utils.data.Subset(Test, indices)

TestLoader = iter(DataLoader(Test, batch_size=1, shuffle=False))
classifier = procedures.classifier("efficientnet_v2_s", n_class)
classifier.to(device)
state_dict = torch.load(
    f"../../../results/Flowers/SHIELD_efficientnet_v2_s_2.0/model.pt",
    map_location=device,
)  # Change the directory to
# the one where the model is saved
classifier.load_state_dict(state_dict)
classifier.to(device)
classifier.eval()
print("Loaded the pretrained model.")

Loaded the pretrained model.


REVEL metrics calculation for a model trained on the Flowers dataset
======================================================================

In [None]:
index = 0
for data in tqdm.tqdm(TestLoader, total=len(TestLoader)):
    inputs, labels = data
    for k, inp in enumerate(inputs):
        inp = inp.to(device)

        # inp dims: (C,H,W) -> (H,W,C)
        inp = np.transpose(inp, (1, 2, 0))

        labels = labels[k].to(device)
        explainer = get_xai_model(
            name=xai_model,
            perturbation=perturbation,
            max_examples=max_examples,
            dim=dim,
            sigma=sigma,
        )
        def classify(image, model=classifier):
            """
            This function takes an image and returns the predicted probabilities.
            :param image: A tensor of shape HxWxC
            :return: A tensor of shape Cx1
            """
            if isinstance(image, np.ndarray):
                image = np.expand_dims(image, 0)

                image = torch.Tensor(image).to(device)

            else:
                image = torch.unsqueeze(image, 0)

            # image dims: (N,H,W,C) -> (N,C,H,W)

            image = torch.transpose(image, 3, 2).transpose(2, 1)

            result = model(image)
            return result
        def model_fordward(
                X: np.array, explainator=explainer, model=classify, img=inp
            ):
            neutral = explainator.perturbation.fn_neutral_image(img)

            avoid = [i for i in range(len(X)) if X[i] == 0]

            segments = explainator.perturbation.segmentation_fn(img.numpy())
            perturbation = explainator.perturbation.perturbation(
                img, neutral, segments=segments, indexes=avoid
            )
            return model(perturbation)
        segments = explainer.perturbation.segmentation_fn(inp.numpy())

        revel = ReVel(
            model_f=classify,
            model_g=model_fordward,
            instance=inp,
            lle=explainer,
            n_classes=n_class,
            segments=segments,
        )
        df = revel.evaluate(times=iterations)
        df.loc[:, "dataset"] = ds
        df.loc[:, "name"] = "SHIELD"
        df.loc[:, "index"] = index
        index+=1
        if os.path.exists(csv_file):
            bigDF = pd.read_csv(csv_file)
            bigDF = pd.concat([bigDF, df])
        else:
            bigDF = df
        bigDF.to_csv(csv_file, index=False)