## Experiment
We train a ReLU network for binary-mnist dataset. We plot $y^T(H_t)^{-1}y$ over time where path gram $H_t = {\Phi_t}^T{\Phi_t} = (x^Tx)\odot\lambda_t$ and $\lambda_t(s, s')$ stands for the total number of paths which are active for both input examples $s$ and $s'$.

## Import Libraries

In [None]:
! sudo apt-get install texlive-latex-recommended #1
! sudo apt-get install dvipng texlive-fonts-recommended #2
! wget http://mirrors.ctan.org/macros/latex/contrib/type1cm.zip #3
! unzip type1cm.zip -d /tmp/type1cm #4
! cd /tmp/type1cm/type1cm/ && sudo latex type1cm.ins  #5
! sudo mkdir /usr/share/texmf/tex/latex/type1cm #6
! sudo cp /tmp/type1cm/type1cm/type1cm.sty /usr/share/texmf/tex/latex/type1cm #7
! sudo texhash #8

- binaryMnist_galu-vs-ReluThenGalu

In [None]:
import tensorflow as tf
print(tf.__version__)
import keras as tfk
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import time
import re
from sklearn.model_selection import *
import keras.backend as K
from keras.layers import *

## Data

In [None]:
labels = [-1, 1]
def getMnist():
    (x_train, y_train), (x_test, y_test) = tfk.datasets.mnist.load_data()
    max_x = np.max(x_train)

    x_train = x_train / max_x
    x_test = x_test / max_x

    bool_train = ((y_train == 4) | (y_train == 7))
    bool_test = ((y_test == 4) | (y_test == 7))

    x_train = x_train[bool_train]
    y_train = y_train[bool_train].astype(np.int32)
    x_test = x_test[bool_test]
    y_test = y_test[bool_test].astype(np.int32)
    y_train[y_train == 4] = labels[0]
    y_test[y_test == 4] = labels[0]
    y_train[y_train == 7] = labels[1]
    y_test[y_test == 7] = labels[1]

    x_train = x_train.reshape(x_train.shape[0], -1)
    x_test = x_test.reshape(x_test.shape[0], -1)

    return x_train, y_train, x_test, y_test

In [None]:
x_train, y_train, x_test, y_test = getMnist()
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

## Utils

In [None]:
def getEig(g_mat):
    eig_vals = np.linalg.eig(g_mat)
    eig_vals = np.real(eig_vals[0])
    return eig_vals

def getRandomIndices(max_lim, num_ind):
    all_indices  = np.arange(0, max_lim)
    np.random.shuffle(np.arange(0, max_lim))
    return all_indices[: num_ind]

def getUniqueLabelExamples(x_train, y_train, num_class, num_each_class):
    x_batch = x_train[:num_each_class * num_class,:].copy()
    y_batch = np.zeros((num_each_class * num_class, ))

    for ind, label in enumerate(labels):
        x_temp = x_train[(y_train == label), :]
        max_lim = x_temp.shape[0]
        indices = getRandomIndices(max_lim, num_each_class)
        x_batch[ind*num_each_class: (ind+1)*num_each_class, :] = x_temp[indices, :]
        y_batch[ind*num_each_class: (ind+1)*num_each_class, ] = label

    return x_batch, y_batch

sess = K.get_session()

def relu(z):
    return np.maximum(0, z)

def getLayersOutput(wts, X):
    layer_output = [X]
    for i, w in enumerate(wts[:-1]):
        # print(X.shape)
        output = relu(np.matmul(X, w))
        layer_output.append(output)
        X = output
    w = wts[-1]
    # print(X.shape, w.shape)
    output = np.matmul(X, w)
    layer_output.append(output)
    return layer_output    

def getLayerActivation(wts, X):
    layer_outputs = getLayersOutput(wts, X)
    layer_activity = []
    for output in layer_outputs[1:-1]:
        activity = (output > 0).astype(np.int)
        layer_activity.append(activity)
    return layer_activity

def getActivityGram(layer_activity):
    num_input = layer_activity[0].shape[0]
    layer_gram = []
    total_gram = np.ones((num_input, num_input))
    for activity in layer_activity:
        gram = np.dot(activity, activity.T)
        total_gram = np.multiply(total_gram, gram)

    return total_gram

## Model

In [None]:
width = 100
depth = 6

In [None]:
from keras import backend as K
from keras.layers import Layer

In [None]:
input_sz = 784
output_sz = 1
num_epochs = 50
batch_size = 64
num_exp = 5

def my_init2(shape, dtype = None):
    return tf.Variable(np.random.choice([-np.sqrt(2/width), np.sqrt(2/width)], 
                                        shape, [0.5, 0.5]), dtype = dtype, trainable= True)

def getRelu():
    inputs = tfk.Input(shape = (input_sz, ))
    R1 = Dense(units = width, activation = 'relu', use_bias = False, 
               kernel_initializer = my_init2, name = "R1")(inputs)

    for i in range(depth - 2):
        R1 = Dense(units = width, activation = 'relu', use_bias = False, 
                   kernel_initializer = my_init2, name = "R"+str(i+2))(R1)

    outputs = Dense(units = output_sz, activation = "linear", use_bias = False,
                    kernel_initializer = my_init2, name = "R"+str(depth))(R1)
    model = tfk.Model(inputs = inputs, outputs = outputs, name = 'mnist_model')

    return model

In [None]:
history_relu = {'mean_squared_error' : [], 'val_mean_squared_error' : [], 'norm_normalized' : [], 
                'norm_unnormalized' : []}

In [None]:
x_batch, y_batch = getUniqueLabelExamples(x_train, y_train, 2, 100)
input_gram = np.dot(x_batch, x_batch.T)
print(x_batch.shape, y_batch.shape)

## Train

In [None]:
for exp_i in range(num_exp):
    model = getRelu()

    optimizer = tfk.optimizers.RMSprop(lr = 1e-4)
    model.compile(optimizer= optimizer, loss = 'mse', metrics = ['mse'])
    model.fit(x_train[:1], y_train[:1], epochs = 1, batch_size = 1, verbose = 0)

    mse_temp = []
    val_mse_temp = []
    norm_temp_normalized = []
    norm_temp_unnormalized = []

    for epoch_i in (range(num_epochs)):
        if epoch_i % 2 == 0:
            wts = model.get_weights()
            layer_activity = getLayerActivation(wts, x_batch)
            K_t = np.multiply(input_gram, getActivityGram(layer_activity))
            K_t_normalized = K_t / np.sum(np.diagonal(K_t))

            norm = np.dot(y_batch, np.dot(np.linalg.inv(K_t), y_batch))
            norm_temp_unnormalized.append(norm)

            norm = np.dot(y_batch, np.dot(np.linalg.inv(K_t_normalized), y_batch))
            norm_temp_normalized.append(norm)

        history = model.fit(x_train, y_train, validation_data=(x_test, y_test), 
                                    epochs = 1, batch_size = batch_size, verbose = 0)
        
        mse_temp.append(history.history['mean_squared_error'][0])
        val_mse_temp.append(history.history['val_mean_squared_error'][0])

    history_relu['mean_squared_error'].append(mse_temp)
    history_relu['val_mean_squared_error'].append(val_mse_temp)
    history_relu['norm_normalized'].append(norm_temp_normalized)
    history_relu['norm_unnormalized'].append(norm_temp_unnormalized)

In [None]:
mean_norm = np.mean(history_relu['norm_normalized'], axis = 0)
mean_std = np.std(history_relu['norm_normalized'], axis = 0)
plt.errorbar(x = np.arange(mean_norm.shape[0]), y = mean_norm, yerr = mean_std, marker = 'o')
plt.ylabel("norm", fontsize = 14)
plt.xlabel("epochs", fontsize = 14)
plt.show()