In [110]:
from torchvision.models import resnet50, ResNet50_Weights
import numpy as np
import os
from transform_factory import tensorize, center_crop_224, resize_322, imagenet_normalize
from PIL import Image
import torch.nn.functional as F
import torch
import matplotlib.pyplot as plt

In [129]:
# config
seed = 3
data_num = 7
expl_method = "LayerXAct"

In [130]:
model = resnet50(weights=ResNet50_Weights.DEFAULT).eval()

In [131]:
with open(f"./val_seed_{seed}.npy", "rb") as f:
    filepath_list = np.load(f)

orig_imgs = []
orig_expls = []


for img_path in filepath_list[:data_num]:
    img_name = os.path.basename(img_path)

    orig_img = Image.open(img_path)
    orig_img = imagenet_normalize(tensorize(center_crop_224(resize_322(orig_img))))

    orig_imgs.append(orig_img)

    with open(f"results/val_seed_{seed}_pred_orig_eval_orig_transform_both_sign_all_reduction_sum/{img_name}_expl_{expl_method}_sample_2000_sigma_0.05_seed_{seed}_orig_true_config.npy", "rb") as f:
        orig_expl = np.load(f, allow_pickle=True)
        true_expls = np.load(f, allow_pickle=True)
        configs = np.load(f, allow_pickle=True)

        orig_expl = F.interpolate(torch.tensor(orig_expl).unsqueeze(0), (224, 224), mode='bilinear').squeeze(0).numpy()

        orig_expls.append(orig_expl)


orig_imgs = torch.stack(orig_imgs)
orig_expls = torch.tensor(np.stack(orig_expls))


In [132]:
orig_imgs.shape
orig_expls.shape

torch.Size([7, 1, 224, 224])

In [133]:
y = model(orig_imgs).argmax(dim = 1)

In [134]:
class AOPCTestor():
    def __init__(self, model) -> None:
        self.model = model
        self.softmax = torch.nn.Softmax(dim = 1)


    @staticmethod
    def perturbation(expl, img, ratio, mode="insertion"):
    # expl : [B, C=1, H, W]
    # img : [B, C=3, H, W]
        order = expl.flatten(1).argsort(descending=True)
        n_perturb = int(ratio * order.shape[1])
        n_order = order[:, n_perturb] 
        threshold = expl.flatten(1)[range(len(expl)), n_order]

        if mode == "insertion":
            mask = expl > threshold.reshape(len(expl), 1, 1).unsqueeze(1)
        elif mode == "deletion":
            mask = expl < threshold.reshape(len(expl), 1, 1).unsqueeze(1)        
        return (img * mask).detach()

    @staticmethod
    def conf_perturbation(expl, img, ratio, conf_high, conf_low, mode="insertion"):
        attr_order = expl.flatten(1).argsort(descending = True)
        

    def test_step(self, expl, img, label, mode="orig"):
        prob_list = []
        for ratio in np.arange(0, 1, 0.05):

            # if mode == "orig":
            img_p = self.perturbation(expl, img, ratio=ratio, mode=mode)
            # else:
                # img_p = self.conf_perturbation(expl, img, ratio=ratio)

            logit = self.model(img_p)
            prob = self.softmax(logit)

            aopc_prob = prob[range(len(label)), label].detach().mean()
            prob_list.append(aopc_prob)

        return prob_list

In [142]:
tester = AOPCTestor(model=model)
prob_list = tester.test_step(orig_expls, orig_imgs, y, mode="insertion")

In [143]:
print(prob_list, torch.stack(prob_list).mean())

[tensor(0.0007), tensor(0.0673), tensor(0.1571), tensor(0.2805), tensor(0.4405), tensor(0.4561), tensor(0.4427), tensor(0.4996), tensor(0.5040), tensor(0.5154), tensor(0.5757), tensor(0.6047), tensor(0.6672), tensor(0.6676), tensor(0.6514), tensor(0.6882), tensor(0.6874), tensor(0.6525), tensor(0.6697), tensor(0.6824)] tensor(0.4955)


In [144]:
conf_highs = []
conf_lows = []

alpha = 0.05

for img_path in filepath_list[:data_num]:
    img_name = os.path.basename(img_path)
    
    with open(f"results/val_seed_{seed}_pred_orig_eval_orig_transform_both_sign_all_reduction_sum/{img_name}_expl_{expl_method}_sample_2000_sigma_0.05_seed_{seed}_results.pkl", "rb") as f:
        results = np.load(f, allow_pickle=True)

    result = results[int(alpha / 0.05)]

    conf_highs.append(torch.tensor(result['conf_high']))
    conf_lows.append(torch.tensor(result['conf_low']))


conf_highs = torch.stack(conf_highs)
conf_lows = torch.stack(conf_lows)

  conf_highs.append(torch.tensor(result['conf_high']))
  conf_lows.append(torch.tensor(result['conf_low']))


In [145]:
tester = AOPCTestor(model=model)
# tester.test_step(conf_lows, orig_imgs, y, mode='deletion')
high_prob_list = tester.test_step(conf_highs, orig_imgs, y, mode='deletion')


In [146]:
print(high_prob_list, torch.stack(high_prob_list).mean())


[tensor(0.6810), tensor(0.6516), tensor(0.6147), tensor(0.6055), tensor(0.5416), tensor(0.4369), tensor(0.4124), tensor(0.3771), tensor(0.2613), tensor(0.1978), tensor(0.1859), tensor(0.1885), tensor(0.1092), tensor(0.1035), tensor(0.1372), tensor(0.1273), tensor(0.0766), tensor(0.0445), tensor(0.0083), tensor(0.0027)] tensor(0.2882)


In [147]:
low_prob_list = tester.test_step(conf_lows, orig_imgs, y, mode='insertion')


In [148]:
print(low_prob_list, torch.stack(low_prob_list).mean())


[tensor(0.0007), tensor(0.0958), tensor(0.2401), tensor(0.3177), tensor(0.4049), tensor(0.4452), tensor(0.5223), tensor(0.5074), tensor(0.4951), tensor(0.6043), tensor(0.5396), tensor(0.5700), tensor(0.5866), tensor(0.6275), tensor(0.6755), tensor(0.6619), tensor(0.6581), tensor(0.6846), tensor(0.6839), tensor(0.6946)] tensor(0.5008)
