In [53]:
from torchvision.models import resnet50, ResNet50_Weights
import numpy as np
import os
from transform_factory import tensorize, center_crop_224, resize_322, imagenet_normalize
from PIL import Image
import torch.nn.functional as F
import torch
import matplotlib.pyplot as plt

In [54]:
# config
seed = 3
expl_method = "LayerXAct"

In [55]:
model = resnet50(weights=ResNet50_Weights.DEFAULT).eval()

In [56]:
with open(f"./val_seed_{seed}.npy", "rb") as f:
    filepath_list = np.load(f)

orig_imgs = []
orig_expls = []


for img_path in filepath_list[:1]:
    img_name = os.path.basename(img_path)

    orig_img = Image.open(img_path)
    orig_img = imagenet_normalize(tensorize(center_crop_224(resize_322(orig_img))))

    orig_imgs.append(orig_img)

    with open(f"results/val_seed_{seed}_pred_orig_eval_orig_transform_both_sign_all_reduction_sum/{img_name}_expl_{expl_method}_sample_2000_sigma_0.05_seed_{seed}_orig_true_config.npy", "rb") as f:
        orig_expl = np.load(f, allow_pickle=True)
        true_expls = np.load(f, allow_pickle=True)
        configs = np.load(f, allow_pickle=True)

        orig_expl = F.interpolate(torch.tensor(orig_expl).unsqueeze(0), (224, 224), mode='bilinear').squeeze(0).numpy()

        orig_expls.append(orig_expl)


orig_imgs = torch.stack(orig_imgs)
orig_expls = torch.tensor(np.stack(orig_expls))


In [57]:
orig_imgs.shape
orig_expls.shape

torch.Size([1, 1, 224, 224])

In [58]:
y = model(orig_imgs).argmax(dim = 1)

In [59]:
class AOPCTestor():
    def __init__(self, model) -> None:
        self.model = model
        self.softmax = torch.nn.Softmax(dim = 1)


    @staticmethod
    def perturbation(expl, img, ratio, mode="insertion"):
    # expl : [B, C=1, H, W]
    # img : [B, C=3, H, W]
        order = expl.flatten(1).argsort(descending=True)
        n_perturb = int(ratio * order.shape[1])
        n_order = order[:, n_perturb] 
        threshold = expl.flatten(1)[range(len(expl)), n_order]

        if mode == "insertion":
            mask = expl > threshold.reshape(len(expl), 1, 1).unsqueeze(1)
        elif mode == "deletion":
            mask = expl < threshold.reshape(len(expl), 1, 1).unsqueeze(1)        
        return (img * mask).detach()

    @staticmethod
    def conf_perturbation(expl, img, ratio, conf_high, conf_low, mode="insertion"):
        attr_order = expl.flatten(1).argsort(descending = True)
        

    def test_step(self, expl, img, label, mode="orig"):
        prob_list = []
        for ratio in np.arange(0, 1, 0.05):

            # if mode == "orig":
            img_p = self.perturbation(expl, img, ratio=ratio, mode=mode)
            # else:
                # img_p = self.conf_perturbation(expl, img, ratio=ratio)

            logit = self.model(img_p)
            prob = self.softmax(logit)

            aopc_prob = prob[range(len(label)), label].detach().mean()
            prob_list.append(aopc_prob)

        return prob_list

In [60]:
tester = AOPCTestor(model=model)
prob_list = tester.test_step(orig_expls, orig_imgs, y, mode="insertion")

In [61]:
print(prob_list, torch.stack(prob_list).mean())

[tensor(0.0008), tensor(0.4343), tensor(0.4013), tensor(0.4883), tensor(0.4493), tensor(0.4614), tensor(0.4866), tensor(0.5751), tensor(0.5752), tensor(0.5577), tensor(0.5684), tensor(0.5726), tensor(0.5646), tensor(0.6041), tensor(0.6090), tensor(0.5876), tensor(0.5976), tensor(0.6294), tensor(0.6582), tensor(0.6791)] tensor(0.5250)


In [62]:
conf_highs = []
conf_lows = []

alpha = 0.05

for img_path in filepath_list[:1]:
    img_name = os.path.basename(img_path)
    
    with open(f"results/val_seed_{seed}_pred_orig_eval_orig_transform_both_sign_all_reduction_sum/{img_name}_expl_{expl_method}_sample_2000_sigma_0.05_seed_{seed}_results.pkl", "rb") as f:
        results = np.load(f, allow_pickle=True)

    result = results[int(alpha / 0.05)]

    conf_highs.append(torch.tensor(result['conf_high']))
    conf_lows.append(torch.tensor(result['conf_low']))


conf_highs = torch.stack(conf_highs)
conf_lows = torch.stack(conf_lows)

  conf_highs.append(torch.tensor(result['conf_high']))
  conf_lows.append(torch.tensor(result['conf_low']))


In [63]:
tester = AOPCTestor(model=model)
# tester.test_step(conf_lows, orig_imgs, y, mode='deletion')
high_prob_list = tester.test_step(conf_highs, orig_imgs, y, mode='insertion')


In [64]:
print(high_prob_list, torch.stack(high_prob_list).mean())


[tensor(0.0008), tensor(0.4237), tensor(0.3915), tensor(0.4767), tensor(0.4399), tensor(0.4802), tensor(0.5314), tensor(0.5845), tensor(0.5868), tensor(0.5563), tensor(0.5368), tensor(0.5647), tensor(0.5651), tensor(0.6404), tensor(0.6058), tensor(0.6386), tensor(0.6163), tensor(0.6533), tensor(0.6110), tensor(0.6658)] tensor(0.5285)


In [39]:
low_prob_list = tester.test_step(conf_lows, orig_imgs, y, mode='deletion')


In [40]:
low_prob_list

[tensor(0.9325),
 tensor(0.8562),
 tensor(0.7736),
 tensor(0.6832),
 tensor(0.6072),
 tensor(0.5264),
 tensor(0.4559),
 tensor(0.3888),
 tensor(0.3296),
 tensor(0.2809),
 tensor(0.2395),
 tensor(0.2081),
 tensor(0.1850),
 tensor(0.1489),
 tensor(0.1249),
 tensor(0.1081),
 tensor(0.0777),
 tensor(0.0471),
 tensor(0.0276),
 tensor(0.0166)]