# Analyze modifications of the gamma rule on the entire MNIST CNNnetwork

#### Loading stuff

In [None]:
1

In [None]:
%load_ext autoreload
%autoreload 2

import os
from tqdm import tqdm
import copy
from functools import partial

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
# import seaborn as sns
# import pandas as pd

# import quantus

from util.util_cnn import CNNModel, data_loaders, first_mnist_batch, test, \
                params_to_filename, params_from_filename, load_mnist_v4_models

from util.util_gamma_rule import \
                plot_vals_lineplot, plot_multiplicative_change, \
                col_norms_for_matrices, \
                global_conv_matrix_from_pytorch_layer, \
                calc_mats_batch, calc_vals_batch

from util.util_lrp import layerwise_forward_pass, compute_relevancies, LRP_global_mat, calc_mats_batch_functional, forward_and_explain
from util.util_matrix_norms import calc_norm_dict
from util.util_data_summary import *
from util.util_pickle import *
from util.naming import *

from learning_lrp import perturb_point

import util.util_tutorial as tut_utils

In [None]:
del CNNModel
from util.util_cnn import CNNModel

In [None]:
%matplotlib inline

In [None]:
# load data and pretrained models
data, target = first_mnist_batch()

model_dict = load_mnist_v4_models()
model_d3 = model_dict[d3_tag]

In [None]:
ll =list(model_dict.keys())
ll.sort()
ll

### Print layers and hidden activaitonn shape

In [None]:
A, layers = layerwise_forward_pass(model_d3, data)
A_shapes = [a.shape[1:] for a in A]

for i, (a, l) in enumerate(zip(A, layers)):
    print(i, "Input:", list(a.shape), '->', str(l).split('(')[0],
    (list(l.weight.shape)) if isinstance(l, torch.nn.Conv2d) else "")

## Svals of LRP transition of "D3" 
(an unnecessarily deep MNIST conv network, with 2x3 Convolutional layers)

In [None]:
for i, (a, l) in enumerate(zip(*layerwise_forward_pass(model_d3, data))):
    print(i, "Input:", list(a.shape), '->', str(l).split('(')[0],
    (list(l.weight.shape)) if isinstance(l, torch.nn.Conv2d) else "")

### Global LRP, modify individual Layers with Gamma

In [None]:
d3_after_conv_layer

In [None]:
n_points = 20 # len(data)
mat_funcs = [partial(LRP_global_mat, model=model_d3, l_lb=l_ub-2, l_ub=l_ub, l_inp=1) for l_ub in d3_after_conv_layer]
LRP_m0_to_1__individual_gamma__gammas40 = calc_mats_batch_functional(mat_funcs, gammas40, data[:n_points].reshape((n_points, -1)))

In [None]:
if False:
    svals__m0_to_1__individual_gamma__gammas5, _ = calc_vals_batch(LRP_m0_to_1__individual_gamma__gammas5, num_vals='auto', tqdm_for='point')
    save_data('d3', 'svals__m0_to_1__individual_gamma__gammas5', svals__m0_to_1__individual_gamma__gammas5)
else:
    svals__m0_to_1__individual_gamma__gammas5 = load_data('d3', 'svals__m0_to_1__individual_gamma__gammas5')
    
svals__m0_to_1__individual_gamma__gammas5.shape

In [None]:
if False:
    svals__m0_to_1__individual_gamma__gammas40, _ = calc_vals_batch(LRP_m0_to_1__individual_gamma__gammas40, num_vals='auto', tqdm_for='point')
    save_data('d3', 'svals__m0_to_1__individual_gamma__gammas40', svals__m0_to_1__individual_gamma__gammas40)
else:
    svals__m0_to_1__individual_gamma__gammas40 = load_data('d3', 'svals__m0_to_1__individual_gamma__gammas40')
    
svals__m0_to_1__individual_gamma__gammas40.shape

In [None]:
plot_vals_lineplot(svals__m0_to_1__individual_gamma__gammas40[:, :1], gammas40, xlim=8, num_vals_total=200, colormap='seismic')

In [None]:
plot_vals_lineplot(svals__m0_to_1__individual_gamma__gammas40[:, :1], gammas40, xlim=8, num_vals_total=200, colormap='seismic')

In [None]:
plot_vals_lineplot(svals__m0_to_1__individual_gamma__gammas40[:, :1], gammas40, xlim=4, num_vals_total=200, colormap='seismic')

In [None]:
# pulled up from below for comparison
plot_vals_lineplot(svals__m0_to_1___cascading_gamma__gammas40[:, :1], gammas40, xlim=4, num_vals_total=200, colormap='seismic')

In [None]:
gammas40

In [None]:
indices = [0, 23, -1]
gammas40[indices]



#### Normalized multiplicative change

In [None]:
# harmoic mean as a summary - y axis linear between 0 and 1
plot_multiplicative_change(svals__m0_to_1__individual_gamma__gammas40, gammas40, normalize=True, hmean='points', yscale='linear', xlim=1, ylim=(0,1), sharey=True)

In [None]:
# individual points - y axis linear between 0 and 1
for p in range(10):
    print('point', p)
    plot_multiplicative_change(svals__m0_to_1__individual_gamma__gammas40[:, p:p+1], gammas40, normalize=True, yscale='linear', xlim=1, ylim=(0,1), sharey=True)

In [None]:
# harmoic mean as a summary - log-scaled, dynamic y axis
plot_multiplicative_change(svals__m0_to_1__individual_gamma__gammas40, gammas40, normalize=True, hmean='points', xlim=4, sharey=False)

In [None]:
# individual points - log-scaled, dynamic y axis
for p in range(10):
    print('point', p)
    plot_multiplicative_change(svals__m0_to_1__individual_gamma__gammas40[:, p:p+1], gammas40, normalize=True, xlim=4, sharey=False)

#### Absolute multiplicative change

In [None]:
# harmoic mean as a summary
plot_multiplicative_change(svals__m0_to_1__individual_gamma__gammas40, gammas40, normalize=False, hmean='points', yscale='linear', sharey=True, xlim=4)

In [None]:
# harmoic mean as a summary
plot_multiplicative_change(svals__m0_to_1__individual_gamma__gammas40, gammas40, normalize=False, hmean='points', xlim=4)

### Global LRP - m1 to 1, modify individual Layers with Gamma

In [None]:
d3_after_conv_layer

In [None]:
n_points = 1 # len(data)
mat_funcs = [partial(LRP_global_mat, model=model_d3, l_lb=l_ub-2, l_ub=l_ub, l_inp=1, l_out=-2) for l_ub in d3_after_conv_layer]
LRP_m1_to_1__individual_gamma__gammas40 = calc_mats_batch_functional(mat_funcs, gammas40, data[:n_points].reshape((n_points, -1)))

In [None]:
LRP_m1_to_1__individual_gamma__gammas40.shape

In [None]:
if False:
    svals__m1_to_1__individual_gamma__gammas40, _ = calc_vals_batch(LRP_m1_to_1__individual_gamma__gammas40, num_vals='auto', tqdm_for='point')
    save_data('d3', 'svals__m1_to_1__individual_gamma__gammas40', svals__m1_to_1__individual_gamma__gammas40)
else:
    svals__m1_to_1__individual_gamma__gammas40 = load_data('d3', 'svals__m1_to_1__individual_gamma__gammas40')
    
svals__m1_to_1__individual_gamma__gammas40.shape

### Global LRP modify increasing numbers of layers with Gamma

In [None]:
mat_funcs = [partial(LRP_global_mat, model=model_d3, l_leq=l_leq, l_inp=1) for l_leq in d3_after_conv_layer]
LRP_m0_to_1__cascading_gamma__gammas_0_1_21_inf = calc_mats_batch_functional(mat_funcs, gammas_0_1_21_inf, data[:].reshape((100, -1)))

In [None]:
mat_funcs = [partial(LRP_global_mat, model=model_d3, l_leq=l_leq, l_inp=1) for l_leq in d3_after_conv_layer]
LRP_m0_to_1__cascading_gamma__gammas5 = calc_mats_batch_functional(mat_funcs, gammas5, data[:].reshape((100, -1)))

In [None]:
mat_funcs = [partial(LRP_global_mat, model=model_d3, l_leq=l_leq, l_inp=1) for l_leq in d3_after_conv_layer]
LRP__m0_to_1__cascading_gamma__gammas40 = calc_mats_batch_functional(mat_funcs, gammas40, data[:].reshape((100, -1)))

In [None]:
if False:
    svals__m0_to_1__cascading_gamma__gammas5, _ = calc_vals_batch(LRP_trans__cascading_gamma__gammas5, num_vals='auto', tqdm_for='point')
    save_data('d3', 'svals__m0_to_1__cascading_gamma__gammas5', svals__m0_to_1__cascading_gamma__gammas5)
else:
    svals__m0_to_1__cascading_gamma__gammas5 = load_data('d3', 'svals__m0_to_1__cascading_gamma__gammas5')
    
svals__m0_to_1__cascading_gamma__gammas5.shape

In [None]:
if False:
    svals__m0_to_1__cascading_gamma__gammas_0_1_21_inf, _ = calc_vals_batch(LRP_m0_to_1__cascading_gamma__gammas_0_1_21_inf, num_vals='auto', tqdm_for='pnt')
    save_data('d3', 'svals__m0_to_1__cascading_gamma__gammas_0_1_21_inf', svals__m0_to_1__cascading_gamma__gammas_0_1_21_inf)
else:
    svals__m0_to_1__cascading_gamma__gammas_0_1_21_inf = load_data('d3', 'svals__m0_to_1__cascading_gamma__gammas_0_1_21_inf')
    
svals__m0_to_1__cascading_gamma__gammas_0_1_21_inf.shape

In [None]:
if False:
    svals__m0_to_1___cascading_gamma__gammas40, _ = calc_vals_batch(LRP__m0_to_1___cascading_gamma__gammas40, num_vals='auto', tqdm_for='point')
    save_data('d3', 'svals__m0_to_1___cascading_gamma__gammas40', svals__m0_to_1___cascading_gamma__gammas40)
else:
    svals__m0_to_1___cascading_gamma__gammas40 = load_data('d3', 'svals__m0_to_1___cascading_gamma__gammas40')

svals__m0_to_1___cascading_gamma__gammas40.shape

#### Spectra

In [None]:
print("The top most line plots the spectra for gamma = 0.15 !!")

i=0 # plot only gammas starting here
plot_vals_lineplot(svals__m0_to_1___cascading_gamma__gammas40[4:, :1, i:], gammas40[i:]
             , spectra=True
             , norm_s1=True # , legend=False
            #  , norm_g0=True
             , ylim='p95'
             # , colormap='seismic'
             , yscale='log'
             , one_plot_per='weight'
            #  , legend="lower left"
             )

In [None]:
plot_last_sval_maximum(svals__m0_to_1___cascading_gamma__gammas40, gammas40)

In [None]:
plot_last_sval(svals__m0_to_1___cascading_gamma__gammas40, gammas40)

In [None]:

plot_determinant(svals__m0_to_1___cascading_gamma__gammas40, gammas40)

In [None]:
i=0#21 # plot only gammas starting here
plot_vals_lineplot(svals__m0_to_1___cascading_gamma__gammas40[-2:, :10, i:], gammas40[i:]
             , spectra=True
             , norm_s1=True # , legend=False
            #  , norm_g0=True
             , ylim='p95'
             # , colormap='seismic'
             , yscale='log'
             , one_plot_per='point'
            #  , legend="lower left"
             )

In [None]:
i=21 # plot only gammas starting here
plot_vals_lineplot(svals__m0_to_1___cascading_gamma__gammas40[-2:, :10, :i], gammas40[:i]
             , spectra=True
             , norm_s1=True # , legend=False
            #  , norm_g0=True
             , ylim='p95'
             # , colormap='seismic'
             , yscale='log'
             , one_plot_per='point'
            #  , legend="lower left"
             )

#### Absolute Svals

In [None]:
plot_vals_lineplot(svals__m0_to_1___cascading_gamma__gammas40[:, :1], gammas40, xlim=8, num_vals_total=200, colormap='seismic')

In [None]:
distribution_plot(svals__m0_to_1__cascading_gamma__gammas5[:, :], gammas5, aggregate_over='points')

In [None]:
distribution_plot(svals__m0_to_1__cascading_gamma__gammas5[:, :], gammas5, aggregate_over='points', mode='violin')

In [None]:
plot_vals_lineplot(svals__m0_to_1__cascading_gamma__gammas_0_1_21_inf, gammas_0_1_21_inf, ylabel="Svals of global LRP", ylim=25)

In [None]:
svals = svals__m0_to_1__cascading_gamma__gammas_0_1_21_inf
maxi = svals[:, :, -1].max(axis=2)
maxi.shape
plt.scatter(np.outer(np.arange(6), np.ones(100)), maxi)
plt.yscale('log')
plt.title('Max Sval at gamma=inf, for six different gamma configurations, and 100 datapoints')

#### Picking out a point, for detailed plot (for proposal)

In [None]:
for i in range(10):
    plot_vals_lineplot(svals__m0_to_1__cascading_gamma__gammas_0_1_21_inf[:, i:i+1], gammas_0_1_21_inf, ylabel="Svals of global LRP")#, ylim=25)

In [None]:
for i in [12, 17, 15, 16, 33, 34, 35]:
    print(i)
    plot_vals_lineplot(svals__m0_to_1__cascading_gamma__gammas_0_1_21_inf[:, i:i+1], gammas_0_1_21_inf, ylabel="Svals of global LRP")#, ylim=25)

#### Proposal Plot

In [None]:
svals = svals__m0_to_1___cascading_gamma__gammas40

In [None]:
fig, axs = plot_vals_lineplot(svals[3:4, :1], gammas40, xlim=4, ylabel="Singular values", title="", xscale='linear', figsize=(8,4), show=False)
axs[0].set_xticks([0,1,2,3,4])
axs[0].set_xticklabels([0,1,2,3,4])
axs[0].set_yticks([0,1,2,3,4])
axs[0].set_yticklabels([0,1,2,3,4])

plt.legend()

#### Log scaled x axis, and derivative

In [None]:
fig, axs = plot_vals_lineplot(svals[:12,:, 30:], gammas40[30:], ylabel="Svals of global LRP", xscale='log', show=0)
for ax in axs: 
    ax.set_ylim((0, 24))

In [None]:
fig, axs = plot_vals_lineplot(svals[:12], gammas40, ylabel="Svals of global LRP", xscale='log', show=0)
for ax in axs: 
    ax.set_ylim((0, 24))
    ax.set_xlim((1e-5, 1e6))

In [None]:
fig, axs = plot_vals_lineplot(svals[:12], gammas40, ylabel="Svals of global LRP", xscale='log', show=0)
for ax in axs: 
    ax.set_ylim((0, 2))
    ax.set_xlim((1e-5, 1e6))

In [None]:
plot_vals_lineplot(svals[:12], gammas40, ylabel="Svals of global LRP", xscale='log', ylim=6)

In [None]:
slope = np.diff(svals, axis=2) / np.diff(gammas40[None, None, :, None], axis=2)

In [None]:
slope.min(axis=(1,2,3))

In [None]:
plot_vals_lineplot(slope[:12], gammas40[1:], ylabel="Svals of global LRP", xscale='log', ylim=(-80, 2))

In [None]:
plot_vals_lineplot(slope[:12], gammas40[1:], ylabel="Svals of global LRP", xscale='log', ylim=(-1150, 2))

In [None]:
plot_vals_lineplot(slope[:12], gammas40[1:], title="Derivative of Svals of global LRP", ylabel="Derivative per Sval", xscale='log', sharey=False, ylim="p100")

In [None]:
axs[0].get_shared_x_axes()

In [None]:
plot_vals_lineplot(svals[:12], gammas40, ylabel="Svals of global LRP", xscale='log')

In [None]:
plot_vals_lineplot(svals[:12], gammas40, ylabel="Svals of global LRP", xscale='log')

#### Normalized multiplicative change

In [None]:
# harmoic mean as a summary - y axis linear between 0 and 1
plot_multiplicative_change(svals__m0_to_1___cascading_gamma__gammas40, gammas40, normalize=True, hmean='points', yscale='linear', xlim=1, ylim=(0,1), sharey=True)

In [None]:
# harmoic mean as a summary - log-scaled, dynamic y axis
plot_multiplicative_change(svals__m0_to_1___cascading_gamma__gammas40, gammas40, normalize=True, hmean='points', yscale='log', xlim=4, ylim='p100')

#### Absolute multiplicative change

In [None]:
# total multiplicative change
for p in range(3):
    print('Point', p)
    plot_multiplicative_change(svals__individual_layer__gammas5[:, p:p+1, :], gammas5, normalize=False, yscale='linear', sharey=True, num_vals_total=400)

In [None]:
# harmoic mean as a summary
plot_multiplicative_change(svals__individual_layer__gammas5, gammas5, normalize=False, hmean='points', yscale='linear', sharey=True, num_vals_total=400)

In [None]:
svals = svals__m0_to_1___cascading_gamma__gammas40

In [None]:
plot_multiplicative_change(svals[:12, :1], gammas40, xlim=8, ylim='p100')# , one_plot_per='point')

In [None]:
for p in range(4):
    print('point', p)
    plot_multiplicative_change(svals[:, p:p+1], gammas40, xlim=8, ylim='p100')

In [None]:
# harmoic mean as a summary
plot_multiplicative_change(svals, gammas40, xlim=2, ylim='p100', hmean='points')

In [None]:
# harmoic mean as a summary
plot_multiplicative_change(svals, gammas40, xlim=2, ylim='p100', hmean='points', yscale='linear')

### Global (except first layer (box rule) and last layer (a Dense layer))

In [None]:
mat_funcs = [partial(LRP_global_mat, model=model_d3, l_leq=l_leq, l_inp=1, l_out=-2, delete_unactivated_subnetwork=True) for l_leq in d3_after_conv_layer[:-1]]
LRP__m1_to_1___cascading_gamma__gammas5 = calc_mats_batch_functional(mat_funcs, gammas5, data[:20])

In [None]:
mat_funcs = [partial(LRP_global_mat, model=model_d3, l_leq=l_leq, l_inp=1, l_out=-2) for l_leq in d3_after_conv_layer[:-1]]
LRP__m1_to_1___cascading_gamma__gammas_0_1_21_inf = calc_mats_batch_functional(mat_funcs, gammas_0_1_21_inf, data, tqdm_for='point')

In [None]:
if False:
    svals__m1_to_1___cascading_gamma__gammas5, _ = calc_vals_batch(LRP__m1_to_1___cascading_gamma__gammas5, num_vals='auto', tqdm_for='point')
    save_data('d3', 'svals__m1_to_1___cascading_gamma__gammas5', svals__m1_to_1___cascading_gamma__gammas5)
else:
    svals__m1_to_1___cascading_gamma__gammas5 = load_data('d3', 'svals__m1_to_1___cascading_gamma__gammas5')

In [None]:
if False:
    svals__m1_to_1___cascading_gamma__gammas_0_1_21_inf, _ = calc_vals_batch(LRP__m1_to_1___cascading_gamma__gammas_0_1_21_inf, num_vals='auto', tqdm_for='point')
    save_data('d3', 'svals__m1_to_1___cascading_gamma__gammas_0_1_21_inf', svals__m1_to_1___cascading_gamma__gammas_0_1_21_inf)
else:
    svals__m1_to_1___cascading_gamma__gammas_0_1_21_inf = load_data('d3', 'svals__m1_to_1___cascading_gamma__gammas_0_1_21_inf')

In [None]:
vals = svals__m1_to_1___cascading_gamma__gammas5
vals.min(), vals[vals>0].min()

In [None]:
# linear scale
plot_vals_lineplot(svals__m1_to_1___cascading_gamma__gammas5, gammas5, ylabel="Svals of global LRP", ylim=200)

In [None]:
# log scale
plot_vals_lineplot(svals__m1_to_1___cascading_gamma__gammas5, gammas5, ylabel="Svals of global LRP", ylim=200, yscale='log', ylim="p100")

In [None]:
distribution_plot(svals__m1_to_1___cascading_gamma__gammas5[:, :], gammas5, aggregate_over='points')

In [None]:
distribution_plot(svals__m1_to_1___cascading_gamma__gammas5[:, :], gammas5, aggregate_over='points', mode='violin')

#### Hists per point

In [None]:
distribution_plot(svals__m1_to_1___cascading_gamma__gammas5[:, :4], gammas5, aggregate_over='x')

In [None]:
distribution_plot(svals__m1_to_1___cascading_gamma__gammas5[:, :4], gammas5, aggregate_over='x')

#### Compare Singular values with L1-induced norm

In [None]:
norm_dict = calc_norm_dict(LRP__m1_to_1___cascading_gamma__gammas5, svals__m1_to_1___cascading_gamma__gammas5)

In [None]:
tag_line = ["L1_lower", "L1_upper", "Linf_lower", "Linf_upper", "sqrt_L1_Linf", "L2", "L1", "Linf", "frobenius"]
norms = np.stack(list(norm_dict[tag] for tag in tag_line), axis=3)
print(norms.shape)

In [None]:
tag_line = ["L1_lower", "L1_upper", "Linf_lower", "Linf_upper", "sqrt_L1_Linf", "L2", "frobenius"]
norms = np.stack(list(norm_dict[tag] for tag in tag_line), axis=3)
print(norms.shape)

In [None]:
tag_line = ["L1_upper", "Linf_upper", "sqrt_L1_Linf", "frobenius"]
upper_bounds = np.stack(list(norm_dict[tag] for tag in tag_line), axis=3)

minis = np.argmin(upper_bounds, axis=3)

print('The best bounds are...')
for (ind, count) in zip(*np.unique(minis, return_counts=True)):
    print(tag_line[ind], '\t', count, 'x')

In [None]:
tag_line = ["sqrt_L1_Linf", "L2", "L1", "Linf", "frobenius"]
norms = np.stack(list(norm_dict[tag] for tag in tag_line), axis=3)
print(norms.shape)

In [None]:
plot_vals_lineplot(norms[[0, -3, -1], :7], gammas5, tag_line=tag_line, ylim="p100", one_plot_per='point', yscale='log')


In [None]:
plot_vals_lineplot(norms, gammas5, tag_line=tag_line, ylim="p99", one_plot_per='point')

#### Calculate [generalized Determinant](https://arxiv.org/pdf/2111.14840.pdf)  and generalized Rank through Singular values

In [None]:
svals = svals__m1_to_1___cascading_gamma__gammas5
svals = svals__m1_to_1___cascading_gamma__gammas_0_1_21_inf
svals[svals==0] = 1
determinants = np.product(svals, axis=3, keepdims=True)

# log scale
plot_vals_lineplot(np.log10(determinants)[:, :], gammas_0_1_21_inf, ylabel="Determinant of global LRP. Log10.", ylim=(-300, 0))

In [None]:
cutoff = 1e-2

svals = svals__m1_to_1___cascading_gamma__gammas5
svals = svals__m1_to_1___cascading_gamma__gammas_0_1_21_inf
generalized_rank = np.sum(svals>cutoff, axis=3, keepdims=True)

# log scale
plot_vals_lineplot(generalized_rank, gammas_0_1_21_inf, ylabel="Generalized rank", title=f"Generalized rank: #Svals > {cutoff}", ylim=(0, svals.shape[3]*1.2))

In [None]:
for a in A:
    print((a[:5] > 0).sum(axis=(1,2,3)))

### Increasing number of layers

In [None]:
mat_funcs = [partial(LRP_global_mat, model=model_d3, l_inp=l_inp, l_out=l_out, delete_unactivated_subnetwork=True) for l_inp, l_out in [(11, 12), (9, 12), (7, 12), (4, 12), (2, 12)]]
LRP__increasing_num_layers_backwards__gammas5 = calc_mats_batch_functional(mat_funcs, gammas5, data[:5])

In [None]:
LRP__increasing_num_layers_backwards__gammas5.shape, LRP__individual_layer__gammas5[:, :1, :1]

In [None]:
if False:
    svals__increasing_num_layers_backwards__gammas5, _ = calc_vals_batch(LRP__increasing_num_layers_backwards__gammas5, num_vals='auto', tqdm_for="gamma")
    save_data('d3', 'svals__increasing_num_layers_backwards__gammas5', svals__increasing_num_layers_backwards__gammas5)
else:
    svals__increasing_num_layers_backwards__gammas5 = load_data('d3', 'svals__increasing_num_layers_backwards__gammas5')
svals__increasing_num_layers_backwards__gammas5.shape

#### Plot Svals

In [None]:
distribution_plot(svals__increasing_num_layers_backwards__gammas5[:, :, [0, 2, 4]], [0, 0.25, 'inf'], aggregate_over='points', cutoff=1e-8)

In [None]:
distribution_plot(svals__increasing_num_layers_backwards__gammas5[:, :, [0, 2, 4]], [0, 0.25, 'inf'], aggregate_over='points', cutoff=1e-6, mode='violin')

### Individual layers

In [None]:
mat_funcs = [partial(LRP_global_mat, model=model_d3, l_inp=l_inp, l_out=l_out, delete_unactivated_subnetwork=True) for l_inp, l_out in [(2, 3), (4, 5), (7,8), (9, 10), (11, 12)]]
LRP__individual_layer__gammas5 = calc_mats_batch_functional(mat_funcs, gammas5, data[:20].reshape((20, -1)))

In [None]:
mat_funcs = [partial(LRP_global_mat, model=model_d3, l_inp=l_inp, l_out=l_out, delete_unactivated_subnetwork=True) for l_inp, l_out in [(2, 3), (4, 5), (7,8), (9, 10), (11, 12)]]
LRP__individual_layer__gammas40 = calc_mats_batch_functional(mat_funcs, gammas40, data[:100].reshape((100, -1)), tqdm_for='point')

In [None]:
LRP__individual_layer__gammas5.shape, LRP__individual_layer__gammas5[:, :1, :1]

In [None]:
if False:
    svals__individual_layer__gammas5, _ = calc_vals_batch(LRP__individual_layer__gammas5, num_vals='auto', tqdm_for="gamma")
    save_data('d3', 'svals__individual_layer__gammas5', svals__individual_layer__gammas5)
else:
    svals__individual_layer__gammas5 = load_data('d3', 'svals__individual_layer__gammas5')
svals__individual_layer__gammas5.shape

In [220]:
if True:
    svals__individual_layer__gammas40, _ = calc_vals_batch(LRP__individual_layer__gammas40, num_vals='auto', tqdm_for="point")
    save_data('d3', 'svals__individual_layer__gammas40', svals__individual_layer__gammas40)
else:
    svals__individual_layer__gammas40 = load_data('d3', 'svals__individual_layer__gammas40')
svals__individual_layer__gammas40.shape

100it [00:24,  4.02it/s]


(5, 100, 40, 3211)

In [221]:
svals__individual_layer__gammas40 = load_data('d3', 'svals__individual_layer__gammas40')

#### New: Plot one line (Spectra) per Gamma

In [None]:
del plot_vals_lineplot
from util.util_gamma_rule import plot_vals_lineplot

In [None]:
plot_vals_lineplot(svals__individual_layer__gammas5[:, :1], gammas5
             , spectra=True
             , norm_s1=True
            #  , norm_g0=True
             , ylim='p100'
             , legend='upper left'
             # , colormap='seismic'
             # , yscale='linear'
             )

In [None]:
i=0 # plot only gammas starting here
plot_vals_lineplot(svals__individual_layer__gammas5[:2, :3, i:], gammas5[i:]
             , spectra=True
             , norm_s1=True, legend=False
            #  , norm_g0=True
             , ylim='p95'
             # , colormap='seismic'
             , yscale='log'
             , one_plot_per='point'
             )

In [None]:
plot_last_sval(svals__individual_layer__gammas5, gammas5)

In [None]:
plot_last_sval_maximum(svals__individual_layer__gammas5, gammas5)

In [None]:
vals = svals__individual_layer__gammas5.copy()
vals /= vals[:, :, :, :1]

# distribution_plot(vals[:, :, [0, 2, 4]], [0, 0.25, 'inf'], aggregate_over='points', cutoff=1e-6)
distribution_plot(vals, gammas5, aggregate_over='points', cutoff=1e-6)

In [None]:

plot_spectra(svals__individual_layer__gammas5[4:5, :1]
             , norm_s1=True
             # , colormap='seismic'
             # , yscale='linear'
             )

#### Absolute Svals

In [None]:
plot_vals_lineplot(svals__individual_layer__gammas5[:, :1], gammas5, num_vals_total=400, ylim="p95", colormap='seismic')

In [None]:
distribution_plot(svals__individual_layer__gammas5[:, :, [0, 2, 4]], [0, 0.25, 'inf'], aggregate_over='points', cutoff=1e-6)

In [None]:
distribution_plot(svals__individual_layer__gammas5[:, :, [0, 2, 4]], [0, 0.25, 'inf'], aggregate_over='points', cutoff=1e-6, mode='violin')

#### Normalized multiplicative change

In [None]:
# normalized change
for p in range(3):
    print('Point', p)
    plot_multiplicative_change(svals__individual_layer__gammas5[:, p:p+1, :], gammas5, normalize=True, yscale='linear', sharey=True, num_vals_total=400)

In [None]:
# harmoic mean as a summary
plot_multiplicative_change(svals__individual_layer__gammas5, gammas5, normalize=True, hmean='points', yscale='linear', sharey=True, num_vals_total=300)

#### Absolute multiplicative change

In [None]:
# total multiplicative change
for p in range(3):
    print('Point', p)
    plot_multiplicative_change(svals__individual_layer__gammas5[:, p:p+1, :], gammas5, normalize=False, yscale='linear', sharey=True, num_vals_total=400)

In [None]:
# harmoic mean as a summary
plot_multiplicative_change(svals__individual_layer__gammas5, gammas5, normalize=False, hmean='points', yscale='linear', sharey=True, num_vals_total=400)

#### L1, L2, Linf operator norms, and 

In [None]:
mat_funcs = [partial(LRP_global_mat, model=model_d3, l_inp=l_inp, l_out=l_out, delete_unactivated_subnetwork='mask') for l_inp, l_out in [(2, 3), (4, 5), (7,8), (9, 10), (11, 12)]]
LRP__individual_layer__gammas5 = calc_mats_batch_functional(mat_funcs, gammas5, data[:5].reshape((5, -1)))

In [None]:
svals__individual_layer__gammas5 = load_data('d3', 'svals__individual_layer__gammas5')

In [None]:
for i, (a, l) in enumerate(zip(*layerwise_forward_pass(model_d3, data))):
    print(i, "Input:", list(a.shape), '->', str(l).split('(')[0],
    (f"Weight shape: {list(l.weight.shape)}, Filter size: {l.weight[0].numel()}") if isinstance(l, torch.nn.Conv2d) else "")

In [None]:
Fs = [72,200,200,144,144]

In [None]:
A, A_pos, A_neg, layers = layerwise_forward_pass(model_d3, data, pos_neg=True)
for i in range(len(A)):
    p, n = A_pos[i], A_neg[i]
    if p is not None and n is not None:
        mask = p+n>0
        p, n = p[mask], n[mask]
        print(i, (-n/p).mean())

In [None]:
norm_dict__individual_layer__gammas5 = calc_norm_dict(LRP__individual_layer__gammas5, svals__individual_layer__gammas5[:, :5], num_filter_entries='lrp')
norm_dict = norm_dict__individual_layer__gammas5

In [None]:
tag_line = ["L1_upper", "Linf_upper", "sqrt_L1_Linf", "frobenius"]
upper_bounds = np.stack(list(norm_dict[tag] for tag in tag_line), axis=3)
minis = np.argmin(upper_bounds, axis=3)

print('The tightest upper bounds are...')
for (ind, count) in zip(*np.unique(minis, return_counts=True)):
    print(tag_line[ind], '\t', count, 'x')

In [None]:
# plot L1, L2, Linf operator norms

tag_line = ["L2", "L1", "Linf", "Linf by L1"]
norms = np.stack(list(norm_dict[tag] for tag in tag_line), axis=3)
print(norms.shape)

In [None]:
plot_vals_lineplot(norms[:, :1], gammas5, ylim="p100", yscale='log', ylabel="Norms", tag_line=tag_line)

In [None]:
plot_vals_lineplot(norms[:, :1], gammas5, ylim="p100", yscale='log', ylabel="Norms", tag_line=tag_line)

In [None]:
plot_vals_lineplot(norms[:, :], gammas5, ylim="p100", yscale='log', ylabel="Norms")

In [None]:
# how tight are bounds on l1 norm?

tag_line = ["L2", "sqrt_L1_Linf", "sqrt_L1_Linf by L1", "frobenius"]
bounds = np.stack(list(norm_dict[tag] for tag in tag_line), axis=3)
print(bounds.shape)

In [None]:
plot_vals_lineplot(bounds[:, :1], gammas5, ylim="p100", yscale='log', ylabel="Bounds", tag_line=tag_line)

In [None]:
plot_vals_lineplot(bounds[:, :], gammas5, ylim="p100", yscale='log', ylabel="Bounds")

In [None]:
mul_difference = norm_dict["sqrt_L1_Linf"] / norm_dict["L2"]
plot_vals_lineplot(mul_difference, gammas5, ylim="p100", yscale='linear', ylabel="Bounds")

In [None]:
mul_difference = norm_dict["sqrt_L1_Linf by L1"] / norm_dict["L2"]
plot_vals_lineplot(mul_difference, gammas5, ylim="p100", yscale='linear', ylabel="Bounds")

In [None]:
mul_difference[:, :, -1]

In [None]:
mul_difference = norm_dict["Linf"] / norm_dict["L2"]
plot_vals_lineplot(mul_difference, gammas5, ylim="p100", yscale='linear', ylabel="Bounds")

## Sparsity of singular vectors (of Individual layers)

In [None]:
A, layers = layerwise_forward_pass(model_d3, data)
A_shapes = [a.shape[1:] for a in A]

for i, (a, l) in enumerate(zip(A, layers)):
    print(i, "Input:", list(a.shape), '->', str(l).split('(')[0],
    (list(l.weight.shape)) if isinstance(l, torch.nn.Conv2d) else "")

In [None]:
l = 2
A_shapes[l]

#### Layer 2 (2nd Conv)

In [None]:
l=2
forw_l2 = global_conv_matrix_from_pytorch_layer(layers[l], inp_shape=A_shapes[l], out_shape=A_shapes[l+1])
forw_l2

In [None]:
n_points=20

In [None]:
LRP__l2__gammas5 = calc_mats_batch([forw_l2], A[l][:n_points].reshape((n_points, -1)).detach(), gammas=gammas5, mode='back clip')
LRP__l2__gammas5.shape

In [None]:
res = calc_vals_batch(LRP__l2__gammas5[:, :3], num_vals=500, tqdm_for='point', return_vecs=True)

In [None]:
if False:
    svals__l2__gammas5, svecs__l2__gammas5 = calc_vals_batch(LRP__l2__gammas5[:, :20], num_vals='auto', tqdm_for='point', return_vecs=True)
    save_data('d3', 'svals__l2__gammas5', svals__l2__gammas5)
    save_data('d3', 'svecs__l2__gammas5', svecs__l2__gammas5)
else:
    svals__l2__gammas5 = load_data('d3', 'svals__l2__gammas5')
    svecs__l2__gammas5 = load_data('d3', 'svecs__l2__gammas5')
    
svals__l2__gammas5.shape, svecs__l2__gammas5.shape

In [None]:
if False:
    svals__l2__gammas_0_1_21_inf, svecs__l2__gammas_0_1_21_inf = calc_vals_batch(LRP__l2__gammas5[:, :20], num_vals='auto', tqdm_for='point', return_vecs=True)
    save_data('d3', 'svals__l2__gammas_0_1_21_inf', svals__l2__gammas_0_1_21_inf)
    save_data('d3', 'svecs__l2__gammas_0_1_21_inf', svecs__l2__gammas_0_1_21_inf)
else:
    svals__l2__gammas_0_1_21_inf = load_data('d3', 'svals__l2__gammas_0_1_21_inf')
    svecs__l2__gammas_0_1_21_inf = load_data('d3', 'svecs__l2__gammas_0_1_21_inf')
    
svals__l2__gammas_0_1_21_inf.shape, svecs__l2__gammas_0_1_21_inf.shape

In [None]:
svec_sparsity = lambda svecs, cutoff=1e-6: (np.abs(svecs)>cutoff * 1.).mean(axis=4)

svec_sparsity__l2__gammas5 = svec_sparsity(svecs__l2__gammas5)
svec_sparsity__l2__gammas_0_1_21_inf = svec_sparsity(svecs__l2__gammas_0_1_21_inf)

In [None]:
# svals
plot_vals_lineplot(svals__l2__gammas_0_1_21_inf[:, :3], gammas_0_1_21_inf, one_plot_per='point', ylabel="Top 300 Svals", ylim=40)


In [None]:
# svals
plot_vals_lineplot(svals__l2__gammas_0_1_21_inf[:, :3], gammas_0_1_21_inf, one_plot_per='point', ylabel="Top 300 Svals (Log)", yscale='log', ylim="p99")


In [None]:
# All Svecs get dense with gamma -> inf
plot_vals_lineplot(svec_sparsity__l2__gammas_0_1_21_inf[:, :3], gammas_0_1_21_inf, one_plot_per='point', ylabel="Top 300 Svec sparsity", ylim=1)


In [None]:
# all svecs sparsity, more points
plot_vals_lineplot(svec_sparsity__l2__gammas_0_1_21_inf[:, :8, :, :50], gammas_0_1_21_inf, one_plot_per='point', ylabel="Top 50 Svec sparsity", ylim=1)


In [None]:
# top svecs: get dense slower
plot_vals_lineplot(svec_sparsity__l2__gammas_0_1_21_inf[:, :3, :, :20], gammas_0_1_21_inf, one_plot_per='point', ylabel="Svec sparsity", ylim=1)

In [None]:
# non-top svecs: get dense fast, or are dense from the beginning.
plot_vals_lineplot(svec_sparsity__l2__gammas_0_1_21_inf[:, :3, :, 30:], gammas_0_1_21_inf, one_plot_per='point', ylabel="Svec sparsity", ylim=1)

## Plot where the entries of Top singular vectors lie. 
For that, take absolute, sum over channels/filters. Are they concentrated in certain xy positions?

In [None]:
A[2].shape[1:]

In [None]:
np.array(gammas_0_1_21_inf)[[0,1,2,4,8,16,21]]

In [None]:
vecs_per_weight.shape, svecs__l2__gammas_0_1_21_inf[:1, :1, :, np.array([0,2])].shape, i_gammas, i_vecs

In [None]:
i_vecs = [0,1,2,3,4,5, 50]
i_vecs = np.arange(20)

# small dataset
vals_per_weight, vecs_per_weight = svals__l2__gammas5, svecs__l2__gammas5
gammas = gammas5

# big dataset
i_gammas = [0,1,2,4,8,16,21]
vals_per_weight, vecs_per_weight = svals__l2__gammas_0_1_21_inf, svecs__l2__gammas_0_1_21_inf
gammas = np.array(gammas_0_1_21_inf)
vals_per_weight, vecs_per_weight, gammas = vals_per_weight[:1, :1, i_gammas][:, :, :, i_vecs], vecs_per_weight[:1, :1, i_gammas][:, :, :, i_vecs], gammas[i_gammas]


img_shape = (8, 28, 28) # in 2nd layer

for vals_per_point, vecs_per_point in zip(vals_per_weight, vecs_per_weight):
    for vals_per_gamma, vecs_per_gamma in zip(vals_per_point, vecs_per_point):
        
        nrow, ncol = len(i_vecs), len(vals_per_gamma)
        fig, axs = plt.subplots(ncol, nrow, figsize=(nrow*15, ncol*15))
        fig.tight_layout(pad=0)
        axs = np.array(axs).T
        
        for vals, vecs, ax_per_vec, gamma in zip(vals_per_gamma, vecs_per_gamma, axs, gammas):
            vecs_reshaped = np.abs(vecs.reshape((len(i_vecs), *img_shape))).sum(axis=1)

            for i_vec, val, vec, ax in zip(i_vecs, vals, vecs_reshaped, ax_per_vec):
                ax.imshow(vec)
        
                ax.set_box_aspect(1)
                ax.set_xticks([])
                ax.set_yticks([])
                # ax.set_title(f"gamma = {gamma}, ")
        
        
        plt.show()

In [None]:
i_vecs = [0,1,2,3,4,5, 50]
i_vecs = np.arange(40)

# small dataset
vals_per_weight, vecs_per_weight = svals__l2__gammas5, svecs__l2__gammas5
gammas = gammas5

# big dataset
i_gammas = [0,1,2,4,8,16,21]
vals_per_weight, vecs_per_weight = svals__l2__gammas_0_1_21_inf, svecs__l2__gammas_0_1_21_inf
gammas = np.array(gammas_0_1_21_inf)
vals_per_weight, vecs_per_weight, gammas = vals_per_weight[:1, :1, i_gammas][:, :, :, i_vecs], vecs_per_weight[:1, :1, i_gammas][:, :, :, i_vecs], gammas[i_gammas]


img_shape = (8, 28, 28) # in 2nd layer

for vals_per_point, vecs_per_point in zip(vals_per_weight, vecs_per_weight):
    for vals_per_gamma, vecs_per_gamma in zip(vals_per_point, vecs_per_point):
        
        nrow, ncol = len(i_vecs), len(vals_per_gamma)
        fig, axs = plt.subplots(nrow, ncol, figsize=(nrow*15, ncol*15))
        fig.tight_layout(pad=0)
        axs = np.array(axs).T
        
        for vals, vecs, ax_per_vec, gamma in zip(vals_per_gamma, vecs_per_gamma, axs, gammas):
            vecs_reshaped = np.abs(vecs.reshape((len(i_vecs), *img_shape))).sum(axis=1)

            for i_vec, val, vec, ax in zip(i_vecs, vals, vecs_reshaped, ax_per_vec):
                ax.imshow(vec)
        
                ax.set_box_aspect(1)
                ax.set_xticks([])
                ax.set_yticks([])
                # ax.set_title(f"gamma = {gamma}, ")
        
        
        plt.show()

## Plot ratio of first and second Sval

In [None]:
plot_vals_lineplot(vals_per_weight, gammas, one_plot_per='point', yscale='log', ylim="p99.99")

In [None]:
vals_per_weight = svals__l2__gammas5
gammas = gammas5

vals_per_weight = svals__l2__gammas_0_1_21_inf
gammas = gammas_0_1_21_inf

ratio_two_largest_svals = vals_per_weight[:, :, :, 0] / vals_per_weight[:, :, :, 1]
plot_vals_lineplot(ratio_two_largest_svals[:, :, :, None], gammas, ylim="p95")

## Plot correlation between Svec sparsity and Sval.

In [None]:
svals__l2__gammas5.shape, svec_sparsity__l2__gammas5.shape

In [None]:
vals_per_weight, spars_per_weight = svals__l2__gammas5, svec_sparsity__l2__gammas5
gammas = gammas5

assert vals_per_weight.shape == spars_per_weight.shape

for vals_per_point, spars_per_point in zip(vals_per_weight, spars_per_weight):
    for vals_per_gamma, spars_per_gamma in zip(vals_per_point, spars_per_point):
        fig, axs = plt.subplots(1, len(vals_per_gamma), figsize=(20, 4), sharex=True, sharey=True)
        
        for vals, spars, ax, gamma in zip(vals_per_gamma, spars_per_gamma, axs, gammas):
            ax.set_box_aspect(1)
            ax.set_xscale('log')
            ax.set_title(f"gamma = {gamma}")
            ax.scatter(vals, spars)
                
        plt.show()

## Similarity between singular vectors of similar datapoints

Calculate the similarity between the svecs of a pair of points. 
- Point 1 is a datapoint from the test set (or its following hidden activations).
- Point 2 is a slightly perturbed version of Point 1.

In [None]:
def perturb_point(point, k=1, mode='normal', var=.5, clip=[0,1], activated_subnetwork=True, seed=1):
    torch.manual_seed(seed)
    
    if 'normal' in mode:
        perturbation = torch.normal(0, var, size=(k, *point.shape))
        
    if activated_subnetwork:
        perturbation *= point[None] > 0
       
    point_perturbed = point[None] + perturbation
    
    if clip is not None:
        if clip[0] is not None:   point_perturbed = point_perturbed.clip(min=clip[0])
        if clip[1] is not None:   point_perturbed = point_perturbed.clip(max=clip[1])
        
    return point_perturbed

### Layer 2 (perturb layer 0 input)

In [None]:
l=2
forw_l2 = global_conv_matrix_from_pytorch_layer(layers[l], inp_shape=A_shapes[l], out_shape=A_shapes[l+1])
forw_l2

In [None]:
point = A[0][0]
points_perturbed = torch.vstack((point[None], perturb_point(point, k=4, activated_subnetwork=False, var=.1)))
points_perturbed.shape

#### Visualize Perturbed points

In [None]:
plt.imshow(points_perturbed.transpose(0, 2)[:, 0].reshape((28, -1)))

In [None]:
#### Playground

In [None]:
mat_funcs = [partial(LRP_global_mat, model=model_d3, l_inp=2, l_out=3)]
LRP_mats = calc_mats_batch_functional(mat_funcs, gammas5, points_perturbed.view((len(points_perturbed), -1)), tqdm_for='point')

In [None]:
n_svecs = 100
svals, svecs = calc_vals_batch(LRP_mats, num_vals=n_svecs, tqdm_for='point', return_vecs=True)

svals.shape, svecs.shape

In [None]:
for i in range(1,5): svec_similarity(*svecs[0, [0,i]])

In [None]:
svec_sparsity = lambda svecs, cutoff=1e-6: (np.abs(svecs)>cutoff * 1.).mean(axis=4)

svec_sparsity__l2__perturbed_points__gammas5 = svec_sparsity(svecs)
# svals
plot_vals_lineplot(svec_sparsity__l2__perturbed_points__gammas5, gammas5, one_plot_per='point', ylabel="Top 50 Svals", ylim=1)

#### Correctly labeled

In [None]:
mat_funcs = [partial(LRP_global_mat, model=model_d3, l_inp=2, l_out=3)]
LRP__l2__perturbed_points_var05__gammas5 = calc_mats_batch_functional(mat_funcs, gammas5, points_perturbed.view((len(points_perturbed), -1)), tqdm_for='point')

In [None]:
n_svecs = 100
if False:
    svals__l2__perturbed_points_var05__gammas5, svecs__l2__perturbed_points_var05__gammas5 = calc_vals_batch(LRP__l2__perturbed_points_var05__gammas5, num_vals=n_svecs, tqdm_for='point', return_vecs=True)
    save_data('d3', 'svals__l2__perturbed_points_var05__gammas5', svals__l2__perturbed_points_var05__gammas5)
    save_data('d3', 'svecs__l2__perturbed_points_var05__gammas5', svecs__l2__perturbed_points_var05__gammas5)
else:
    svals__l2__perturbed_points_var05__gammas5 = load_data('d3', 'svals__l2__perturbed_points_var05__gammas5')
    svecs__l2__perturbed_points_var05__gammas5 = load_data('d3', 'svecs__l2__perturbed_points_var05__gammas5')
    
svals__l2__perturbed_points_var05__gammas5.shape, svecs__l2__perturbed_points_var05__gammas5.shape

In [None]:
n_svecs = 50
if False:
    svals__l2__perturbed_points_var01__gammas5, svecs__l2__perturbed_points_var01__gammas5 = calc_vals_batch(LRP__l2__perturbed_points_var01__gammas5, num_vals=n_svecs, tqdm_for='point', return_vecs=True)
    save_data('d3', 'svals__l2__perturbed_points_var01__gammas5', svals__l2__perturbed_points_var01__gammas5)
    save_data('d3', 'svecs__l2__perturbed_points_var01__gammas5', svecs__l2__perturbed_points_var01__gammas5)
else:
    svals__l2__perturbed_points_var01__gammas5 = load_data('d3', 'svals__l2__perturbed_points_var01__gammas5')
    svecs__l2__perturbed_points_var01__gammas5 = load_data('d3', 'svecs__l2__perturbed_points_var01__gammas5')
    
svals__l2__perturbed_points_var01__gammas5.shape, svecs__l2__perturbed_points_var01__gammas5.shape

In [None]:
def svec_similarity(p1_svecs, p2_svecs):
    print(p1_svecs.shape)
    cosine_sim = np.einsum('ijx,ikx->ijk', p1_svecs, p2_svecs)   # compute cosine similarity between every svec in p1 to every svec in p2
    cosine_sim = np.abs(cosine_sim)                              # take absolute
    print(cosine_sim.max(axis=(1,2)))
    l = [cosine_sim[0]]
    
    for p in cosine_sim[1:]: l += [np.full((n_svecs, 1), 2), p]

    plt.figure(figsize=(15, 10))
    plt.imshow(np.hstack(l), vmin=0, vmax=1)

    plt.ylabel('Svecs of unperturbed point (Decreasing Svals from Top to Down)')
    plt.xlabel('Svecs of perturbed point (Decreasing Svals from Left to Right)')
    plt.show()

#### Variance 0.1

In [None]:
for i in range(1,5): svec_similarity(*svals__l2__perturbed_points_var01__gammas5[0, [0,i]])

In [None]:
svec_sparsity = lambda svecs, cutoff=1e-6: (np.abs(svecs)>cutoff * 1.).mean(axis=4)

svec_sparsity__l2__perturbed_points__gammas5 = svec_sparsity(svals__l2__perturbed_points_var01__gammas5)
# svals
plot_vals_lineplot(svec_sparsity__l2__perturbed_points__gammas5, gammas5, one_plot_per='point', ylabel="Top 50 Svals", ylim=1)

#### Variance 0.5

In [None]:
for i in range(1,5): svec_similarity(*svecs__l2__perturbed_points_var05__gammas5[0, [0,i]])

In [None]:
svec_sparsity = lambda svecs, cutoff=1e-6: (np.abs(svecs)>cutoff * 1.).mean(axis=4)

svec_sparsity__l2__perturbed_points__gammas5 = svec_sparsity(svecs__l2__perturbed_points_var05__gammas5)
# svals
plot_vals_lineplot(svec_sparsity__l2__perturbed_points__gammas5, gammas5, one_plot_per='point', ylabel="Top 50 Svals", ylim=1)

### Layer 2 (perturb hidden activations directly)
The issue in perturbing hidden activations directly with Gaussian noise might be that it changes their characteristics unduly (even if clipping them properly etc.).
In advanced layers of the NN, the hidden activations have already successfully condensed tothe *Signal*, containing less of the *Distractor*. (Ref Alber PatternNet paper)

Some evidence in that direction is that the Singular Vectors of an LRP matrix with [hidden_activation + noise] as a referece point are less sparse than the Singular Vectors of an usual LRP matrix with just [hidden_activation] as a referece point.

In [None]:
l=2
forw_l2 = global_conv_matrix_from_pytorch_layer(layers[l], inp_shape=A_shapes[l], out_shape=A_shapes[l+1])
forw_l2

In [None]:
point = A[l][0]
points_perturbed = torch.vstack((point[None], get_perturbations(point, k=4)))
points_perturbed.shape

In [None]:
points_perturbed.sum()

In [None]:
((points_perturbed.reshape((5, -1)) != 0)*1.).mean(axis=1), ((points_perturbed.reshape((5, -1)) < 0)*1.).mean(axis=1)

In [None]:
LRP__l2__perturbed_activations__gammas5 = calc_mats_batch([forw_l2], points.detach().reshape(len(points), -1), gammas=gammas5, mode='back clip')
LRP__l2__perturbed_activations__gammas5.shape

In [None]:
n_svecs = 50

In [None]:
if False:
    svals__l2__perturbed_activations__gammas5, svecs__l2__perturbed_activations__gammas5 = calc_vals_batch(LRP__l2__perturbed_activations__gammas5, num_vals=n_svecs, tqdm_for='point', return_vecs=True)
    save_data('d3', 'svals__l2__perturbed_activations__gammas5', svals__l2__perturbed_activations__gammas5)
    save_data('d3', 'svecs__l2__perturbed_activations__gammas5', svecs__l2__perturbed_activations__gammas5)
else:
    svals__l2__perturbed_activations__gammas5 = load_data('d3', 'svals__l2__perturbed_activations__gammas5')
    svecs__l2__perturbed_activations__gammas5 = load_data('d3', 'svecs__l2__perturbed_activations__gammas5')
    
svals__l2__perturbed_activations__gammas5.shape, svecs__l2__perturbed_activations__gammas5.shape

In [None]:

p1, p2 = svecs__l2__perturbed_activations__gammas5[0, :2, :, :n_svecs, :]
p1.shape

In [None]:
cosine_sim = np.einsum('ijx,ikx->ijk', p1, p2)   # compute cosine similarity between every svec in p1 to every svec in p2
cosine_sim = np.abs(cosine_sim)                  # take absolute
cosine_sim.shape

In [None]:
cosine_sim.max(axis=(1,2))

In [None]:
cosine_sim.max(axis=(1,2))

In [None]:
plt.imshow(cosine_sim.transpose((1,0,2)).reshape((n_svecs, -1)), vmin=0, vmax=cosine_sim.max())

In [None]:
plt.imshow(cosine_sim.transpose((1,0,2)).reshape((n_svecs, -1)), vmin=0, vmax=cosine_sim.max())

In [None]:
maxi = cosine_sim.max()
l = [cosine_sim[0]]
for p in cosine_sim[1:]: l += [np.full((n_svecs, 1), maxi), p]

plt.figure(figsize=(15, 10))
plt.imshow(np.hstack(l), vmin=0, vmax=maxi)

In [None]:
i_gamma = -1

sims = cosine_sim[i_gamma]
maxi= sims.max()

for i_p1_svecs in range(n_svecs):
    plt.imshow(sims, vmin=0, vmax=maxi)
    plt.show()
    