In [1]:
import cv2
%matplotlib inline
from matplotlib import pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
import numpy as np
from torchvision.transforms import transforms
from datasets.cifar10 import CIFAR10Loader
from utils.io_utils import *

import torch

from models.network.DDN import DDN

from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from utils.explain_utils import calculate_ifc

  warn(


In [2]:
DEVICE = 'cuda'
EPOCHS = 100
BATCH_SIZE = 1
EXP_NAME = "train"
CLS = 10
INPUT_SHAPE = (3,32,32)
BETA = 1e-2
GAMMA = 1e-4
ETA = 1e-3

In [3]:
pipeline = transforms.Compose([transforms.ToTensor(),
                        transforms.Resize(INPUT_SHAPE[1:], antialias=True),
                        transforms.Normalize(mean=(0.4914,0.4822,0.4465), std=(0.2023,0.1994,0.2010))
                        ])



root = 'D:/project/Discardable-Distributed-Networks/data'
dataLoader = CIFAR10Loader(root=root,transform=pipeline,batch_size=BATCH_SIZE)
MODEL_PATH = f'D:/project/Discardable-Distributed-Networks/save/{EXP_NAME}/'
create_directory_if_not_exists(MODEL_PATH)

# dataLoader.generate_img(resize=(32,32))

input_sample = torch.randn((1,) + INPUT_SHAPE).to(DEVICE)
print("input size",input_sample.shape)

Files already downloaded and verified
Files already downloaded and verified
input size torch.Size([1, 3, 32, 32])


In [4]:


class DDN_base(DDN):
    def __init__(self, in_places):
        super().__init__(in_places)

    def forward(self, x):
        mu_vars_list = []
        fusion_list = []
        x = self.conv1(x)
        for i in range(self.l):

            x = self.blocks[i].drop_forward(x,1.0)

            mu_vars, fusion = self.fusions[i](x['padding'])

            if self.training:
                mu_vars_list += mu_vars
                f = dict()
                f['feature'] = []
                for feature in x['no_padding']:
                    f['feature'].append(self.flatten(feature))
                f['fusion'] = self.flatten(fusion)
                fusion_list.append(f)
            x = self.downsamples[i](fusion)

        x = self.head(x)
        if self.training:
            return mu_vars_list, fusion_list, x
        else:
            return x

In [5]:



def explain_Model(model_expand,model_base):
    model_expand.eval()
    model_base.eval()
    img = cv2.imread("D:\\project\\Discardable-Distributed-Networks\\data\\cifar-10-batches-py\\val\\0\\774.png")
    data = pipeline(img).to(DEVICE)
    img = cv2.resize(img,(256,256))
    img = img/float(255.0)
    cv2.imshow("img",img)
    data = data.unsqueeze(0)
    targets = [ClassifierOutputTarget(0)]

    target_layers_expand = [model_expand.downsamples[1]]
    cam_expand = GradCAM(model=model_expand, target_layers=target_layers_expand)

    grayscale_cam_expand = cam_expand(input_tensor=data, targets=targets)
    grayscale_cam_expand = grayscale_cam_expand[0, :]
    grayscale_cam_expand = cv2.resize(grayscale_cam_expand,(256,256))
    visualization_expand = show_cam_on_image(img, grayscale_cam_expand, use_rgb=True)
    cv2.imshow(f"camf expand",visualization_expand)
    cv2.imwrite("expand.png",visualization_expand)

    target_layers_base = [model_base.downsamples[1]]
    cam_base = GradCAM(model=model_base, target_layers=target_layers_base)

    grayscale_cam_base = cam_base(input_tensor=data, targets=targets)
    grayscale_cam_base = grayscale_cam_base[0, :]
    grayscale_cam_base = cv2.resize(grayscale_cam_base,(256,256))
    visualization_base = show_cam_on_image(img, grayscale_cam_base, use_rgb=True)
    cv2.imshow(f"camf base",visualization_base)
    cv2.imwrite("base.png",visualization_base)

    ifc = calculate_ifc([torch.tensor(grayscale_cam_expand),torch.tensor(grayscale_cam_base)])[0][1]
    print("1-ifc:",1-ifc)
    cv2.waitKey(0)

    # ifc_data = ifc_dict['ifc']
    # names = ifc_dict['names']
    #
    # fig, axs = plt.subplots(figsize=(14, 8))
    #
    # x = np.arange(len(names))
    # width = 0.3
    #
    # bars = []
    # for i, ifc in enumerate(ifc_data):
    #     bar = axs.bar(x + i * width, ifc, width, label=names[i])
    #     bars.append(bar)
    #     for idx, val in enumerate(ifc):
    #         axs.text(x[idx] + i * width, val + 0.01, str(round(val.item(), 2)), ha='center')
    #
    # axs.set_title('Accuracy Comparison for Different Models and Dropout Rates')
    # axs.set_xlabel('Models')
    # axs.set_ylabel('Accuracy')
    # axs.set_xticklabels(ifc_dict['names'], rotation=45, ha='right')
    # axs.legend()
    #
    # # 显示图表
    # plt.show()




In [6]:
def explain(MODEL_NAME='DDN'):
    model_name = os.path.join(MODEL_PATH,f'{MODEL_NAME}.ckpt')
    model_expand = DDN(in_places=3)
    model_expand.load_state_dict(torch.load(model_name))
    model_expand = model_expand.to(DEVICE)

    model_base = DDN_base(in_places=3)
    model_base.load_state_dict(torch.load(model_name))
    model_base = model_base.to(DEVICE)

    explain_Model(model_expand, model_base)

In [7]:
explain()