In [None]:
pip install gast==0.2.2

In [None]:
%tensorflow_version 1.x

In [None]:
from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.callbacks import ModelCheckpoint
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow.keras.backend as K
from tensorflow.keras.utils import *
import pickle
import tensorflow as tf
import re
print(tf.__version__)

- Data

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

In [None]:
y_train1 = y_train.copy()
y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)
x_train = x_train.reshape(-1, 784)

x_test = x_test.reshape(-1, 784)
x_train = x_train / 255.0
x_test = x_test / 255.0
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

- Utils

In [None]:

def getEig(g_mat):
    eig_vals = np.linalg.eig(g_mat)
    eig_vals = np.real(eig_vals[0])
    return eig_vals

def getRandomIndices(max_lim, num_ind):
    all_indices  = np.arange(0, max_lim)
    np.random.shuffle(np.arange(0, max_lim))
    return all_indices[: num_ind]

def getUniqueLabelExamples(x_train, y_train, num_class, num_each_class):
    x_batch = x_train[:num_each_class * num_class,:].copy()
    y_batch = y_train[:num_each_class * num_class,:].copy()

    for ind in range(10):
        x_temp = x_train[(y_train1 == ind), :]
        y_temp = y_train[(y_train1 == ind), :]
        max_lim = x_temp.shape[0]
        indices = getRandomIndices(max_lim, num_each_class)
        x_batch[ind*num_each_class: (ind+1)*num_each_class, :] = x_temp[indices, :]
        y_batch[ind*num_each_class: (ind+1)*num_each_class, : ] = y_temp[indices, :]

    return x_batch, y_batch

sess = tf.compat.v1.keras.backend.get_session()

def getNetwork1Layers(model):
    pattern = re.compile("^R[0-9]*$")
    layer_list  = []

    for layer in model.layers:
        if pattern.match(layer.name):
            layer_list.append(layer)

    return layer_list

def getNetwork2Layers(model):
    pattern = re.compile("^G[0-9]*$")
    layer_list  = []

    for layer in model.layers:
        if pattern.match(layer.name):
            layer_list.append(layer)

    return layer_list
    
def freezeWeights(layers):
    for layer in layers:
        layer.trainable = False

def getNetwork1TrainableWts(model):
    layer_list = getNetwork1Layers(model)
    wts = []
    for layer in layer_list:
        wts.append((layer.trainable_weights + layer.non_trainable_weights)[0])

    return wts

def getNetwork2TrainableWts(model):
    layer_list = getNetwork2Layers(model)
    wts = []
    for layer in layer_list:
        wts.append((layer.trainable_weights + layer.non_trainable_weights)[0])

    return wts

def getNTK(model, x_batch, wts, epoch_i):
    if epoch_i == 0:
        init = tf.global_variables_initializer()
        sess.run(init)

    gradients = K.gradients(model.output, wts)
    Phi_E = None
    for i in (range(x_batch.shape[0])):
        grad = sess.run(gradients, feed_dict={model.input:x_batch[i:i+1, :]})
        row_grad = []
        for grad_layer in grad:
            row_grad += grad_layer.flatten().tolist()

        if Phi_E is None:
            Phi_E = np.zeros((x_batch.shape[0], len(row_grad)))

        Phi_E[i, :] = row_grad

    return np.dot(Phi_E, Phi_E.T)


- Models

In [None]:
eps, beta = 0.1, 4
class SoftGate(Layer):
    def __init__(self, **kwargs):
        super(SoftGate, self).__init__(**kwargs)

    def build(self, input_shape):
        super(SoftGate, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        activation = (1 + eps)*K.sigmoid(beta*x)
        return activation

    def compute_output_shape(self, input_shape):
        return input_shape

def getGalu(depth, width):
    inputs = Input(shape = (784, ))

    R1 = Dense(units = width, activation = 'linear', name = "R1")(inputs)
    R1A = Activation('relu')(R1)
    A1 = SoftGate()(R1)

    G1 = Dense(units = width, activation = 'linear',  name = "G1")(inputs)
    G1A = Multiply(name = "G1A")([G1, A1])

    for i in range(depth - 2):
        R1 = Dense(units = width, activation = 'linear',name = "R"+str(i+2))(R1A)
        R1A = Activation('relu')(R1)
        A1 = SoftGate()(R1)

        G1 = Dense(units = width, activation = 'linear', name = "G"+str(i+2))(G1A)
        G1A = Multiply(name = "G"+str(i+2)+"A")([G1, A1])

    outputs = Dense(units = 10, activation = "softmax", name = "G"+str(depth))(G1A)
    model = keras.Model(inputs = inputs, outputs = outputs, name = 'galu_model')

    return model

In [None]:
depth, width = 6, 128
lr = 1e-2
loss = keras.losses.categorical_crossentropy
opt = keras.optimizers.SGD
batch_size = 32
num_exp = 5
num_epochs = 60

history_galu = {'acc':[], 'val_acc':[], 'loss': [], 'val_loss': [], 
                'norm_value': [], 'norm_gate': [], 
                'trace_value': [], 'trace_gate': []}

In [None]:
x_batch, y_batch = getUniqueLabelExamples(x_train, y_train, 10, 20)
print(x_batch.shape, y_batch.shape)

- Train Galu

In [None]:
for exp_i in range(num_exp):
    print("_____________EXP:{}____________".format(exp_i+1))
    model_galu = getGalu(depth, width)
    model_galu.compile(loss = loss, optimizer = opt(lr), metrics = ['acc'])
    acc_temp = []
    val_acc_temp = []
    trace_temp_value = []
    trace_temp_gate = []
    norm_temp_value = []
    norm_temp_gate = []

    for epoch_i in range(num_epochs):
        if epoch_i % 5 == 0:
            wts = getNetwork2TrainableWts(model_galu)
            K_t = getNTK(model_galu, x_batch, wts, epoch_i)
            K_t_flatten = K_t.flatten()

            trace_temp_value.append(np.sum(np.diagonal(K_t)))
            norm_temp_value.append(np.linalg.norm(K_t_flatten))

            wts = getNetwork1TrainableWts(model_galu)
            K_t = getNTK(model_galu, x_batch, wts, epoch_i)
            K_t_flatten = K_t.flatten()

            trace_temp_gate.append(np.sum(np.diagonal(K_t)))
            norm_temp_gate.append(np.linalg.norm(K_t_flatten))

        history = model_galu.fit(x_train, y_train, validation_data = (x_test, y_test), verbose = 0,
                             batch_size=batch_size, epochs= 1)
        
        acc_temp.append(history.history['acc'][0])
        val_acc_temp.append(history.history['val_acc'][0])
    
    history_galu['acc'].append(acc_temp)
    history_galu['val_acc'].append(val_acc_temp)
    history_galu['trace_value'].append(trace_temp_value)
    history_galu['trace_gate'].append(trace_temp_gate)
    
    history_galu['norm_value'].append(norm_temp_value)
    history_galu['norm_gate'].append(norm_temp_gate)

    print("GaLU: MAX ACC = {:.4f}, MAX VAL ACC = {:.4f}".format(np.max(history.history['acc']), 
                                                        np.max(history.history['val_acc'])))

In [None]:
print("GaLU: max_acc = {:.4f}, mean_max_val_acc = {:.4f}, std_max_val_acc = {:.4f}".format(
                                                    np.mean(np.max(history_galu['acc'], axis = 1)), 
                                                    np.mean(np.max(history_galu['val_acc'], axis = 1)),
                                                    np.std(np.max(history_galu['val_acc'], axis = 1))))

In [None]:
trace_value = np.mean(history_galu['trace_value'], axis = 0)
trace_gate = np.mean(history_galu['trace_gate'], axis = 0)
norm_value = np.mean(history_galu['norm_value'], axis = 0)
norm_gate = np.mean(history_galu['norm_gate'], axis = 0)

plt.plot(norm_value, '--', label = 'norm_value')
plt.plot(norm_gate, '--', label = 'norm_gate')
plt.plot(trace_value, label = 'trace_value')
plt.plot(trace_gate, label = 'trace_gate')
plt.legend()
plt.show()