In [1]:
from tensorflow.keras.models import load_model
import mnist
import dnn
import os.path
import graph_tool as gt
import nn2graph
import numpy as np
import pandas as pd
import statsmodels.stats as stats
import matplotlib.pyplot as plt
import seaborn as sns

#### Build or load model

In [2]:
load_saved_model = True
if load_saved_model and os.path.exists('data/outputs/mnist_dnn_4x64_10.h5'):
    model = load_model('data/outputs/mnist_dnn_4x64_10.h5')
else:
    model = dnn.build_dense(mnist.X_train.shape[-1], hidden_dims=64)
    model.fit(mnist.X_train, mnist.y_train, batch_size=128, epochs=10, validation_data=(mnist.X_test, mnist.y_test))
    model.save('data/outputs/mnist_dnn_4x64_10.h5')

In [3]:
print(model.evaluate(mnist.X_test, mnist.y_test, verbose=0))
y_pred = model.predict(mnist.X_test)

[2.2895871978759765, 0.0974]


#### Convert neural network to graph

In [4]:
N = 5000
rand_inds = mnist.subsample_test(N)
g_full, g_xs = nn2graph.dense_activations_to_graph(model, mnist.X_test[rand_inds])

In [None]:
layer_sizes = [X_test.shape[-1]] + [layer.units for layer in model.layers if isinstance(layer, Dense)]
i_max = np.argmax([g.num_edges() for g in g_xs])
i_min = np.argmin([g.num_edges() for g in g_xs])
g_max = g_xs[i_max]
g_min = g_xs[i_min]
g_by_class = [[g for j, g in enumerate(g_xs) if y_test_orig[j] == i] for i in range(10)]
print('edge max label: {}'.format(y_test_orig[i_max]))
print('edge min label: {}'.format(y_test_orig[i_min]))

#### Mean weighted degree distributions per class

In [None]:
g_xs_mean_degree_by_class = pd.DataFrame({'digit class': y_test_orig[rand_inds],
                                          'mean weighted degree': [float(g.vp['degree'].a.mean()) for g in g_xs]})
plt.figure(figsize=(8,6))
sns.violinplot(x='digit class', y='mean weighted degree', data=g_xs_mean_degree_by_class)
plt.title('Mean weighted degree by digit label')
plt.show()

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(3*8,2*6))
for i, (l, st) in enumerate(zip(layer_sizes, np.cumsum(layer_sizes))):
    ax = axs[0 if i < 3 else 1, i%3]
    g_xs_mean_degree_by_class = pd.DataFrame({'digit class': y_test_orig[rand_inds],
                                              'mean weighted degree': [float(g.vp['degree'].a[st-l:st].mean()) for g in g_xs]})
    sns.violinplot(x='digit class', y='mean weighted degree', data=g_xs_mean_degree_by_class,
                   ax=ax)
    ax.set_title('Mean weighted degree by digit label for layer {}'.format(i))
plt.show()

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(5*12,8))
for i, (l, st) in enumerate(zip(layer_sizes, np.cumsum(layer_sizes))):
    ax = axs[i]
    g_xs_std_degree_by_class = pd.DataFrame({'digit class': y_test_orig[rand_inds],
                                              'stddev weighted degree': [float(g.vp['degree'].a[st-l:st].std()) for g in g_xs]})
    sns.violinplot(x='digit class', y='stddev weighted degree', data=g_xs_std_degree_by_class,
                   ax=ax)
    ax.set_title('Standard deviation of weighted degree by digit label for layer {}'.format(i))
plt.show()

In [None]:
from statsmodels.stats.weightstats import ztest

groups_by_class = g_xs_mean_degree_by_class.groupby('digit class').groups.items()
p_vals = np.ones((10,10))
for i, (c_1, inds_1) in enumerate(groups_by_class):
    for j, (c_2, inds_2) in enumerate(groups_by_class):
        if j >= i:
            continue
        sample_1 = g_xs_mean_degree_by_class.loc[inds_1, 'mean weighted degree']
        sample_2 = g_xs_mean_degree_by_class.loc[inds_2, 'mean weighted degree']
        z, p = ztest(sample_1, sample_2)
        p_vals[i, j] = p
plt.figure(figsize=(8,6))
sns.heatmap(p_vals, cmap='Blues_r')
plt.title('p-values for mean degree z-tests, per digit')
plt.show()

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(3*8,2*6))
for k, (l, st) in enumerate(zip(layer_sizes, np.cumsum(layer_sizes))):
    ax = axs[0 if k < 3 else 1, k%3]
    g_xs_mean_degree_by_class_l = pd.DataFrame({'digit class': y_test_orig[rand_inds],
                                                'mean weighted degree': [float(g.vp['degree'].a[st-l:st].mean()) for g in g_xs]})
    groups_by_class = g_xs_mean_degree_by_class_l.groupby('digit class').groups.items()
    p_vals = np.ones((10,10))
    for i, (c_1, inds_1) in enumerate(groups_by_class):
        for j, (c_2, inds_2) in enumerate(groups_by_class):
            if j >= i:
                continue
            sample_1 = g_xs_mean_degree_by_class_l.loc[inds_1, 'mean weighted degree']
            sample_2 = g_xs_mean_degree_by_class_l.loc[inds_2, 'mean weighted degree']
            z, p = ztest(sample_1, sample_2)
            p_vals[i, j] = p
    sns.heatmap(p_vals, cmap='Blues_r', ax=ax)
    ax.set_title('p-values for mean degree z-tests, per digit for layer {}'.format(k))
plt.show()

#### Mean weighted degrees vs number of active inputs

In [None]:
g_xs_degree_by_inputs = pd.DataFrame({'active inputs': np.sum(X_test[rand_inds], axis=1),
                                      'digit class': y_test_orig[rand_inds],
                                      'mean weighted degree': [float(g.vp['degree'].a.mean()) for g in g_xs]})
plt.figure(figsize=(12,8))
sns.scatterplot(x='active inputs', y='mean weighted degree', hue='digit class',
                palette=sns.color_palette('hls', 10), alpha=0.4,
                data=g_xs_degree_by_inputs)
plt.title('Mean weighted degree by number of active inputs')
plt.show()

In [None]:
layer_sizes = [784, 64, 64, 64, 64, 10]
fig, axs = plt.subplots(2, 3, figsize=(3*12,2*8))
for i, (l, st) in enumerate(zip(layer_sizes, np.cumsum(layer_sizes))):
    ax = axs[0 if i < 3 else 1, i%3]
    g_xs_mean_degree_vs_inputs = pd.DataFrame({'active inputs': np.sum(X_test[rand_inds], axis=1),
                                               'digit class': y_test_orig[rand_inds],
                                               'mean weighted degree': [float(g.vp['degree'].a[st-l:st].mean()) for g in g_xs]})
    sns.scatterplot(x='active inputs', y='mean weighted degree', hue='digit class',
                    palette=sns.color_palette('hls', 10), alpha=0.4,
                    data=g_xs_mean_degree_vs_inputs,
                    ax=ax)
    ax.set_title('Mean weighted degree by number of active inputs for layer {}'.format(i))
plt.show()

In [None]:
from graph_tool.centrality import eigenvector
eig_max, eig_vp = eigenvector(g, weight=g.ep['weight'], max_iter=1.0E5)

In [None]:
eig_centralities = [(g.vertex_index[v], eig_vp[v]) for v in g.vertices()]

In [None]:
vs_in = [v for v,_ in eig_centralities if v < input_dims]
eigs_in = [eigc for v, eigc in eig_centralities if v < input_dims]
vs1 = [v for v,_ in eig_centralities if v >= input_dims and v < input_dims+hidden_dims]
eigs1 = [eigc for v, eigc in eig_centralities if v >= input_dims and v < input_dims+hidden_dims]
vs2 = [v for v,_ in eig_centralities if v >= input_dims+hidden_dims and v < input_dims+2*hidden_dims]
eigs2 = [eigc for v, eigc in eig_centralities if v >= input_dims+hidden_dims and v < input_dims+2*hidden_dims]
vs3 = [v for v,_ in eig_centralities if v >= input_dims+2*hidden_dims and v < input_dims+3*hidden_dims]
eigs3 = [eigc for v, eigc in eig_centralities if v >= input_dims+2*hidden_dims and v < input_dims+3*hidden_dims]
vs4 = [v for v,_ in eig_centralities if v >= input_dims+3*hidden_dims and v < input_dims+4*hidden_dims]
eigs4 = [eigc for v, eigc in eig_centralities if v >= input_dims+3*hidden_dims and v < input_dims+4*hidden_dims]
vs_out = [v for v,_ in eig_centralities if v >= input_dims+4*hidden_dims]
eigs_out = [eigc for v, eigc in eig_centralities if v >= input_dims+4*hidden_dims]
plt.figure(figsize=(16,6))
plt.subplot(1,2,1)
plt.scatter(vs_in, eigs_in, c='gray')
plt.scatter(vs1, eigs1, c='r')
plt.scatter(vs2, eigs2, c='b')
plt.scatter(vs3, eigs3, c='g')
plt.scatter(vs4, eigs4, c='c')
plt.scatter(vs_out, eigs_out, c='m')
plt.legend(['inputs', 'layer 1', 'layer 2', 'layer 3', 'layer 4', 'output'])
plt.xlabel('vertex label')
plt.ylabel('eigenvector centrality')
plt.title('Hidden layers + inputs')
plt.subplot(1,2,2)
plt.scatter(vs1, eigs1, c='r')
plt.scatter(vs2, eigs2, c='b')
plt.scatter(vs3, eigs3, c='g')
plt.scatter(vs4, eigs4, c='c')
plt.scatter(vs_out, eigs_out, c='m')
plt.legend(['layer 1', 'layer 2', 'layer 3', 'layer 4', 'output'])
plt.xlabel('vertex label')
plt.ylabel('eigenvector centrality')
plt.title('Hidden layers only')
plt.suptitle('Eigenvector centrality vs. nodes')
plt.show()

In [None]:
from graph_tool.inference.minimize import minimize_nested_blockmodel_dl
min_state = minimize_nested_blockmodel_dl(g_max)
print(min_state)

In [None]:
min_state.draw()

In [None]:
x = min_state.get_bs()[0]
plt.imshow(x[:784].reshape((28,28)))
plt.colorbar()
plt.show()
plt.imshow(X_test[i_max].reshape((28,28)))
plt.show()

In [None]:
import pandas as pd

bs = min_state.get_bs()[0]
df = pd.DataFrame.from_dict({'group': bs, 'node label': list(range(len(bs))), 'layer': g_max.vp['layer'].a})
sns.catplot(x='group', y='node label', data=df, hue='layer')

In [None]:
from graph_tool.inference.minimize import minimize_nested_blockmodel_dl
min_state = minimize_nested_blockmodel_dl(g_min)
print(min_state)
min_state.draw(layout='sfdp')

In [None]:
plt.scatter(range(g_min.num_vertices()), min_state.get_bs()[0])