In [46]:
import os
from data.utils import get_dataset
import captum
import random
import torch
import torchvision
from torchvision import transforms
import numpy as np
from captum.attr import IntegratedGradients
from captum.attr import Saliency
from captum.attr import DeepLift
from captum.attr import NoiseTunnel


from DebiAN.models.simple_cls import get_simple_classifier

from matplotlib import pyplot as plt
from matplotlib import cm
from PIL import Image
from torchvision import models
from tqdm import tqdm
from os import path

# Functions

In [16]:
from matplotlib import cm
import numpy as np

def overlay(input, cam, alpha=0.9, colormap="jet"):
    # inspired by https://github.com/frgfm/torch-cam/blob/master/torchcam/utils.py

    img = transforms.ToPILImage()(input)
    # normalize to 0,1
    cam -= torch.min(cam)
    cam /= torch.max(cam)
    cam_img = transforms.ToPILImage(mode='F')(cam)

    if type(colormap) is str:
        cmap = cm.get_cmap(colormap)
    else:
        cmap = colormap

    # Resize mask and apply colormap
    overlay_raw = cam_img.resize(img.size, resample=Image.BILINEAR)
    overlay_raw = cam_img.resize(img.size, resample=Image.NEAREST)
    overlay = overlay_raw
    overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8)
    # Overlay the image with the mask
    overlayed_img = Image.fromarray((alpha * np.asarray(img) + (1 - alpha) * overlay).astype(np.uint8))
    return overlayed_img


def attribute_image_features(net,algorithm, input,label, **kwargs):
    net.zero_grad()
    tensor_attributions = algorithm.attribute(input,
                                              target=label,
                                              **kwargs
                                             )
    return tensor_attributions

In [17]:
# Seed everything
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed_all(1234)

# Parameters

In [47]:
data_path  ='data/'
result_path  ='npz/cmnist/'
split='test'
percent='5pct'
_last_hidden_layer=100

model_path ="DebiAN/exp/cmnist/MLP_{}/debian_bs_256_wd_1E-04_lr_1E-03_cmnist_{}/last.pth".format(_last_hidden_layer,percent)
print(model_path)



Saliency_methods=['integrated_gradient','smoothgrad','deeplift']
Saliency_methods=['smoothgrad','deeplift']

In [19]:
os.makedirs(result_path, exist_ok = True)
print("Directory '%s' created successfully" %result_path)

Directory 'npz/cmnist/' created successfully


# Load Data and model

In [31]:
dataset=get_dataset('cmnist',
        data_dir=data_path,
        dataset_split=split,
        transform_split="valid",
        percent='1pct')

if dataset.__len__()==0:
    print('Error - data not loaded')

model=get_simple_classifier('mlp',last_hidden_layer=_last_hidden_layer)
state_dict = torch.load(model_path)
model.load_state_dict(state_dict['model'])
model = model.eval()
print(model)

MLP(
  (feature): Sequential(
    (0): Linear(in_features=2352, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=100, bias=True)
    (3): ReLU()
    (4): Linear(in_features=100, out_features=100, bias=True)
    (5): ReLU()
  )
  (classifier): Linear(in_features=100, out_features=10, bias=True)
)


# Extract Saliency map for a specific method 
## methods in [Integrated Gradient,SmoothGrad,Deeplift]

In [None]:
for method in Saliency_methods:
    print (method)
    
    
    align_saliencies_maps=[]
    conflict_saliencies_maps=[]
    
    for idx_img in tqdm(range(len(dataset))):
        
        
        image,label,idx = dataset[idx_img]
        l_target,l_bias=label
        l_target=l_target.item()
        l_bias=l_bias.item()
        correct=0
        bias=0
        x=torch.unsqueeze(image, 0)
        x.requires_grad = True
        
        logits=model(x)
        pred = logits.data.max(1, keepdim=True)[1].squeeze(1)
        pred=pred.item()
        
    
        if method=='integrated_gradient':
            # Integrated Gradient
            ig = IntegratedGradients(model)
            attr_ig, delta = attribute_image_features(model,ig, x,l_target, baselines=x * 0, return_convergence_delta=True)
            
            saliency_map= np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0)).sum(2)

        if method=='smoothgrad':
            # Integrated Gradient with Smoothgrad
            ig = IntegratedGradients(model)
            nt = NoiseTunnel(ig)
            attr_ig_nt = attribute_image_features(model,nt,x,l_target, baselines=x * 0, nt_type='smoothgrad_sq',
                                                  nt_samples=100, stdevs=0.2)
            saliency_map= np.transpose(attr_ig_nt.squeeze(0).cpu().detach().numpy(), (1, 2, 0)).sum(2)


        if method=='deeplift':
            #DeepLift
            dl = DeepLift(model)
            attr_dl = attribute_image_features(model,dl, x,l_target,baselines=x * 0)
            saliency_map= np.transpose(attr_dl.squeeze(0).cpu().detach().numpy(), (1, 2, 0)).sum(2)
        
        
        logits=model(x)
        pred = logits.data.max(1, keepdim=True)[1].squeeze(1)
        pred=pred.item()


        if l_bias==l_target:
            align_saliencies_maps.append(saliency_map)
        else:
            conflict_saliencies_maps.append(saliency_map)

    
    #Convert the list of maps into one tensor
    
    print("align_data {} samples for method {}".format(len(align_saliencies_maps),method))
    print("conflict_data {} samples for method {}".format(len(conflict_saliencies_maps),method)) 
    print("saving saliencies to npz for method {}".format(method))
    npz_name_allign = 'align_'+method + "_mlp{}_{}_cmnist_{}".format(split,_last_hidden_layer,percent)
    np.savez(path.join(result_path, npz_name_allign), align_saliencies_maps)
    npz_name_conflict = 'conflict_'+method + "_mlp{}_{}_cmnist_{}".format(split,_last_hidden_layer,percent)
    np.savez(path.join(result_path, npz_name_conflict), conflict_saliencies_maps)

    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    

smoothgrad


 23%|██▎       | 2307/10000 [10:12<33:37,  3.81it/s]