In [None]:
%env CUDA_VISIBLE_DEVICES=""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import OrderedDict
import tqdm
import keras
import itertools
from IPython.display import display, IFrame
from when_explanations_lie import cosine_similarity_dot, load_model
sns.set()

In [None]:
def cosine_similarity(U, V):
    v_norm =  V / np.linalg.norm(V, axis=0, keepdims=True)
    u_norm = U / np.linalg.norm(U, axis=0, keepdims=True)
    return (v_norm.T @ u_norm)

In [None]:
def measure_convergence(matrices, metric='std', relu=False, normalize=False):
    w_i = np.eye(matrices[0].shape[0])
    metrics = []
    first = True
    for w in matrices:
        w_i = w_i @ w
        if relu and not first:
            w_i = np.clip(w_i, 0, np.inf)
        first = False
        
        if normalize:
            w_i = columns_sum_to_one(w_i)
        if metric == 'cos_similarity':
            scos = cosine_similarity(w_i.T, w_i.T)
            tri = np.tri(len(scos))
            scos = scos[tri == 1]
            scos = scos[~np.isnan(scos)]
            metrics.append(np.abs(scos))
        else:
            metrics.append(metric(w_i))
            
    return metrics


def rows_sum_to_one(w):
    return w / w.sum(0, keepdims=True)

def columns_sum_to_one(w):
    return w / w.sum(1, keepdims=True)

def transpose_shapes(shapes):
    return [(shp[1], shp[0]) for shp in shapes[::-1]]

def conv_to_matrix(kernel):
    if len(kernel.shape) == 2:
        return kernel
    h, w, cin, cout = kernel.shape
    return kernel[h//2, w//2]

In [None]:
model, innv_net, _ = load_model('vgg16')

vgg_forward_shps = []
for layer in model.layers:
    try:
        w, b = layer.get_weights()
        if len(w.shape) == 4:
            vgg_forward_shps.append(w[1, 1].shape)
        elif w.shape[0] == 25088:
            vgg_forward_shps.append(w[:25088 //(7*7), :].shape)
        else:
            vgg_forward_shps.append(w.shape)
            vg
        print(vgg_forward_shps[-1])
    except:
        continue
        
vgg_backward_shps = transpose_shapes(vgg_forward_shps) 


vgg_parameters = []
for layer in model.layers[::-1]:
    if isinstance(layer, (keras.layers.Conv2D, keras.layers.Dense)):
        w, b = layer.get_weights()
        if len(w.shape) == 4:
            w = conv_to_matrix(w)
            
        print(w.shape) 
        if len(w) == 25088:
            cin, cout = w.shape
            w = w.reshape((7, 7, cin // (7*7), cout))
            w = w[3, 3]
        vgg_parameters.append((w, b))

In [None]:
pattern_weights = []
for parameters, pattern in zip(vgg_parameters, innv_net['patterns'][::-1]):
    w, b = parameters
    pattern_mat = conv_to_matrix(pattern)
    if pattern_mat.shape[0] == 25088:
        cin, cout = pattern_mat.shape
        pattern_mat = pattern_mat.reshape((7, 7, cin // (7*7), cout))
        pattern_mat = pattern_mat[3, 3]
    pattern_weights.append([w * pattern_mat, np.zeros_like(b)])

In [None]:
w = np.random.normal(size=(64, 64))
cos_sim = cosine_similarity_dot(w, np.ones_like(w))
print(cos_sim.shape,(cos_sim < 0)[None].shape)
mask = np.repeat((cos_sim < 0)[None], 64, axis=0).T
w[mask] = -w[mask]  

cosine_similarity_dot(w, np.ones_like(w))

In [None]:
def cosine_sim_great_zero(w, reference=None):
    if reference is None:
        reference = np.ones_like(w)
    cos_sim = cosine_similarity_dot(w, reference)
    mask = np.repeat((cos_sim < 0)[None], len(w[0]), axis=0).T
    w[mask] = -w[mask]  
    return w

In [None]:
plt.hist((w.T @ w.T).flatten())

In [None]:
square_size = 128
square_shapes = [(square_size, square_size) for _ in range(len(vgg_backward_shps))]

In [None]:
n_samples = 20
input_size = 32

with_relu = True
no_relu = False
use_nn = True
use_sq = False

square_size = 128
square_shapes = [(square_size, square_size) for _ in range(len(vgg_forward_shps))]


sq_nn_convergences = []
if use_sq:
    sq_nn_convergences.extend([
        # label, shapes, relu, matrix, metric
        ('sq vanilla', square_shapes, no_relu, lambda s: np.random.normal(size=s)), 
        ('sq ReLU', square_shapes, with_relu, lambda s: np.random.normal(size=s)), 
        ('sq stocastic', square_shapes, no_relu, lambda s: rows_sum_to_one(np.abs(np.random.normal(size=s)))), 
        ('sq positive', square_shapes, no_relu, lambda s: np.abs(np.random.normal(size=s))), 
        ('sq non-neg.',   square_shapes, no_relu, lambda s: np.maximum(0, np.random.normal(size=s))), 
        
        
    ])

if use_nn:
    sq_nn_convergences.extend([
        ('vanilla',       vgg_backward_shps, no_relu, lambda s: np.random.normal(size=s)), 
        ('ReLU',  vgg_backward_shps, with_relu, lambda s: np.random.normal(size=s)), 
        ('ReLU learned',   vgg_backward_shps, with_relu, vgg_parameters), 
        ('pattern $A \\odot W$', vgg_backward_shps, no_relu,
             pattern_weights),
        ('stocastic', vgg_backward_shps, no_relu, lambda s: rows_sum_to_one(np.abs(np.random.normal(size=s)))), 
        ('postive',   vgg_backward_shps, no_relu, lambda s: np.abs(np.random.normal(size=s))), 
        ('non-neg.',   vgg_backward_shps, no_relu, lambda s: np.maximum(0, np.random.normal(size=s))), 
        #('NN ReLU Forw.', vgg_forward_shps, with_relu, lambda s: np.random.normal(size=s)), 
    ])


def get_alpha_beta_matrix(a, b):
    def wrapper(shape):
        w = np.random.normal(size=shape)
        w_plus = w * (w >= 0)
        w_neg = w * (w < 0)      
        return a * w_plus + b * w_neg
    return wrapper

def neg_idx_matrix(lam):
    def wrapper(shape):
        w = np.abs(np.random.normal(size=shape))
        m = (np.random.uniform(size=shape) > lam)
        return w * m - (1-m) * w
    return wrapper

def sq_alpha_beta(alpha):
    beta = alpha - 1
    return ('sq $\\alpha={},\\beta={}$'.format(alpha, beta), 
            square_shapes, no_relu, get_alpha_beta_matrix(alpha, beta))

def nn_alpha_beta(alpha):
    beta = alpha - 1
    return ('$\\alpha={},\\beta={}$'.format(alpha, beta), 
            vgg_backward_shps, no_relu, get_alpha_beta_matrix(alpha, beta))

alpha_beta_conv = [
#    sq_alpha_beta(1),
#    sq_alpha_beta(2),
#    sq_alpha_beta(3),
#    sq_alpha_beta(4),
#    sq_alpha_beta(5),
#    sq_alpha_beta(10),
    
    nn_alpha_beta(10),
    nn_alpha_beta(5),
    nn_alpha_beta(4),
    nn_alpha_beta(3),
    nn_alpha_beta(2),
    nn_alpha_beta(1),
]

#neg_idx_convergences = []
#
#for lam in [0, 0.4, 0.45, 0.5, .65, 0.8, 1]:
#    #f use_nn:
#    #   lambda_convergences.append(
#    #       ('NN $\\lambda={}$'.format(lam), vgg_backward_shps, no_relu, get_alpha_beta_matrix(lam))) 
#        
#    neg_idx_convergences.append(
#        ('SQ IDX $\\lambda={}$'.format(lam), square_shapes, no_relu, neg_idx_matrix(lam))
#    )
    
convergences = alpha_beta_conv + sq_nn_convergences
#convergences = sq_nn_convergences


In [None]:
def get_labels(convergences):
    return [l for l, _, _, _ in convergences]

In [None]:
corrs = OrderedDict()

for label, shapes, with_relu, get_matrix in tqdm.tqdm_notebook(convergences):
        
    for i in tqdm.tqdm_notebook(range(n_samples), desc=label):
        input = np.random.normal(size=(input_size, shapes[0][0]))
        #print(input.shape)
        if type(get_matrix) == list:
            matrices = [w.T for w, b in get_matrix]
        else:
            matrices = [get_matrix(shp) for shp in shapes]
        vals = measure_convergence([input] + matrices, metric='cos_similarity', relu=with_relu)
        for itr, v in enumerate(vals):
            corrs[label, i, itr] = np.nanmean(v)

In [None]:
markers = [
    "o", #circle
    "v", #triangle_down
    "^", #triangle_up
    "X", #star
    "s", #square
    "D", #diamond
    "<", #triangle_left
    "P", #plus (filled)
    "$O$", 
    ">",
    "$V$", #hexagon2
    "$P$", 
]

In [None]:
! mkdir -p figures/convergence_simulation/

In [None]:
def plot_convergence(values, labels, linestyles, conf_intervals=None, title=None, save_path=None, ylogscale=False, clip_eps=1e-15,
                     legend='right'):
    def handle_log(x):
        return 1 - x
    
    if conf_intervals is None:
        conf_intervals = itertools.repeat((None, None))
    with sns.axes_style('ticks', {'font.family': 'serif'}):
        plt.figure(figsize=(3.9, 2.2))
        for value, (conf_lower, conf_upper), label, linestyle in zip(
            values, conf_intervals, labels, linestyles):
            
            xs = np.arange(len(value))
            if ylogscale:
                inv_value = 1 - value
                _ = plt.semilogy(xs, inv_value, label=label, **linestyle)
                #plt.ylim(1, 1e-10)
                if conf_upper is not None:
                    conf_upper = 1 - conf_upper
                    conf_lower = 1 - conf_lower
                    conf_upper[inv_value < clip_eps] = np.nan
                    conf_lower[inv_value < clip_eps] = np.nan
                    
            else:
                _ = plt.plot(xs, value, label=label, **linestyle)
            if conf_upper is not None:
                _ = plt.fill_between(xs, conf_upper, conf_lower, alpha=0.25, color=linestyle['color'])
        plt.title(title)
        if legend == 'left':
            plt.legend(bbox_to_anchor=(-.3, 1) )
        else:
            plt.legend(bbox_to_anchor=(1, 1) )
        plt.grid(True)
        ymin = 1e-12
        _, ymax = plt.ylim()
        plt.ylim(ymax, ymin)
        
        locs, labels = plt.yticks()
        locs = [l for l in locs if ymin < l < 1]
        plt.yticks(locs, labels=["1 - {:.0e}".format(l).replace("-0", "-") for l in locs])
        
        #plt.yticks(1 - np.array([10**(-i) for i in range(10)]))
        plt.xlabel('multiplications')
        if ylogscale:
            plt.ylabel('cosine similarity')
        else:
            plt.ylabel('abs. cosine similarity')
        if save_path:
            plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)

        plt.show()
        
def collect_corr(corrs, selected_conv, percentile=99):
    values = []
    conf_int = []
    for label, shapes, _, _ in selected_conv:
        corr_per_label = np.array([
            [corrs[label, i, itr] for itr in range(1+len(shapes))]
            for i in range(n_samples)
        ])
        median = np.median(corr_per_label, 0)
        p = (100 - percentile) / 2
        lower = np.percentile(corr_per_label, p, axis=0)
        upper = np.percentile(corr_per_label, 100 - p, axis=0)
        values.append(median)
        conf_int.append((lower, upper))
    return values, conf_int

In [None]:
values, conf_int = collect_corr(corrs, sq_nn_convergences, percentile=99)


k = 6
color_palette = sns.color_palette('colorblind', k)
linestyles = [{'linestyle': l, 'color': c, 'marker': m} for l, c, m in zip(
    ["-"] * k + ["--"] *k , 
    itertools.cycle(iter(color_palette)),
    markers
)]

labels = get_labels(sq_nn_convergences)
#labels[0] = "$s_\\cos(W, \\vec{1}) > 0$"
#labels[1] = "pattern $A \\odot W$"
plot_convergence(values, labels, linestyles[:len(values)], 
                 save_path="figures/convergence_simulation/convergence_nn_sq.pdf",
                 ylogscale=True, legend='right'
                )

In [None]:
display(IFrame("figures/convergence_simulation/convergence_nn_sq.pdf", 800, 600))

In [None]:
values, conf_int = collect_corr(corrs, alpha_beta_conv, percentile=95)


nc = 6
color_palette = sns.color_palette('colorblind', nc)
linestyles = [{'linestyle': l, 'color': c, 'marker': m} for l, c, m in zip(
    ["-"] * nc + ["--"] *nc , itertools.cycle(iter(color_palette)), markers)]

plot_convergence(values, get_labels(alpha_beta_conv), linestyles[:len(values)], 
                 #conf_int, 
                 #title="$\\lambda W^+ + (1 - \\lambda) W^-$", 
                 save_path="figures/convergence_simulation/convergence_ab.pdf",
                 ylogscale=True, 
                 clip_eps=1e-20
                )


In [None]:
display(IFrame("figures/convergence_simulation/convergence_ab.pdf", 800, 600))