# PatternAttribution


In this notebook, we analyze:

* SVD of weight $W$, pattern $A$ and $W \odot A$
* intermediate matrix chain items $S_l^{1/2} V_l U_{l+1} S_{l+1}^{1/2}$
* ratio of singular values 

In [None]:
# select gpu
%env CUDA_VISIBLE_DEVICES=

In [None]:
%load_ext autoreload
%autoreload 2
import tensorflow
import tensorflow as tf
import warnings

import innvestigate
import matplotlib.pyplot as plt

import numpy as np
import PIL 
import json 
import copy
import contextlib

import imp
import numpy as np
import os

from skimage.measure import compare_ssim 
import pickle
from collections import OrderedDict
from IPython.display import IFrame, display

import keras
import keras.backend
import keras.models


import innvestigate
import innvestigate.applications.imagenet
import innvestigate.utils as iutils
import innvestigate.utils as iutils
import innvestigate.utils.visualizations as ivis
from innvestigate.analyzer.relevance_based.relevance_analyzer import LRP
from innvestigate.analyzer.base import AnalyzerNetworkBase, ReverseAnalyzerBase
from innvestigate.analyzer.deeptaylor import DeepTaylor
innvestigate.analyzer.analyzers
from innvestigate.analyzer import PatternAttribution

import time
import tqdm

import seaborn as sns

import itertools
import matplotlib as mpl
from tensorflow.python.client import device_lib

import deeplift

import os
import sys
from when_explanations_lie import *

In [None]:
device_lib.list_local_devices()

In [None]:
# path to imagenet validation
host = ! hostname
host = host[0]

with open('imagenet_dir.json') as f:
    imagenet_dir = json.load(f)[host]

#imagenet_val_dir = "/home/leonsixt/tmp/imagenet/imagenet-raw/validation/"
# path to examplary image
ex_image_path = "n01534433/ILSVRC2012_val_00015410.JPEG"
# number of images to run the evaluation
#n_selected_imgs = 200
n_selected_imgs = 10

load_weights = True
model_names = ['resnet50', 'vgg16']

In [None]:
def load_model_and_meta(model_name, load_weights=True, clear_session=True):
    if clear_session:
        keras.backend.clear_session()
    if model_name in ['vgg16', 'resnet50']:
        model, innv_net, color_conversion = load_model(model_name, load_weights) 
        meta = ImageNetMeta(model, model_name, innv_net, n_selected_imgs, 
                            imagenet_dir, ex_image_path)
    elif model_name == 'cifar10':
        model, _, _ = load_model('cifar10', load_weights)
        meta = CIFAR10Meta(model, n_selected_imgs)
    else:
        raise ValueError()
    return model, meta
    

In [None]:
def symetric_min_max(x):
    vmax = max(-x.min(), x.max())
    vmin = min(x.min(), -x.max())
    return {'vmin': vmin, 'vmax': vmax, 'cmap': 'seismic'}


In [None]:
model, meta = load_model_and_meta('vgg16')
innv_net = innvestigate.applications.imagenet.vgg16(load_weights=True, load_patterns=True)

In [None]:
def to_mat(x):
    if len(x.shape) == 2:
        return x
    
    h, w, cin, cout = x.shape
    return x.reshape((h*w*cin, cout))


def view_as_conv1x1(x):
    
    if len(x.shape) == 2:
        cin, cout = x.shape
        if cin == 25088:
            cin_conv = 25088 // (7*7)
            x = x.reshape(7, 7, cin_conv, cout) 
            return x[3, 3, :, :]
        else:
            return x
    
    h, w, cin, cout = x.shape
    return x[h//2, w//2]

In [None]:
patterns_x_weights = []
patterns = []
weights = []
pattern_layers = []
def _prepare_pattern(self, layer, state, pattern):         
    _weights = layer.get_weights()                          
    tmp = [pattern.shape == x.shape for x in _weights]      
    if np.sum(tmp) != 1:                                   
        raise Exception("Cannot match pattern to kernel.")
    weight = _weights[np.argmax(tmp)]                      
    pxw =  np.multiply(pattern, weight)
    
    patterns_x_weights.append(pxw)
    patterns.append(pattern)
    weights.append(weight)
    pattern_layers.append(layer)
    return pxw
PatternAttribution._prepare_pattern = _prepare_pattern

In [None]:
pa = innvestigate.analyzer.create_analyzer(
    'pattern.attribution', model, patterns=innv_net['patterns']
)

In [None]:
pa.create_analyzer_model()

In [None]:
patterns_x_weights_mat = []
patterns_mat = []
weights_mat = []

patterns_x_weights_1x1 = []
patterns_1x1 = []
weights_1x1 = []

for pxw, pattern, weight in zip(patterns_x_weights, patterns, weights):
    patterns_x_weights_mat.append(to_mat(pxw))
    patterns_mat.append(to_mat(pattern))
    weights_mat.append(to_mat(weight))
    
    
    patterns_x_weights_1x1.append(view_as_conv1x1(pxw))
    patterns_1x1.append(view_as_conv1x1(pattern))
    weights_1x1.append(view_as_conv1x1(weight))

In [None]:
hmap = pa.analyze(meta.ex_image)

In [None]:
len(weights)

In [None]:
for layer, pattern, weight in zip(pattern_layers, patterns, weights):
    print(layer.name, pattern.shape, weight.shape)

In [None]:
patterns_mat_sv = [np.linalg.svd(m, compute_uv=False) for m in patterns_mat]

## SVD of $ P \odot W$

Let $M_l = P_l \odot W_l$ where $P_l$ are the computed patterns and $ W_l $ the weight matrix of layer $l$.

For backpropagation, 

In [None]:

pxw_svd = [np.linalg.svd(pw, full_matrices=True) for pw in patterns_x_weights_1x1]

In [None]:
weights_svd = [np.linalg.svd(w, full_matrices=True) for w in weights_1x1]
patterns_svd = [np.linalg.svd(w, full_matrices=True) for w in patterns_1x1]

In [None]:
for w_svd, px_svd, p_svd in zip(weights_svd, pxw_svd, patterns_svd):
    plt.plot(w_svd[1], label='weight')
    plt.plot(px_svd[1], label='$P \odot W$')
    plt.plot(p_svd[1], label='pattern')
    #plt.ylim(0, 20)
    plt.legend()
    plt.show()

# Visualise SVD for $W$, $A$, $W\odot A$

In [None]:

def plot_svd(usv, title=None, axes=None):
    if axes is None:
        fig, ax = plt.subplots(1, 3, figsize=(10, 2))
        fig.suptitle(title)
    else:
        ax = axes
    for j in range(3):
        if j == 1:
            g = usv[j]
            ax[j].plot(g) 
            
            ax[j].set_title("$\\sigma_1 / \\sigma_2 = {:.3f}$"
                            .format(usv[j][0] / usv[j][1]))
        else:
            g = usv[j]
            im = ax[j].imshow(g, **symetric_min_max(np.ones_like(g)))
            plt.colorbar(im, ax=ax[j])
            

In [None]:
fig, axes = plt.subplots(len(weights_svd), 9, 
                         figsize=(25, 3.4 * len(weights_svd)))

for i, (w_usv, p_usv, pw_usv) in enumerate(zip(
    weights_svd, patterns_svd, pxw_svd)):
    
    plot_svd(w_usv, "{}: weights".format(i), axes=axes[i, :3])
    plot_svd(p_usv, "{}: pattern".format(i), axes=axes[i, 3:6])
    plot_svd(pw_usv, "{}: pattern x weight".format(i), axes=axes[i, 6:9])
    
    if i > 43:
        break

In [None]:
w = weights_1x1[4]
print(w.shape)
w_col_n = np.linalg.norm(w.T, axis=0)
plt.plot(w_col_n[np.argsort(w_col_n)][::-1])
plt.plot(weights_svd[4][1])

In [None]:
u1, s1, v1 = pxw_svd[6]
u2, s2, v2 = pxw_svd[7]

In [None]:
m = v1 @ u2 

In [None]:
u, s, v = np.linalg.svd(m)

In [None]:
def s_to_diag(s, length):
    pad = length - len(s)
    return np.diag(np.pad(s, (0, pad)))

def get_svus_matrices(usv_list):
    svus = []
    for i, (u1, s1, v1) in enumerate(usv_list):
        if i+1 >= len(usv_list):
            break
        u2, s2, v2 = usv_list[i+1]


        v1_u2 = v1@u2


        svus.append(s_to_diag(np.sqrt(s1), len(v1)) 
                    @ (v1 @ u2) 
                    @ s_to_diag(np.sqrt(s2), u2.shape[1]))
    return svus    

def get_vu_matrices(usv_list):
    vu_list = []
    for i, (u1, s1, v1) in enumerate(usv_list):
        if i+1 >= len(usv_list):
            break
        u2, s2, v2 = usv_list[i+1]


        v1_u2 = v1@u2
        vu_list.append(v1_u2)
    return vu_list    

In [None]:
plt.imshow(m, **symetric_min_max(m))

In [None]:
ms = m @ s_to_diag(s2, m.shape[1])
plt.imshow(ms, **symetric_min_max(ms))

In [None]:
svus_weights = get_svus_matrices(weights_svd)
svus_pxw = get_svus_matrices(pxw_svd)
svus_patterns = get_svus_matrices(patterns_svd)

In [None]:
vu_weights = get_vu_matrices(weights_svd)
vu_pxw = get_vu_matrices(pxw_svd)
vu_patterns = get_vu_matrices(patterns_svd)

In [None]:
def plot_3_matrices(m_w, m_p, m_pxw, 
                    zoom=None,
                    names=['weight', 'pattern', 'pattern x weight']):
    fig, ax = plt.subplots(1, 3, figsize=(20,  5))
    fig.suptitle("{}: {}".format(i, m_w.shape))
    
    if zoom is not None:
        if m_w.shape[0] > zoom:
            m_w = m_w[:zoom, :zoom]
            m_pxw = m_pxw[:zoom, :zoom]
            m_p = m_p[:zoom, :zoom]
    ax[0].set_title('weight')
    im = ax[0].imshow((m_w),  **symetric_min_max(m_w))
    plt.colorbar(im, ax=ax[0])
    
    ax[1].set_title('pattern')
    m_p[0, 0] = 0
    im = ax[1].imshow((m_p), **symetric_min_max(m_p))
    plt.colorbar(im, ax=ax[1])
    
    ax[2].set_title('pattern x weight')
    im = ax[2].imshow((m_pxw), **symetric_min_max(m_pxw))
    plt.colorbar(im, ax=ax[2])
    
    plt.show()
    # plt.plot(s1)
    # plt.show()

for i, (m_w, m_p, m_pxw) in enumerate(zip(vu_weights, vu_patterns, vu_pxw)):
    plot_3_matrices(m_w, m_p, m_pxw)


In [None]:
for i, (m_w, m_px, m_p) in enumerate(zip(svus_weights, svus_pxw, svus_patterns)):
    fig, ax = plt.subplots(1, 3, figsize=(20,  5))
    fig.suptitle("{}: {}".format(i, m_w.shape))
    
    zoom = 512
    if m_w.shape[0] > zoom:
        m_w = m_w[:zoom, :zoom]
        m_px = m_px[:zoom, :zoom]
        m_p = m_p[:zoom, :zoom]
    ax[0].set_title('weight')
    im = ax[0].imshow((m_w),  **symetric_min_max(m_w))
    plt.colorbar(im, ax=ax[0])
    
    ax[1].set_title('pattern')
    m_p[0, 0] = 0
    im = ax[1].imshow((m_p), **symetric_min_max(m_p))
    plt.colorbar(im, ax=ax[1])
    
    ax[2].set_title('pattern x weight')
    im = ax[2].imshow((m_px), **symetric_min_max(m_px))
    plt.colorbar(im, ax=ax[2])
    
    plt.show()
    # plt.plot(s1)
    # plt.show()

In [None]:
## Inter chain items

In [None]:
svus_weights_svd = [np.linalg.svd(m) for m in svus_weights] 
svus_patterns_svd = [np.linalg.svd(m) for m in svus_patterns] 
svus_pxw_svd = [np.linalg.svd(m) for m in svus_pxw] 

In [None]:
zoom = 50
for i in range(len(svus_weights_svd)):
    plt.plot(svus_weights_svd[i][1][:zoom], label='suvs')
    plt.plot(weights_svd[i][1][:zoom], label='prev')
    plt.plot(weights_svd[i+1][1][:zoom], label='next')
    plt.legend()
    plt.show()
    
print('PATTERN X WEIGHTS')
for i in range(len(svus_weights_svd)):
    plt.plot(svus_pxw_svd[i][1], label='suvs')
    plt.plot(pxw_svd[i][1], label='prev')
    plt.plot(pxw_svd[i+1][1], label='next')
    plt.legend()
    plt.show()

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(20,  20))
for i, (w_svd, p_svd, px_svd) in enumerate(zip(svus_weights_svd, svus_patterns_svd, svus_pxw_svd)):
    ax = axes.flatten()[i]
    fig.suptitle("{}: {}".format(i, m_w.shape))
    
    uw, sw, vw = w_svd
    up, sp, vp = p_svd
    upx, spx, vpx = px_svd
    ax.plot(sw / sw[0], label='weight')
    ax.plot(sp / sp[0], label='pattern')
    ax.plot(spx / spx[0], label='$P \odot W$')
    ax.legend()
    plt.ylim(0, 1)


## Ratio between singular values

In [None]:
plot_svd(np.linalg.svd(np.linalg.multi_dot(weights_1x1[1:])))
plot_svd(np.linalg.svd(np.linalg.multi_dot(patterns_x_weights_1x1[1:])))

In [None]:
! mkdir -p 'figures/patternattr'

In [None]:
layer_names = []
for layer in model.layers:
    if isinstance(layer, (keras.layers.Conv2D, keras.layers.Dense)):
        layer_names.append(meta.names.to_nice(layer.name))

In [None]:
def s0_to_s1(usv_list):
    return [s[0] / s[1] for _, s, _ in usv_list]

colors = sns.color_palette('colorblind', n_colors=6)

for name, (w_svd, px_svd) in [("per_layer", (weights_svd, pxw_svd)), 
                              ("inter_layer", (svus_weights_svd, svus_pxw_svd))
                              
                             ]:
    with sns.axes_style('ticks'):
        
        fig_width = 4
        golden_mean = (np.sqrt(5)-1.0)/2.0    # Aesthetic ratio
        fig_height = fig_width*golden_mean # height in inches

        plt.figure(figsize=(fig_width, fig_height))
        plt.plot(s0_to_s1(w_svd), label='Gradient', **mpl_styles['Gradient'])
        #plt.plot(s0_to_s1(svus_patterns_svd))
        plt.plot(s0_to_s1(px_svd), label='PatternAttr.', **mpl_styles['PatternAttr.'])
        plt.ylabel('$\\sigma_1 \, / \, \\sigma_2$')
        
        if name == 'per_layer':
            offset = 0
        else:
            offset = -0.5
        plt.xticks(np.arange(len(layer_names)) + offset, layer_names, rotation=90)
        plt.grid('on')

        plt.legend(fontsize='small')
        plt.tight_layout()
        figpath = 'figures/patternattr/{}_pattern_attr_s1_s2.pdf'.format(name)
        plt.savefig(figpath, 
                     bbox_inches='tight', pad_inches=0.1)
        
        display(IFrame(figpath, 800, 400))
        plt.show()

In [None]:
np.prod(s0_to_s1(svus_pxw_svd)), np.prod(s0_to_s1(pxw_svd))

In [None]:
with sns.axes_style('ticks'):
    name = 'patternnet'
    fig_width = 4
    golden_mean = (np.sqrt(5)-1.0)/2.0    # Aesthetic ratio
    fig_height = fig_width*golden_mean # height in inches

    plt.figure(figsize=(fig_width, fig_height))
    plt.plot([s[0] / s[1] for s in patterns_mat_sv], label='PatternNet', **mpl_styles['PatternNet'])

    plt.ylabel('$\\sigma_1 \, / \, \\sigma_2$')

    offset = 0
    plt.xticks(np.arange(len(layer_names)) + offset, layer_names, rotation=90)
    plt.grid('on')

    plt.legend(fontsize='small')
    plt.tight_layout()
    figpath = 'figures/patternattr/pattern_net_s1_s2.pdf'.format(name)
    plt.savefig(figpath, 
                 bbox_inches='tight', pad_inches=0.1)

    display(IFrame(figpath, 800, 400))
    plt.show()

In [None]:
! mkdir -p export_defs

In [None]:
from datetime import datetime
def latex_def(name, value):
    return "\\newcommand{{\\{}}}{{{}}}".format(name, value)

def save_latex_defs(dictonary, filename):
    lines = []
    lines.append("% Automatically generated. Do not change!")
    lines.append("% Exported at {}".format(datetime.utcnow().isoformat()))
    lines.append("")
    
    for k, v in dictonary.items():
        lines.append(latex_def(k, v))
    with open(filename, 'w') as f:
        f.write("\n".join(lines))
    
#np.prod(s0_to_s1(svus_pxw_svd))
defs = {
    "weightSingularRatioProd": 
        "{:0.2f}".format(np.prod(s0_to_s1(svus_weights_svd))),
    "patternSingularRatioProd": 
        "{:0.2f}".format(np.prod(s0_to_s1(svus_pxw_svd))),
}

save_latex_defs(defs, './export_defs/pattern_attr_s1_s1_prod.tex')
! cat './export_defs/pattern_attr_s1_s1_prod.tex'