In [1]:
from torchvision.models import resnet50, ResNet50_Weights
import numpy as np
import os
from transform_factory import tensorize, center_crop_224, resize_322, imagenet_normalize
from PIL import Image
import torch.nn.functional as F
import torch
import matplotlib.pyplot as plt

In [2]:
# config
seed = 0
expl_method = "LayerXAct"

In [3]:
model = resnet50(weights=ResNet50_Weights.DEFAULT).eval()

In [4]:
with open(f"./val_seed_{seed}.npy", "rb") as f:
    filepath_list = np.load(f)

orig_imgs = []
orig_expls = []


for img_path in filepath_list[:60]:
    img_name = os.path.basename(img_path)

    orig_img = Image.open(img_path)
    orig_img = imagenet_normalize(tensorize(center_crop_224(resize_322(orig_img))))

    orig_imgs.append(orig_img)

    with open(f"results/val_seed_{seed}_pred_orig_eval_orig_transform_both_sign_all_reduction_sum/{img_name}_expl_{expl_method}_sample_2000_sigma_0.05_seed_{seed}_orig_true_config.npy", "rb") as f:
        orig_expl = np.load(f, allow_pickle=True)
        true_expls = np.load(f, allow_pickle=True)
        configs = np.load(f, allow_pickle=True)

        orig_expl = F.interpolate(torch.tensor(orig_expl).unsqueeze(0), (224, 224), mode='bilinear').squeeze(0).numpy()

        orig_expls.append(orig_expl)


orig_imgs = torch.stack(orig_imgs)
orig_expls = torch.tensor(np.stack(orig_expls))


In [5]:
orig_imgs.shape
orig_expls.shape

torch.Size([60, 1, 224, 224])

In [6]:
y = model(orig_imgs).argmax(dim = 1)

In [13]:
class AOPCTestor():
    def __init__(self, model) -> None:
        self.model = model
        self.softmax = torch.nn.Softmax(dim = 1)


    @staticmethod
    def perturbation(expl, img, ratio, mode="insertion"):
    # expl : [B, C=1, H, W]
    # img : [B, C=3, H, W]
        order = expl.flatten(1).argsort(descending=True)
        n_perturb = int(ratio * order.shape[1])
        n_order = order[:, n_perturb] 
        threshold = expl.flatten(1)[range(len(expl)), n_order]

        if mode == "insertion":
            mask = expl > threshold.reshape(len(expl), 1, 1).unsqueeze(1)
        elif mode == "deletion":
            mask = expl < threshold.reshape(len(expl), 1, 1).unsqueeze(1)        
        return (img * mask).detach()

    @staticmethod
    def conf_perturbation(expl, img, ratio, mode="insertion"):
        order = expl.flatten(1).argsort(descending = True)

    def test_step(self, expl, img, label, mode="orig"):
        for ratio in np.arange(0, 1, 0.05):

            # if mode == "orig":
            img_p = self.perturbation(expl, img, ratio=ratio, mode=mode)
            # else:
                # img_p = self.conf_perturbation(expl, img, ratio=ratio)

            logit = self.model(img_p)
            prob = self.softmax(logit)

            aopc_prob = prob[range(len(label)), label].detach().mean()
            print(aopc_prob)


In [14]:
tester = AOPCTestor(model=model)
tester.test_step(orig_expls, orig_imgs, y, mode="deletion")

tensor(0.7970)
tensor(0.7525)
tensor(0.7331)
tensor(0.6698)
tensor(0.6286)
tensor(0.5587)
tensor(0.5047)
tensor(0.4680)
tensor(0.4318)
tensor(0.3970)
tensor(0.3448)
tensor(0.2864)
tensor(0.2550)
tensor(0.2207)
tensor(0.1773)
tensor(0.1340)
tensor(0.0861)
tensor(0.0608)
tensor(0.0226)
tensor(0.0067)


In [15]:
conf_highs = []
conf_lows = []

alpha = 0.05

for img_path in filepath_list[:60]:
    img_name = os.path.basename(img_path)
    
    with open(f"results/val_seed_{seed}_pred_orig_eval_orig_transform_both_sign_all_reduction_sum/{img_name}_expl_{expl_method}_sample_2000_sigma_0.05_seed_{seed}_results.pkl", "rb") as f:
        results = np.load(f, allow_pickle=True)

    result = results[int(alpha / 0.05)]

    conf_highs.append(torch.tensor(result['conf_high']))
    conf_lows.append(torch.tensor(result['conf_low']))


conf_highs = torch.stack(conf_highs)
conf_lows = torch.stack(conf_lows)

In [17]:
tester = AOPCTestor(model=model)
tester.test_step(conf_lows, orig_imgs, y, mode='deletion')

tensor(0.7947)
tensor(0.7463)
tensor(0.6975)
tensor(0.6705)
tensor(0.6470)
tensor(0.6061)
tensor(0.5746)
tensor(0.5341)
tensor(0.5231)
tensor(0.5064)
tensor(0.4614)
tensor(0.4480)
tensor(0.4294)
tensor(0.3770)
tensor(0.3579)
tensor(0.3152)
tensor(0.2666)
tensor(0.2114)
tensor(0.1080)
tensor(0.0330)
