In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense, Dropout, LayerNormalization, Concatenate, Input, Flatten
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np

tf.keras.utils.set_random_seed(812)

Implementation of this network: https://arxiv.org/abs/1912.09363... Well at least the feature filtering aspect.

In [2]:
class Gate(Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.linear = Dense(units)
        self.sigmoid = Dense(units, activation="sigmoid")
    
    def call(self, inputs):
        return self.linear(inputs) * self.sigmoid(inputs)

class GRN(Layer):
    def __init__(self, units, drop, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.elu = Dense(units, activation= 'elu')
        self.linear = Dense(units)
        self.dropout = Dropout(drop)
        self.gate = Gate(units)
        self.layer_norm = LayerNormalization()
        self.filter = Dense(units)

    
    def call(self, inputs):
        x = self.elu(inputs)
        x = self.linear(x)
        x = self.dropout(x)
        if inputs.shape[-1] != self.units:
            inputs = self.filter(inputs)
        x = inputs + self.gate(x)
        x = self.layer_norm(x)
        return x

In [3]:
class VSN(Layer):
    def __init__(self, units, drop, feat_num, **kwargs):
        super().__init__(**kwargs)
        self.feat_num = feat_num
        self.grn_list = [GRN(units, drop) for _ in range(feat_num)]
        self.grn = GRN(units, drop)
        self.softmax = Dense(units = feat_num, activation = 'softmax')
    
    def call(self, inputs):
        all = self.grn(inputs)
        all = tf.expand_dims(self.softmax(all), axis = -1)   
        
        indi = []
        split_list = tf.split(inputs, self.feat_num, axis = -1)
        for idx, i in enumerate(split_list):
            indi.append(self.grn_list[idx](i))
        indi = tf.stack(indi, axis=1)
        
        outputs = tf.squeeze(tf.matmul(all, indi, transpose_a=True), axis=1)
        return outputs

Base Models and base testing:

In [4]:
# Concatenate into own GRU layer. And dot product with total GRN.
def vsn_mdl(df):
    inputs_1 = tf.keras.Input(shape = (df.shape[1],))
    x = VSN(16, 0.5, df.shape[1])(inputs_1)
    x = VSN(8, 0.5, 16)(x)
    output = Dense(units = 1, activation = 'sigmoid')(x)
    return tf.keras.Model(inputs_1, output)

In [14]:
def base_model(df):
    start = Input(shape = (df.shape[1],))
    end = Dense(units = 1, activation = 'sigmoid')(start)
    model = tf.keras.Model(start, end)
    return model

In [7]:
df = pd.read_csv('titanic.csv').select_dtypes(exclude = object)
# df = df.drop('Survived', axis = 1)
# y = df.Survived

In [10]:
mdl = vsn_mdl(df.drop('Survived', axis = 1))
mdl.compile(loss = 'binary_crossentropy', optimizer = 'SGD', metrics = ['Accuracy'])
history = mdl.fit(x = df.iloc[:, 1:], y = df.Survived, epochs = 30, verbose = 0)
print(max(history.history['Accuracy']))

0.6719278693199158


In [13]:
mdl = base_model(df.drop('Survived', axis = 1))
mdl.compile(loss = 'binary_crossentropy', optimizer = 'SGD', metrics = ['Accuracy'])
history = mdl.fit(x = df.iloc[:, 1:], y = df.Survived, epochs = 30, verbose = 0)
print(max(history.history['Accuracy']))

0.6009019017219543


DNN testing:

In [19]:
def variable_dnn(df):
    start = Input(shape = (df.shape[1],))
    x = VSN(64, 0.2, df.shape[1],)(start)
    x = VSN(32, 0.2, 64)(x)
    x = VSN(16, 0.2, 32)(x)
    x = Dense(64, activation = 'relu')(x)
    x = Dense(32, activation = 'relu')(x)
    end = Dense(units = 1, activation = 'sigmoid')(x)
    model = tf.keras.Model(start, end)
    return model

In [25]:
def base_dnn(df):
    start = Input(shape = (df.shape[1],))
    x = Dense(64, activation = 'relu')(start)
    x = Dense(32, activation = 'relu')(x)
    end = Dense(units = 1, activation = 'sigmoid')(x)
    model = tf.keras.Model(start, end)
    return model

In [21]:
mdl = variable_dnn(df.drop('Survived', axis = 1))
mdl.compile(loss = 'binary_crossentropy', optimizer = 'SGD', metrics = ['Accuracy'])
history = mdl.fit(x = df.iloc[:, 1:], y = df.Survived, epochs = 30, verbose = 0)
print(max(history.history['Accuracy']))

0.6944757699966431


In [26]:
mdl = base_dnn(df.drop('Survived', axis = 1))
mdl.compile(loss = 'binary_crossentropy', optimizer = 'SGD', metrics = ['Accuracy'])
history = mdl.fit(x = df.iloc[:, 1:], y = df.Survived, epochs = 30, verbose = 0)
print(max(history.history['Accuracy']))

0.7080045342445374


Generalization ability using TTS

In [27]:
X_train, X_test, y_train, y_test = train_test_split(df.iloc[:, 1:], df.Survived)

In [29]:
mdl = vsn_mdl(df.drop('Survived', axis = 1))
mdl.compile(loss = 'binary_crossentropy', optimizer = 'SGD', metrics = ['Accuracy'])
mdl.fit(x = X_train, y = y_train, epochs = 30, verbose = 0)
preds = mdl.predict(X_test, verbose = 0)
(np.round(preds).flatten() == y_test).sum()/ len(X_test)

0.6981981981981982

In [30]:
mdl = base_model(df.drop('Survived', axis = 1))
mdl.compile(loss = 'binary_crossentropy', optimizer = 'SGD', metrics = ['Accuracy'])
mdl.fit(x = X_train, y = y_train, epochs = 30, verbose = 0)
preds = mdl.predict(X_test, verbose = 0)
(np.round(preds).flatten() == y_test).sum()/ len(X_test)

0.6711711711711712

So the results above is trained on the titanic dataset. We test variable selection model on three tests. No hidden network (input of (5, 1) -> output (1, )) network, a dnn network with two hidden layers with 64 and 32 nodes, then finally testing generalization ability.

In every testing result, the VSN network performed well. It outshined a standard ANN with no hidden network, and also generalization ability. However, VSN faired slightly weaker when adding a deep hidden netowrk.

This is just a mini test to see the merits and to test VSN network architecture.

Ultimately, the tests are not concrete as I have not added any standard scaling or sorts which is standard for neural networks. And also the dataset is not only simple but not the type that needs a neural network, or to need and attention mechanism.