In [1]:
from IPython.display import clear_output
clear_output()

In [1]:
import pathlib
import random
import copy
import gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

import torch
import torch.nn as nn



import quantus

sns.set() 

# Enable GPU. 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import os
os.chdir( "/home/guido/github/physioex-private/papers/icml/" )
print( "Current working directory: ", os.getcwd() )


  from .autonotebook import tqdm as notebook_tqdm


Current working directory:  /home/guido/github/physioex-private/papers/icml


In [5]:
# load dataset and model
from src.model import TimeModule
from collections import OrderedDict
fn = TimeModule.load_from_checkpoint( "synt_model/checkpoint/epoch=4-step=7628-val_acc=0.96.ckpt" ).eval()
fn = nn.Sequential( OrderedDict( [ ( "fn", fn ), ("softmax", nn.Softmax(dim=1)) ] ) ).to(device)

from src.synt_data import SyntDataset, CLASS_DESC
data = SyntDataset(CLASS_DESC)

from torch.utils.data import DataLoader

batch_size = 32

loader = DataLoader(data, batch_size=128, num_workers=8)

x_batch, y_batch = next(iter(loader))

torch.Size([70000, 1000]) torch.Size([70000])


In [17]:
def explainer_wrapper(**kwargs):
    """Wrapper for explainer functions."""
    if kwargs["method"] == "Saliency":
        return saliency_explainer(**kwargs)
    elif kwargs["method"] == "IntegratedGradients":
        return intgrad_explainer(**kwargs)
    elif kwargs["method"] == "InputXGrad":
        return inputxgrad_explainer(**kwargs)
    elif kwargs["method"] == "SpectralGradients":
        return spectralgrads_explainer(**kwargs)
    elif kwargs["method"] == "WSpectralGradients":
        return wspectralgrads_explainer(**kwargs)
    else:
        raise ValueError("Pick an explaination function that exists.")


def saliency_explainer( model, inputs, targets, **kwargs ):
    from src.Saliency import Saliency

    gc.collect()
    torch.cuda.empty_cache()

    inputs = torch.tensor( inputs, device = device, dtype = torch.float32 )
    saliency = Saliency( f = model ).to(device)
    
    explanations = saliency( inputs )
    # explanations shape : (batch_size, num_classes, num_features)
    # select target class for each sample, targets shape: (batch_size,)
    explanations = explanations[ range(explanations.shape[0]), targets ].detach().cpu()
    
    return explanations.numpy()

def inputxgrad_explainer( model, inputs, targets ):
    from src.Saliency import InputXGradient
    
    gc.collect()
    torch.cuda.empty_cache() 
    
    inputxgrad = InputXGradient( f = model ).to(device)
    
    explanations = inputxgrad( inputs )
    # explanations shape : (batch_size, num_classes, num_features)
    # select target class for each sample, targets shape: (batch_size,)
    explanations = explanations[ range(explanations.shape[0]), targets ].detach().cpu()
    
    return explanations

def intgrad_explainer( model, inputs, targets, baseline : float = .0, n_points : int = 100):
    
    from src.IntegratedGradients import IntegratedGradients
        
    gc.collect()
    torch.cuda.empty_cache()

    intgrad = IntegratedGradients( f = model, baseline = baseline, n_points=n_points ).to(device)
    
    explanations = intgrad( inputs )
    # explanations shape : (batch_size, num_classes, num_features)
    # select target class for each sample, targets shape: (batch_size,)
    explanations = explanations[ range(explanations.shape[0]), targets ].detach().cpu()
    
    return explanations
    
def spectralgrads_explainer( model, inputs, targets, fs : int = 100,  Q = 5, nperseg = 200, noverlap = 100):
    
    from src.SpectralGradients import SpectralGradients

    gc.collect()
    torch.cuda.empty_cache()

    
    spectralgrads = SpectralGradients( f = model, fs = fs, Q = Q, nperseg = nperseg, noverlap = noverlap ).to(device)
        
    explanations = spectralgrads( inputs )
    
    # explanations shape : (batch_size, num_classes, frequencies, num_features)
    explanations = explanations.sum(dim=2)
    explanations = explanations[ range(explanations.shape[0]), targets ].detach().cpu()
    
    return explanations    

def wspectralgrads_explainer( model, inputs, targets, fs : int = 100,  Q = 5, nperseg = 200, noverlap = 100):
    from src.SpectralGradients import SpectralGradients

    gc.collect()
    torch.cuda.empty_cache()
    
    spectralgrads = SpectralGradients( f = model, fs = fs, Q = Q, nperseg = nperseg, noverlap = noverlap ).to(device)
        
    explanations = spectralgrads( inputs )
    
    # explanations shape : (batch_size, num_classes, frequencies, num_features)
    # first compute the weights
    weights = torch.abs( explanations.sum(dim = -1) ) # frequency weights
    
    explanations = explanations * weights.unsqueeze(-1)
    explanations = explanations.sum(dim=2)
    
    explanations = explanations[ range(explanations.shape[0]), targets ].detach().cpu()
    
    return explanations
    

In [14]:

print( "Explaining the model with different methods." )
print( " Saliency...")
a_batch_saliency = saliency_explainer( fn, x_batch.to(device), y_batch.to(device) )
print( " InputXGrad...")
a_batch_inputxgrad = inputxgrad_explainer( fn, x_batch.to(device), y_batch.to(device) )
print( " Integrated Gradients...")
a_batch_intgrad = intgrad_explainer( fn, x_batch.to(device), y_batch.to(device) )
print( " Spectral Gradients...")
a_batch_spectralgrads = spectralgrads_explainer( fn, x_batch.to(device), y_batch.to(device) )
print( " WSpectral Gradients...")
a_batch_wspectralgrads = wspectralgrads_explainer( fn, x_batch.to(device), y_batch.to(device) )

explanations ={
    "Saliency": a_batch_saliency,
    "IntegratedGradients": a_batch_intgrad,
    "InputXGrad": a_batch_inputxgrad,
    "SpectralGradients": a_batch_spectralgrads,
    "WSpectralGradients": a_batch_wspectralgrads
}


Explaining the model with different methods.
 Saliency...
 InputXGrad...
 Integrated Gradients...


KeyboardInterrupt: 

In [22]:
metric = quantus.MaxSensitivity()

scores = {}

for method in ["Saliency", "IntegratedGradients", "InputXGrad", "SpectralGradients", "WSpectralGradients"]:
    score = metric(
        model=fn,
        x_batch=x_batch.numpy(),
        y_batch=y_batch.numpy(),
        device=device,
        explain_func=explainer_wrapper,
        explain_func_kwargs={"method": method}
    )
    
    scores[method] = np.array(score)

KeyboardInterrupt: 

In [21]:
len(scores)

128