In [283]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense, Dropout, LayerNormalization, Concatenate, Input, Flatten
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np

tf.keras.utils.set_random_seed(812)

In [17]:
class Gate(Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.linear = Dense(units, activation = 'linear')
        self.sigmoid = Dense(units, activation = 'sigmoid')
    
    def call(self, inputs):
        return self.sigmoid(inputs) * self.linear(inputs)

In [128]:
class GRN(Layer):
    def __init__(self, units, drop, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.elu = Dense(units, activation = 'elu')
        self.linear = Dense(units, activation = 'linear')
        self.drop = Dropout(drop)
        self.gate = Gate(units)
        self.layer_norm = LayerNormalization()
        self.filter = Dense(units, activation = 'linear')
    
    def call(self, inputs):
        x = self.elu(inputs)
        x = self.linear(x)
        x = self.drop(x)
        if inputs.shape[-1] != self.units:
            inputs = self.filter(inputs)
        x = self.gate(x)
        x = x + inputs
        x = self.layer_norm(x)
        return x

In [219]:
# Concatenate into own GRU layer. And dot product with total GRN.

class VariableSelection(Layer):
    def __init__(self, units, drop, **kwargs):
        super().__init__(**kwargs)
        self.gru = GRN(units, drop)
        self.softmax = Dense(units, activation = 'softmax')
    
    def call(self, inputs):
        feat_list = list()
        for i in tf.unstack(inputs, axis = 1):
            x = self.gru(i)
            feat_list.append(x)
        individual_gru = Concatenate()(feat_list)
        individual_gru = tf.expand_dims(self.softmax(individual_gru), axis = -1)
        total = self.gru(inputs)
        output = tf.linalg.matmul(total, individual_gru)
        return output

In [228]:
def variable_selection(df):
    start = Input(shape = (df.shape[1], 1))
    x = VariableSelection(64, 0.6)(start)
    x = VariableSelection(32, 0.2)(start)
    x = VariableSelection(16, 0.2)(x)
    x = Flatten()(x)
    end = Dense(units = 1, activation = 'sigmoid')(x)
    model = tf.keras.Model(start, end)
    return model

In [230]:
def base_model(df):
    start = Input(shape = (df.shape[1], 1))
    x = Flatten()(start)
    end = Dense(units = 1, activation = 'sigmoid')(x)
    model = tf.keras.Model(start, end)
    return model

In [6]:
df = pd.read_csv('titanic.csv').select_dtypes(exclude = object)

In [284]:
mdl = variable_selection(df.drop('Survived', axis = 1))
mdl.compile(loss = 'binary_crossentropy', optimizer = 'SGD', metrics = ['Accuracy'])
history = mdl.fit(x = df.iloc[:, 1:], y = df.Survived, epochs = 30, verbose = 0)
print(max(history.history['Accuracy']))

0.6493799090385437


In [285]:
mdl = base_model(df.drop('Survived', axis = 1))
mdl.compile(loss = 'binary_crossentropy', optimizer = 'SGD', metrics = ['Accuracy'])
history = mdl.fit(x = df.iloc[:, 1:], y = df.Survived, epochs = 30, verbose = 0)
print(max(history.history['Accuracy']))

0.5873731970787048


DNN testing:

In [286]:
def variable_dnn(df):
    start = Input(shape = (df.shape[1], 1))
    x = VariableSelection(64, 0.2)(start)
    x = VariableSelection(32, 0.2)(x)
    x = VariableSelection(16, 0.2)(x)
    x = Flatten()(x)
    x = Dense(64, activation = 'relu')(x)
    x = Dense(32, activation = 'relu')(x)
    end = Dense(units = 1, activation = 'sigmoid')(x)
    model = tf.keras.Model(start, end)
    return model

In [287]:
def base_dnn(df):
    start = Input(shape = (df.shape[1], 1))
    x = Flatten()(start)
    x = Dense(64, activation = 'relu')(x)
    x = Dense(32, activation = 'relu')(x)
    end = Dense(units = 1, activation = 'sigmoid')(x)
    model = tf.keras.Model(start, end)
    return model

In [288]:
mdl = variable_dnn(df.drop('Survived', axis = 1))
mdl.compile(loss = 'binary_crossentropy', optimizer = 'SGD', metrics = ['Accuracy'])
history = mdl.fit(x = df.iloc[:, 1:], y = df.Survived, epochs = 30, verbose = 0)
print(max(history.history['Accuracy']))

0.6324689984321594


In [289]:
mdl = base_dnn(df.drop('Survived', axis = 1))
mdl.compile(loss = 'binary_crossentropy', optimizer = 'SGD', metrics = ['Accuracy'])
history = mdl.fit(x = df.iloc[:, 1:], y = df.Survived, epochs = 30, verbose = 0)
print(max(history.history['Accuracy']))

0.7113866806030273


Generalization ability using TTS

In [257]:
X_train, X_test, y_train, y_test = train_test_split(df.iloc[:, 1:], df.Survived)

In [290]:
mdl = variable_selection(df.drop('Survived', axis = 1))
mdl.compile(loss = 'binary_crossentropy', optimizer = 'SGD', metrics = ['Accuracy'])
mdl.fit(x = X_train, y = y_train, epochs = 30, verbose = 0)
preds = mdl.predict(X_test, verbose = 0)
(np.round(preds).flatten() == y_test).sum()/ len(X_test)

0.6216216216216216

In [291]:
mdl = base_model(df.drop('Survived', axis = 1))
mdl.compile(loss = 'binary_crossentropy', optimizer = 'SGD', metrics = ['Accuracy'])
mdl.fit(x = X_train, y = y_train, epochs = 30, verbose = 0)
preds = mdl.predict(X_test, verbose = 0)
(np.round(preds).flatten() == y_test).sum()/ len(X_test)

0.6846846846846847

So the results above is trained on the titanic dataset. We test variable selection model on three tests. No hidden network (input of (5, 1) -> output (1, )) network, a dnn network with two hidden layers with 64 and 32 nodes, then finally testing generalization ability.

With no hidden layers, variable selection network (VSN) is significantly stronger than the base model by roughly 0.06 (6%) points higher. However, VSN was worse in both DNN and generalization by roughly 0.06 in the opposite direction.

Note that the titanic dataset is small, and neural networks, in general, should not be used for this dataset. VSN also is generally better for noisy/ many feature datasets, which this clearly is not with 5 features. And so, the findings of this notebook just shows VSN is not good for a relatively clean and easy dataset (you can get 85% on this dataset with tree networks). 