# Evaluation of fidelity by feature perturbation

This notebook performs the perturbation of input images according to the previously calculated importance estimates. The model outputs (logits) are recorded as a function of the fraction of perturbed pixels

First choose for which model you want to perform the perturbation:

In [1]:
models = {
        0: "cifar",
        1: "food101",
        2: "imgnet"
}
model_name = models[0]

### Imports

In [2]:
import torch
import saliency.core as saliency
import os
from torchvision import transforms
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from torch.nn import functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
import glob
from matplotlib_inline.backend_inline import set_matplotlib_formats
from scipy import stats
import seaborn as sns
import pickle
from numpy import trapz
from sklearn.model_selection import train_test_split

from numba import njit
from numba import cuda

In [3]:
%load_ext autoreload
%autoreload 2
from lightning_models.model_imagenet import ImgNet_ResNet
from lightning_models.model_cifar_resnet import CIFAR_ResNet
from lightning_models.model_food101 import Food101_ResNet

from torchvision.datasets import CIFAR10
from torchvision.datasets import ImageNet
from torchvision.datasets import Food101

Global seed set to 7
Global seed set to 7
Global seed set to 7


In [4]:
device = 'cuda:0'

### Define some defaults for plotting

In [5]:
sns.set_style("whitegrid")
set_matplotlib_formats('pdf', 'svg')

In [6]:
plt.rc('axes', titlesize=18)     # fontsize of the axes title
plt.rc('axes', labelsize=20)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=18)    # fontsize of the tick labels
plt.rc('ytick', labelsize=18)    # fontsize of the tick labels
plt.rc('legend', fontsize=15)    # legend fontsize
plt.rc('font', size=18)
plt.rcParams['text.usetex'] = True

### Defining datasets

In [7]:
transform2d = transforms.Compose(
    [
        transforms.ToTensor(),
        # normalizes images to [-1,1]
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

transform3d = transforms.Compose(
    [
        transforms.ToTensor(),
        # normalizes images to [-1,1]
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

transform_imgnet = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.PILToTensor(),
        transforms.ConvertImageDtype(torch.float),
        #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

transform_food101 = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.561, 0.440, 0.312), (0.252, 0.256, 0.259)),
        ]
)

cifar = CIFAR10('./data', train=False, transform=transform3d)
imgnet = ImageNet(root='/home/lbrocki/AugmentData/data', 
                  split='val', 
                  transform=transform_imgnet)
food101 = Food101("/home/lbrocki/AugmentData/data/", split="test", transform=transform_food101)

In [20]:
test_indices = np.load('test_idx_val_imgnet.npy')
imgnet_test = torch.utils.data.Subset(imgnet, test_indices)

_, food_test_indices = train_test_split(
    np.arange(0,len(food101)), 
    test_size=5000, random_state=42
)
food101_test = torch.utils.data.Subset(food101, food_test_indices)

In [9]:
if(model_name == "cifar"):
    data = cifar
elif(model_name == "food101"):
    data = food101_test
elif(model_name == "imgnet"):
    data = imgnet_test
else:
    print("model name error")

### Obtain perturbation curves

In [10]:
@njit()
def get_grid2d(arr, descending=True):
    grid = []
    xs, ys = arr.shape
    for x in range(xs):
        for y in range(ys):
            grid.append([arr[x,y],x,y])
    grid = np.array(grid)
    grid_sorted = grid[grid[:, 0].argsort()]
    if(descending):
        return grid_sorted[::-1]
    else:
        return grid_sorted

In [11]:
def perturb_n(img, attribution, n, reverse=True):
    grid = get_grid2d(attribution, descending=reverse)

    x = grid[:,1].astype(np.int64)[:n]
    y = grid[:,2].astype(np.int64)[:n]
    
    img_cp = np.copy(img)

    img_cp[:,x,y] = 0.0
    
    return img_cp

In [12]:
def setup(model_path):
    estimator_path = 'estimators/'+model_path+'/'
    logits_path = 'logits/'+model_path+'/'
    
    if(model_name == "cifar"):
        model = CIFAR_ResNet().to(device)
    elif(model_name == "food101"):
        model = Food101_ResNet().to(device)
    elif(model_name == "imgnet"):
        model = ImgNet_ResNet().to(device)
    else:
        print("model name error")

    model.load_state_dict(torch.load(f"weights/{model_path}"))

    try:
        os.makedirs(logits_path)
    except Exception as e:
        print(e)

    model.eval()
    
    return estimator_path, logits_path, model

In [13]:
def perturb(estimator, estim_path, logits_path, model, 
            data, ch_mode, random_label=False, xinput=False):
    print(estim_path, logits_path)
        
    all_LiF = []
    all_MiF = []

    #by sorting maps_path in ascending order we make sure that the 
    #saliency maps match with the images and labels in the for loop 
    maps_path = glob.glob(estim_path+estimator+'/*')
    maps_path.sort()
#     for i in tqdm(range(len(data))):
    for i in tqdm(range(10)):
        img = data[i][0]
        c,h,w = img.shape
        if(estimator != 'random'):
            map_ = np.load(maps_path[i], allow_pickle=True)
            if(xinput):
                map_ = np.asarray(img)*map_
            if(ch_mode == 'abs_sum'):
                map_ = np.abs(map_)
        else:
            map_ = np.random.uniform(0, 1,(h, w))
        if(len(map_.shape) == 3):
            if(map_.shape[0] == 3):
                map_ = np.sum(map_, axis=0)
            else:
                map_ = map_.squeeze(0)
        assert len(map_.shape) == 2
        step = int(h*w*0.05)
        pixel_range = np.arange(0, h*w, step)

        predicted_label = model(img.unsqueeze(0).to(device)).argmax()
        batch_LiF = torch.zeros((len(pixel_range), c, h, w), device=device)
        batch_MiF = torch.zeros((len(pixel_range), c, h, w), device=device)
        
        for k, n in enumerate(pixel_range):
            img_p_LiF = perturb_n(img, map_, n, False)
            img_p_MiF = perturb_n(img, map_, n, True)

            img_p_LiF = torch.tensor(img_p_LiF, device=device)
            img_p_MiF = torch.tensor(img_p_MiF, device=device)

            batch_LiF[k] = img_p_LiF
            batch_MiF[k] = img_p_MiF
        
        with torch.no_grad():
            scores_LiF = model(batch_LiF)[:,predicted_label].detach().cpu().numpy()
            scores_MiF = model(batch_MiF)[:,predicted_label].detach().cpu().numpy()

        all_LiF.append(scores_LiF)
        all_MiF.append(scores_MiF)

    all_LiF = np.array(all_LiF)
    all_MiF = np.array(all_MiF)
    if(xinput):
        str_xi = 'xi_'
    else:
        str_xi = ''

    np.save(f"{logits_path}{estimator}_{ch_mode}_{str_xi}MiF", all_MiF)
    np.save(f"{logits_path}{estimator}_{ch_mode}_{str_xi}LiF", all_LiF)

In [14]:
perturb_method = ["rand", "rect", "none"]
model_paths = [f"{model_name}/{p}.pt" for p in perturb_method]

# sum: sum color channels
# abs_sum: first abs then sum

# specify for which importance estimators you want to perform perturbation
# the last argument toggles elementwise multiplication with input image
estim_list = [
    ['random', 'sum', False],
    ['intgrad', 'sum', False],
    ['intgrad', 'abs_sum', False],
    ['vanilla', 'sum', False],
    ['vanilla', 'abs_sum', False],
    ['vanilla', 'sum', True],
    ['vanilla', 'abs_sum', True],
    ['smooth', 'sum', False],
    ['smooth', 'abs_sum', False],
    ['smooth', 'sum', True],
    ['smooth', 'abs_sum', True],
    ['smooth_sq', 'sum', True],
    ['smooth_sq', 'sum', False]
]
random_label = False
for model_path in model_paths:
    for param in estim_list:
        estimator, ch_mode, xinput = param
        print(model_path, estimator, ch_mode)
        estim_path, logits_path, model = setup(model_path)
        perturb(estimator, estim_path, logits_path, model, data, ch_mode, random_label, xinput)

cifar/rand.pt intgrad sum
[Errno 17] File exists: 'logits/cifar/rand.pt/'
estimators/cifar/rand.pt/ logits/cifar/rand.pt/


100%|██████████| 10/10 [00:02<00:00,  3.57it/s]


cifar/rect.pt intgrad sum
[Errno 17] File exists: 'logits/cifar/rect.pt/'
estimators/cifar/rect.pt/ logits/cifar/rect.pt/


100%|██████████| 10/10 [00:00<00:00, 29.92it/s]


cifar/none.pt intgrad sum
[Errno 17] File exists: 'logits/cifar/none.pt/'
estimators/cifar/none.pt/ logits/cifar/none.pt/


100%|██████████| 10/10 [00:00<00:00, 30.27it/s]


### Calculate area between the MIF and LIF curves (fildelity metric)

In [15]:
def AUC(logits):
    x = np.linspace(0,100,len(logits))
    return trapz(logits/logits[0], x)
val_dict = {}

def fill_dict(val_dict, model_names):
    for model_path in model_names:
        print(model_path)
        val_dict[model_path] = {}
        for param in estim_list:
            str_xi = ""
            if(param[2]):
                str_xi = "_xi"
            estim = f"{param[0]}_{param[1]}{str_xi}"
            
            val_dict[model_path][estim] = {}
            logits_path = f'logits/{model_path}/'
            
            logits_MiF = np.load(f"{logits_path}{estim}_MiF.npy")
            logits_LiF = np.load(f"{logits_path}{estim}_LiF.npy")
            AUC_MiF = []
            AUC_LiF = []
            for i,j in zip(logits_MiF, logits_LiF):
                AUC_MiF.append(AUC(i))
                AUC_LiF.append(AUC(j))
            
            mean_MiF = np.mean(AUC_MiF)
            mean_LiF = np.mean(AUC_LiF)
            var_MiF = np.var(AUC_MiF)
            var_LiF = np.var(AUC_LiF)
                        
            std_diff = np.sqrt(var_MiF + var_LiF)
            
            ci_diff = stats.norm.interval(0.95, loc=mean_LiF-mean_MiF, 
                             scale=std_diff / np.sqrt(len(AUC_MiF)))
            
            pm = np.abs(mean_LiF-mean_MiF-ci_diff[0])
            
            val_dict[model_path][estim]["MIF"] = mean_MiF
            val_dict[model_path][estim]["LIF"] = mean_LiF
            val_dict[model_path][estim]["Diff"] = mean_LiF - mean_MiF
            val_dict[model_path][estim]["CI"] = pm

In [16]:
fill_dict(val_dict, model_paths)
with open('logits/logits_dict.pkl', 'wb') as handle:
    pickle.dump(val_dict, handle)

cifar/rand.pt
cifar/rect.pt
cifar/none.pt


In [17]:
with open('logits/logits_dict.pkl', 'rb') as handle:
    val_dict = pickle.load(handle)

In [18]:
pd.DataFrame(val_dict["cifar/rand.pt"])

Unnamed: 0,intgrad_sum
CI,13.412625
Diff,59.411431
LiF,87.880639
MiF,28.469208


MIF: area under MIF curve <br />
LIF: area under LIF curve <br />
Diff: area between those curves <br />
CI: symmetrical confidence interval for diff <br />

$\text{fidelity} = \text{Diff} \pm \text{CI}$

### Create a plot comparing the MIF and LIF curves for different pertubation schemes

In [None]:
fig, axes = plt.subplots(2,1, figsize=(5,10))
estim = 'intgrad_sum_MiF'
logits_list = []
for model_path in model_paths:
    logits_path = f'logits/{model_path}/'
    logits = np.load(logits_path+estim+'.npy')
    mean_logits = np.mean(logits, axis=0)
    logits_list.append([model_path, mean_logits/mean_logits[0]])
for model_name, logits in logits_list:
    if("rand" in model_name):
        label = "Proposed"
    elif("rect" in model_name):
        label = "Rectangle"
    else:
        label = "None"
    x = np.linspace(0, 100, logits.shape[0])
    axes[0].xaxis.set_ticklabels([])
    axes[0].plot(x, logits, label=label)
    axes[0].legend()
    axes[0].set_ylabel('MIF')
    
estim = 'intgrad_sum_LiF'
logits_list = []
for model_path in model_paths:
    logits_path = f'logits/{model_path}/'
    logits = np.load(logits_path+estim+'.npy')
    mean_logits = np.mean(logits, axis=0)
    logits_list.append([model_path, mean_logits/mean_logits[0]])
for model_name, logits in logits_list:
    x = np.linspace(0, 100, logits.shape[0])
    axes[1].plot(x, logits, label=label)
    axes[1].set_ylabel('LIF')
    axes[1].set_xlabel('percentage of masked pixels')
fig.text(-0.05, 0.5, 'logits', va='center', rotation='vertical')
# plt.savefig("graphics/cifar_intgrad.svg", bbox_inches = "tight")