In [2]:
import sys
sys.path.append('../')

import numpy as np
from scipy import ndimage

from skimage.filters import sobel_h
from skimage.filters import sobel_v
from scipy import stats


import os
import matplotlib
import matplotlib.pyplot as plt
import scienceplots
from tensorflow.python.client import device_lib

#plt.rcParams['figure.figsize'] = [10,10]

import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import decode_predictions
from tensorflow.keras.applications import VGG16

from tensorflow.nn import depthwise_conv2d
from tensorflow.math import multiply, reduce_sum, reduce_mean,reduce_euclidean_norm, sin, cos, abs
from tensorflow import stack, concat, expand_dims

import tensorflow_probability as tfp

from utils.utils import *


plt.style.use(['science', 'ieee'])
plt.rcParams.update({'figure.dpi': '100'})

#### Experiment #1 : $\beta$ accross all layers (Top-10% filters)

In [3]:
model = VGG16(weights='imagenet',
                  include_top=False,
                  input_shape=(224, 224, 3))

In [4]:
k = 20 # Top 10% highest l2-magnitude filters

In [95]:
conv_layers = []
for l in model.layers:
    if 'conv2d' in str(type(l)).lower():
        conv_layers.append(l)

full_a_a = []
full_a_s = []
full_s_a = []
full_s_s = []
for l_num, l in enumerate(conv_layers[:-1]):
    print(f" ============ LAYER {l_num} ==================")
    top10F=topKfilters(model,l_num, k=k)
    print(l_num, top10F)
    filters = get_filter(model, l_num)[:,:,:, top10F]
    s, a = getSymAntiSymTF(filters)
    a_e = reduce_euclidean_norm(a, axis=[0,1])**2
    #print(a_mag.shape)

    s_e = reduce_euclidean_norm(s, axis=[0,1])**2
    beta = tf.reduce_mean(a_e/(s_e+a_e), axis=0)


    antisym_filters = tf.reshape(tf.where(beta >= 0.5), [-1])
    sym_filters = tf.reshape(tf.where(beta < 0.5), [-1])


    #print(antisym_filters)

    '''a_mag = reduce_euclidean_norm(a, axis=[0,1, 2])**2
    a_mag = a_mag[top10F]
    s_mag = reduce_euclidean_norm(s, axis=[0,1, 2])**2
    s_mag = s_mag[top10F]'''

    a_a = 0
    a_s = 0
    s_a = 0
    s_s = 0

    nxt_filters = get_filter(model, l_num+1)
    nxt_s, nxt_a = getSymAntiSymTF(nxt_filters)
    nxt_a_mag = reduce_euclidean_norm(nxt_a, axis=[0,1])**2
    nxt_s_mag = reduce_euclidean_norm(nxt_s, axis=[0,1])**2
    nxt_avg_amag = reduce_mean(nxt_a_mag, axis=0).numpy()
    nxt_avg_smag = reduce_mean(nxt_s_mag, axis=0).numpy()

    for idx in topKfilters(model, l_num+1, k=k):
        #print("Filter ", idx , " : ", topKchannels(l_num+1,idx),"--->",set(topKchannels(l_num+1,idx)).intersection(top10F))
        #Is this filter symetric or antisymmetric ?
        is_anti = nxt_avg_amag[idx] > nxt_avg_smag[idx]
        #print()
        # Out of its top-k connections, how many inputs are comming from sym-filters, 
        #                               how many inputs are comming from antisym-filters?

        top_k_connections = list(set(topKchannels(model,l_num+1,idx, k=20)).intersection(top10F))
        print("TKC",top_k_connections)
        #Check : does top_k_connections lie in antisym_filters or in sym_filters (or neither) ?  
        #print (top_k_connections, antisym_filters.shape)
        #print(tf.squeeze(antisym_filters) )

        #print("HERE" , set(tf.squeeze(antisym_filters).numpy()).intersection(set(top_k_connections)))
        #print("HERE" , set(tf.squeeze(sym_filters).numpy()).intersection(set(top_k_connections)))
        if is_anti: 
            a_a += len(list(set(antisym_filters.numpy()).intersection(set(top_k_connections))))
            print(idx, "Anti->Anti : ", list(set(antisym_filters.numpy()).intersection(set(top_k_connections))))
            s_a += len(list(set(sym_filters.numpy()).intersection(set(top_k_connections))))
            print(idx,"Sym->Anti : ", list(set(sym_filters.numpy()).intersection(set(top_k_connections))))

        else:
            a_s += len(list(set((antisym_filters).numpy()).intersection(set(top_k_connections))))
            print(idx, "Anti->Sym : ", list(set((antisym_filters).numpy()).intersection(set(top_k_connections))))
            s_s += len(list(set((sym_filters).numpy()).intersection(set(top_k_connections))))
            print(idx, "Sym->Sym : ", list(set((sym_filters).numpy()).intersection(set(top_k_connections))))

    num_connections = a_a+a_s+s_a+s_s
    full_a_a.append(a_a/num_connections)
    full_a_s.append(a_s/num_connections)
    full_s_a.append(s_a/num_connections)
    full_s_s.append(s_s/num_connections)

    print("A -> A ",  a_a)
    print("A -> S ",  a_s)
    print("S -> A ",  s_a)
    print("S -> S ",  s_s)
    print(top_k_connections)

0 [10, 45, 51, 0, 47, 62, 3, 37, 58, 26, 41, 34]
TKC []
18 Anti->Sym :  []
18 Sym->Sym :  []
TKC [34, 3, 37, 41, 45, 51, 58, 62]
59 Anti->Anti :  [3]
59 Sym->Anti :  []
TKC [0, 34, 3, 37, 41, 10, 45, 62]
19 Anti->Anti :  [0, 10, 3]
19 Sym->Anti :  []
TKC [0, 3, 37, 10, 47, 51, 58]
0 Anti->Anti :  [0, 10, 3]
0 Sym->Anti :  []
TKC [0, 37, 10, 45, 51, 58, 62]
29 Anti->Anti :  [0, 10]
29 Sym->Anti :  []
TKC [0, 3, 37, 10, 47, 51]
10 Anti->Anti :  [0, 10, 3]
10 Sym->Anti :  []
TKC [34, 3, 37, 10, 45, 51, 26, 62]
30 Anti->Anti :  [10, 3]
30 Sym->Anti :  []
TKC [0, 3, 37, 10, 47, 51, 58]
47 Anti->Anti :  [0, 10, 3]
47 Sym->Anti :  []
TKC [41]
52 Anti->Sym :  []
52 Sym->Sym :  []
TKC [3, 26, 10, 45, 51, 58, 62]
55 Anti->Anti :  [10, 3]
55 Sym->Anti :  []
TKC [34, 37, 10, 45, 51, 58, 62]
13 Anti->Sym :  [10]
13 Sym->Sym :  []
TKC [0, 34, 3, 37, 41, 10, 45, 47, 26, 62]
20 Anti->Anti :  [0, 10, 3]
20 Sym->Anti :  []
A -> A  22
A -> S  1
S -> A  0
S -> S  0
[0, 34, 3, 37, 41, 10, 45, 47, 26, 62]
1

In [92]:
num_connections

2759

In [108]:
full_a_a
f = open("figures/vgg16_as_tree.tex", "w+")

f.write("\\begin{tikzpicture}\n\\begin{scope}[every node/.style={circle,thick,draw}, scale=0.8, transform shape]\n ")

pos = len(conv_layers)
for i in range(len(conv_layers)):
    pos -= 2
    f.write(f"    \\node (A_{i}) at (-2,{pos}) {{A}};\n")
    f.write(f"    \\node (S_{i}) at (2,{pos}) {{S}};\n")


f.write("\end{scope}")

f.write("\\begin{scope}[>={Stealth[black]},every node/.style={fill=white,circle},every edge/.style={draw=black, thin} ,scale=0.8, transform shape]\n ")
for i in range(len(conv_layers)-1):

    aa_max = "[red, very thick]" if np.argmax([full_a_a[i], full_a_s[i],full_s_a[i],full_s_s[i]]) == 0 else ""
    as_max = "[red, very thick]" if np.argmax([full_a_a[i], full_a_s[i],full_s_a[i],full_s_s[i]]) == 1 else ""
    sa_max = "[red, very thick]" if np.argmax([full_a_a[i], full_a_s[i],full_s_a[i],full_s_s[i]]) == 2 else ""
    ss_max = "[red, very thick]" if np.argmax([full_a_a[i], full_a_s[i],full_s_a[i],full_s_s[i]]) == 3 else ""

    pos -= 2
    f.write(f"        \\path [->] (A_{i}) edge {as_max}  node[near start]  {{${full_a_s[i]:.2f}$}} (S_{i+1});\n")
    f.write(f"        \\path [->] (S_{i}) edge {sa_max}  node[near start] {{${full_s_a[i]:.2f}$}} (A_{i+1});\n")
    f.write(f"        \\path [->] (A_{i}) edge {aa_max}  node[left] {{${full_a_a[i]:.2f}$}} (A_{i+1});\n")
    f.write(f"        \\path [->] (S_{i}) edge {ss_max}  node[right] {{${full_s_s[i]:.2f}$}} (S_{i+1});\n")

f.write("\end{scope}")


f.write("\end{tikzpicture}")


f.close()


In [98]:
full_a_a

[0.9565217391304348,
 0.0,
 0.7222222222222222,
 0.32061068702290074,
 0.319047619047619,
 0.35664335664335667,
 0.42955326460481097,
 0.39285714285714285,
 0.33640552995391704,
 0.3782771535580524,
 0.05098684210526316,
 0.0]