In [1]:
import innvestigate
import pickle
import numpy as np
import matplotlib as plt
import pandas as pd

from tqdm import tqdm
from utils import plot
from pathlib import Path
from scipy.special import softmax
from torchvision import transforms
from keras.models import load_model
from keras.preprocessing.image import load_img, img_to_array

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


# Load models

In [2]:
folder = Path("data/")

In [3]:
resnet = load_model(folder / 'resnet.h5')

Instructions for updating:
Colocations handled automatically by placer.




In [4]:
resnet.compile(optimizer="adam", loss="categorical_crossentropy")

In [5]:
stylenet = load_model(folder / 'stylenet.h5')

In [6]:
stylenet.compile(optimizer="adam", loss="categorical_crossentropy")

Transform the image in the same way as the network has seen them during training

In [7]:
train_transforms = transforms.Compose([
                                  transforms.Resize(256),
                                  transforms.CenterCrop(224),
                                  transforms.ToTensor(),
                                  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

#### Load image paths to images we created earlier

In [9]:
data = pd.read_csv(folder / "imagenette2_scr/scramble.csv", )
data = data.set_index(data.columns[0])
data[:5]

Unnamed: 0_level_0,class,path,scrambled_indices
Unnamed: 0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,n01440764,/home/malte/Dokumente/Masterarbeit/data/imagen...,"{0: array([2, 3, 1, 0]), 1: array([0, 3, 1, 2]..."
1,n01440764,/home/malte/Dokumente/Masterarbeit/data/imagen...,"{0: array([1, 3, 0, 2]), 1: array([1, 3, 0, 2]..."
2,n01440764,/home/malte/Dokumente/Masterarbeit/data/imagen...,"{0: array([0, 1, 2, 3]), 1: array([1, 3, 0, 2]..."
3,n02102040,/home/malte/Dokumente/Masterarbeit/data/imagen...,"{0: array([2, 0, 1, 3]), 1: array([0, 2, 3, 1]..."
4,n02102040,/home/malte/Dokumente/Masterarbeit/data/imagen...,"{0: array([1, 3, 2, 0]), 1: array([2, 3, 0, 1]..."


#### Define a few simple helper functions

In [10]:
def process_img(img_path):
        img = load_img(img_path)
        img_arr = np.array(train_transforms(img)[np.newaxis, :])
        img_plt = np.moveaxis(img_arr.squeeze(), 0, 2)
        
        return img_arr, img_plt

In [11]:
def evaluate_model(model, img_arr):
        output = model.predict(img_arr)
        output_class = np.argmax(output)
        output_smax = softmax(output)[0][output_class]
        
        return output, output_class, output_smax

In [12]:
def innv_lrp(model, img_arr, output_class):
        analyzer = innvestigate.create_analyzer("lrp.sequential_preset_a_flat", 
                                                model, allow_lambda_layers=True, neuron_selection_mode="index")
        a = analyzer.analyze(img_arr, neuron_selection=output_class)
        a_r = np.moveaxis(a.squeeze(), 0, 2)
        
        return a_r

In [13]:
def get_info(a_r):
    a_m = a_r.sum(axis=np.argmax(np.asarray(a_r.shape) == 3))
    # general mean
    g_mean = np.mean(a_m)
    # mean of a_m's positive values
    pos_mean = np.nanmean(np.where(a_m > 0, a_m, np.nan), axis=(0,1))
    # mean of a_m's negative values
    neg_mean = np.nanmean(np.where(a_m < 0, a_m, np.nan), axis=(0,1))
    
    # index of pixel values greater than positive mean
    pos_ind = np.argwhere(a_m > pos_mean)
    # index of pixel values smaller than negative mean
    neg_ind = np.argwhere(a_m < neg_mean)
    
    return g_mean, pos_mean, neg_mean, pos_ind, neg_ind

In [14]:
def analysis(model, data, m_name, save_fig=True, print_out=True):
    data_dict = {}
    
    for ind, _class, img_path, _ in tqdm(data.itertuples()):
        # load images to arrays
        img_arr, img_plt = process_img(img_path)
        # run model on image and save class and softmax
        output, output_class, output_smax = evaluate_model(model, img_arr)
        if print_out:
            print(f'\nOutputclass: {output_class}')
            print(f'\nSoftmax: {output_smax}')
        # run innvestigate with lrp and save plot
        a_r = innv_lrp(model, img_arr, output_class)
        if save_fig:
            hmap = plot(a_r, img_plt, dilation=.5, percentile=99, alpha=.3, cmap="coolwarm", _sum=True)
            save_path = Path(img_path)
            hmap.figure.savefig(folder_plt / f'{m_name}_{save_path.stem}_lrp.png')
        # extract coordinates
        g_mean, pos_mean, neg_mean, pos_ind, neg_ind = get_info(a_r)
        # save to dictionary
        data_dict[ind] = {"g_mean": g_mean, "pos_mean": pos_mean, "neg_mean": neg_mean, "pos_ind": pos_ind, "neg_ind":neg_ind,
                         "output_class": output_class, "softmax": output_smax, "output": output}
    
    return data_dict

## Run LRP analysis over images

In [15]:
folder_plt = folder / "imagenette2_lrp"
folder_plt.mkdir(exist_ok=True)

In [17]:
res_dict = analysis(resnet, data, "resnet", save_fig=False, print_out=False)

0it [00:00, ?it/s]

Instructions for updating:
Use tf.cast instead.


30it [27:57, 55.92s/it]


#### Save data dicitonaries to pickle file

In [18]:
with open(folder / "lrp_regions_resnet.pkl", 'wb') as f:
    pickle.dump(res_dict, f)

In [16]:
del resnet

In [17]:
stl_dict = analysis(stylenet, data, "stylenet", save_fig=False)

0it [00:00, ?it/s]


Outputclass: 0

Softmax: 0.9501403570175171
Instructions for updating:
Use tf.cast instead.


1it [00:23, 23.28s/it]


Outputclass: 0

Softmax: 0.9995155930519104


2it [00:42, 22.20s/it]


Outputclass: 0

Softmax: 0.273191899061203


3it [01:03, 21.59s/it]


Outputclass: 217

Softmax: 0.9028735160827637


4it [01:25, 21.68s/it]


Outputclass: 217

Softmax: 0.6586723327636719


5it [01:47, 21.92s/it]


Outputclass: 217

Softmax: 0.9552500247955322


6it [02:12, 22.74s/it]


Outputclass: 632

Softmax: 0.6883887052536011


7it [02:37, 23.57s/it]


Outputclass: 485

Softmax: 0.8192846775054932


8it [03:05, 24.86s/it]


Outputclass: 482

Softmax: 0.4470788538455963


9it [03:34, 26.03s/it]


Outputclass: 491

Softmax: 0.9561533331871033


10it [04:05, 27.47s/it]


Outputclass: 491

Softmax: 0.7333689332008362


11it [04:38, 29.17s/it]


Outputclass: 882

Softmax: 0.8825289011001587


12it [05:13, 30.95s/it]


Outputclass: 497

Softmax: 0.7476836442947388


13it [05:49, 32.52s/it]


Outputclass: 442

Softmax: 0.21035031974315643


14it [06:29, 34.60s/it]


Outputclass: 406

Softmax: 0.4893115162849426


15it [07:09, 36.25s/it]


Outputclass: 566

Softmax: 0.9578974843025208


16it [07:56, 39.66s/it]


Outputclass: 513

Softmax: 0.518923282623291


17it [08:52, 44.62s/it]


Outputclass: 566

Softmax: 0.8354141116142273


18it [10:02, 52.04s/it]


Outputclass: 569

Softmax: 0.8913399577140808


19it [11:07, 56.05s/it]


Outputclass: 569

Softmax: 0.9990238547325134


20it [12:11, 58.48s/it]


Outputclass: 569

Softmax: 0.5900624394416809


21it [13:34, 65.75s/it]


Outputclass: 571

Softmax: 0.9908638596534729


22it [15:00, 71.71s/it]


Outputclass: 571

Softmax: 0.9972628951072693


23it [16:07, 70.45s/it]


Outputclass: 919

Softmax: 0.1431838423013687


24it [17:37, 76.35s/it]


Outputclass: 574

Softmax: 0.921381950378418


25it [19:19, 84.09s/it]


Outputclass: 574

Softmax: 1.0


26it [21:05, 90.65s/it]


Outputclass: 574

Softmax: 0.9999236464500427


27it [23:02, 98.57s/it]


Outputclass: 701

Softmax: 0.9997501373291016


28it [25:22, 110.95s/it]


Outputclass: 701

Softmax: 0.9995442032814026


29it [27:39, 118.80s/it]


Outputclass: 701

Softmax: 0.9856162667274475


30it [30:02, 60.09s/it] 


#### Save data dicitonaries to pickle file

In [18]:
with open(folder / "lrp_regions_stylenet.pkl", 'wb') as f:
    pickle.dump(stl_dict, f)