In [1]:
from exp.utils import *
from exp.models import *
from exp.losses import *
from tqdm.notebook import tqdm
from multiprocessing import Pool
from PIL import Image

import torch
import torch.nn as NN
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

from matplotlib.colors import LinearSegmentedColormap
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients
from captum.attr import GradientShap
from captum.attr import DeepLift
from captum.attr import Occlusion
from captum.attr import NoiseTunnel
from captum.attr import visualization as viz
from captum.attr import Saliency

from ipywidgets import interact

In [2]:
picked_labels = ["Atelectasis", "Cardiomegaly", "Pneumonia"]

In [3]:
architecure = "MNASNET1_v2"

In [4]:
for label in picked_labels:
    print(f"Computing insights for '{label}'")
    
    seed = 92
    n_noises = 1
    insight_size = 10
    model_name = f"{architecure}_{label}"
    s = 224
    image_size = (s, s)#(224, 224)
    labels = get_labels()
    bs = 1
    device = get_cpu()
    
    sigmoid = NN.Sigmoid()
    model = load_model(model_name)
    model.eval();
    model = model.to(device)
    
    seed_everything(seed=seed)

    _, _, test_df = get_dataframes(include_labels=labels, 
                                                 small=False)
    test_df = get_binary_df(label, test_df)

    _, test_tfs = get_transforms(image_size=image_size)

    test_ds  = CRX8_Data(test_df , get_image_path(), label, image_size=image_size, transforms=test_tfs)

    test_dl  = DataLoader(test_ds,  batch_size=bs, shuffle=False)
    
    pos_df = test_df[test_df[label] > 0.5]
    neg_df = test_df[test_df[label] < 0.5]
    pos_df = pos_df.iloc[:insight_size, :]
    neg_df = neg_df.iloc[:insight_size, :]

    pos_ds = CRX8_Data(pos_df , get_image_path(), label, image_size=image_size, transforms=test_tfs)
    neg_ds = CRX8_Data(neg_df , get_image_path(), label, image_size=image_size, transforms=test_tfs)

    pos_dl = DataLoader(pos_ds,  batch_size=bs, shuffle=False)
    neg_dl = DataLoader(neg_ds,  batch_size=bs, shuffle=False)
    
    pos_noise_tunnels, neg_noise_tunnels = [], []
    pos_images, neg_images = [], []
    pos_probs, neg_probs = [], []
    pos_truths, neg_truths = [], []
    
    model = model.to(device)

    # positive examples
    for X, y in tqdm(pos_dl):
        X, y = X.to(device), y.to(device)

        pos_probs = [*pos_probs, *sigmoid(model(X)).cpu().detach().numpy()]
        pos_truths = [*pos_truths, *y.cpu().detach().numpy()]

        nt = NoiseTunnel(IntegratedGradients(model))
        attrs = []
        for _ in range(n_noises):
            model.zero_grad()
            tmp = nt.attribute(X,
                               baselines=X * 0, 
                               nt_type='smoothgrad_sq',
                               nt_samples=3, stdevs=0.2)
            attrs.append(tmp.squeeze())

        attr_ig_nt = torch.zeros((len(attrs), *attrs[0].shape))
        for i in range(len(attrs)): attr_ig_nt[i] = attrs[i][0]

        attr_ig_nt = attr_ig_nt.mean(axis=0).cpu().detach().numpy()
        attr_ig_nt = np.einsum("cwh -> whc", attr_ig_nt)

        pos_noise_tunnels = [*pos_noise_tunnels, attr_ig_nt]
        pos_images = [*pos_images, np.einsum("cwh -> whc", *X.cpu().detach().numpy())]
        pos_noise_tunnels = np.array(pos_noise_tunnels)
        pos_images = np.array(pos_images)
        pos_probs = np.array(pos_probs)
        pos_truths = np.array(pos_truths)


    # negative examples
    for X, y in tqdm(neg_dl):
        X, y = X.to(device), y.to(device)

        neg_probs = [*neg_probs, *sigmoid(model(X)).cpu().detach().numpy()]
        neg_truths = [*neg_truths, *y.cpu().detach().numpy()]

        nt = NoiseTunnel(IntegratedGradients(model))
        attr_ig_nt = nt.attribute(X,
                               baselines=X * 0, 
                               nt_type='smoothgrad_sq',
                               nt_samples=12, stdevs=0.2)
        #attrs = []
        #for _ in range(n_noises):
        #    model.zero_grad()
        #    tmp = nt.attribute(X,
        #                       baselines=X * 0, 
        #                       nt_type='smoothgrad_sq',
        #                       nt_samples=13, stdevs=0.2)
        #    attrs.append(tmp.squeeze())

        #attr_ig_nt = torch.zeros((len(attrs), *attrs[0].shape))
        #for i in range(len(attrs)): attr_ig_nt[i] = attrs[i][0]

        attr_ig_nt = attr_ig_nt.mean(axis=0).cpu().detach().numpy()
        attr_ig_nt = np.einsum("cwh -> whc", attr_ig_nt) 

        neg_noise_tunnels = [*neg_noise_tunnels, attr_ig_nt]
        neg_images = [*neg_images, np.einsum("cwh -> whc", *X.cpu().detach().numpy())]
        neg_noise_tunnels = np.array(neg_noise_tunnels)
        neg_images = np.array(neg_images)
        neg_probs = np.array(neg_probs)
        neg_truths = np.array(neg_truths)
    
    save_insights(pos_images, neg_images, 
              pos_noise_tunnels, neg_noise_tunnels,
              pos_probs, neg_probs, 
              pos_truths, neg_truths,
              model_name)
    print()
    
FERTIG()

Computing insights for 'Atelectasis'
Using the CPU!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))


Insights saved to '/home/favi/work/crx8/insights/MNASNET1_v2_Atelectasis'

Computing insights for 'Cardiomegaly'
Using the CPU!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))


Insights saved to '/home/favi/work/crx8/insights/MNASNET1_v2_Cardiomegaly'

Computing insights for 'Pneumonia'
Using the CPU!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))


Insights saved to '/home/favi/work/crx8/insights/MNASNET1_v2_Pneumonia'

FERTIG! :D


In [None]:
default_cmap = LinearSegmentedColormap.from_list("orange",
                                                 [(0, '#000000'), (0.25, '#000000'), (1, '#fc7b02')], 
                                                 N=256)

In [None]:


@interact(image=(0, 4), percentage=(1, 99))
def drawit(image, percentage):
    _=viz.visualize_image_attr(pos_noise_tunnels[image], 
                             pos_images[image], 
                             method="blended_heat_map", 
                             sign="absolute_value", 
                             outlier_perc=percentage, 
                             show_colorbar=True,
                             cmap=default_cmap,
                             title="Overlayed Integrated Gradients \n with SmoothGrad Squared")