In [1]:
import numpy as np
import math
import copy
import bisect
import operator
import tensorflow as tf
from sklearn.preprocessing import StandardScaler

## Train Neural Network

In [2]:
use_main_effect_nets = True # toggle this to use "main effect" nets

In [3]:
# Parameters
learning_rate = 0.01
num_epochs = 200
batch_size = 100
display_step = 100
l1_const = 5e-5
num_samples = 30000 #30k datapoints, split 1/3-1/3-1/3

# Network Parameters
n_hidden_1 = 140 # 1st layer number of neurons
n_hidden_2 = 100 # 2nd layer number of neurons
n_hidden_3 = 60 # 3rd "
n_hidden_4 = 20 # 4th "
n_hidden_uni = 10
num_input = 10 # simple synthetic example input dimension
num_output = 1 # regression or classification output dimension

# tf Graph input
X = tf.placeholder("float", [None, num_input])
Y = tf.placeholder("float", [None, num_output])

tf.set_random_seed(0)
np.random.seed(0)

In [4]:
# Interaction data generator
def synth_func(x):
    interaction1 = np.exp(np.fabs(x[:,0]-x[:,1]))                        
    interaction2 = np.fabs(x[:,1]*x[:,2])  
    interaction3 = -1*np.power(np.power(x[:,2],2),np.fabs(x[:,3])) 
    interaction4 = np.power(x[:,0]*x[:,3],2)
    interaction5 = np.log(np.power(x[:,3],2) + np.power(x[:,4],2) + np.power(x[:,6],2) + np.power(x[:,7],2))
    main_effects = x[:,8] + 1/(1+np.power(x[:,9],2))

    y =         interaction1 + interaction2 + interaction3 + interaction4 + interaction5 + main_effects
    #ground truth:  {1,2}         {2,3}          {3,4}          {1,4}        {4,5,7,8}
    return y

def gen_synth_data():
    X = np.random.uniform(low=-1, high=1, size=(num_samples,10))
    Y = np.expand_dims(synth_func(X),axis=1)
    
    a = num_samples//3
    b = 2*num_samples//3
    
    tr_x, va_x, te_x = X[:a], X[a:b], X[b:]
    tr_y, va_y, te_y = Y[:a], Y[a:b], Y[b:]

    scaler_x = StandardScaler()
    scaler_y = StandardScaler()
    scaler_x.fit(tr_x)
    scaler_y.fit(tr_y)

    tr_x, va_x, te_x = scaler_x.transform(tr_x), scaler_x.transform(va_x), scaler_x.transform(te_x)
    tr_y, va_y, te_y = scaler_y.transform(tr_y), scaler_y.transform(va_y), scaler_y.transform(te_y)
    return tr_x, va_x, te_x, tr_y, va_y, te_y

# Get data
tr_x, va_x, te_x, tr_y, va_y, te_y = gen_synth_data()
tr_size = tr_x.shape[0]

In [5]:
# access weights & biases
weights = {
    'h1': tf.Variable(tf.truncated_normal([num_input, n_hidden_1], 0, 0.1)),
    'h2': tf.Variable(tf.truncated_normal([n_hidden_1, n_hidden_2], 0, 0.1)),
    'h3': tf.Variable(tf.truncated_normal([n_hidden_2, n_hidden_3], 0, 0.1)),
    'h4': tf.Variable(tf.truncated_normal([n_hidden_3, n_hidden_4], 0, 0.1)),
    'out': tf.Variable(tf.truncated_normal([n_hidden_4, num_output], 0, 0.1))
}
biases = {
    'b1': tf.Variable(tf.truncated_normal([n_hidden_1], 0, 0.1)),
    'b2': tf.Variable(tf.truncated_normal([n_hidden_2], 0, 0.1)),
    'b3': tf.Variable(tf.truncated_normal([n_hidden_3], 0, 0.1)),
    'b4': tf.Variable(tf.truncated_normal([n_hidden_4], 0, 0.1)),
    'out': tf.Variable(tf.truncated_normal([num_output], 0, 0.1))
}

def get_weights_uninet():
    weights = {
        'h1': tf.Variable(tf.truncated_normal([1, n_hidden_uni], 0, 0.1)),
        'h2': tf.Variable(tf.truncated_normal([n_hidden_uni, n_hidden_uni], 0, 0.1)),
        'h3': tf.Variable(tf.truncated_normal([n_hidden_uni, n_hidden_uni], 0, 0.1)),
        'out': tf.Variable(tf.truncated_normal([n_hidden_uni, num_output], 0, 0.1))
    }
    return weights

def get_biases_uninet():
    biases = {
        'b1': tf.Variable(tf.truncated_normal([n_hidden_uni], 0, 0.1)),
        'b2': tf.Variable(tf.truncated_normal([n_hidden_uni], 0, 0.1)),
        'b3': tf.Variable(tf.truncated_normal([n_hidden_uni], 0, 0.1))
    }
    return biases

In [6]:
# Create model
def normal_neural_net(x, weights, biases):
    layer_1 = tf.nn.relu(tf.add(tf.matmul(x, weights['h1']), biases['b1']))
    layer_2 = tf.nn.relu(tf.add(tf.matmul(layer_1, weights['h2']), biases['b2']))
    layer_3 = tf.nn.relu(tf.add(tf.matmul(layer_2, weights['h3']), biases['b3']))
    layer_4 = tf.nn.relu(tf.add(tf.matmul(layer_3, weights['h4']), biases['b4']))    
    out_layer = tf.matmul(layer_4, weights['out']) + biases['out']
    return out_layer

def main_effect_net(x, weights, biases):
    layer_1 = tf.nn.relu(tf.add(tf.matmul(x, weights['h1']), biases['b1']))
    layer_2 = tf.nn.relu(tf.add(tf.matmul(layer_1, weights['h2']), biases['b2']))
    layer_3 = tf.nn.relu(tf.add(tf.matmul(layer_2, weights['h3']), biases['b3']))    
    out_layer = tf.matmul(layer_3, weights['out'])
    return out_layer

# L1 regularizer
def l1_norm(a): return tf.reduce_sum(tf.abs(a))

In [7]:
# Construct model
net = normal_neural_net(X, weights, biases)

if use_main_effect_nets:  
    me_nets = []
    for x_i in range(num_input):
        me_net = main_effect_net(tf.expand_dims(X[:,x_i],1), get_weights_uninet(), get_biases_uninet())
        me_nets.append(me_net)
    net = net + sum(me_nets)

# Define optimizer
loss_op = tf.losses.mean_squared_error(labels=Y, predictions=net)
# loss_op = tf.sigmoid_cross_entropy_with_logits(labels=Y,logits=net) # use this in the case of binary classification
sum_l1 = tf.reduce_sum([l1_norm(weights[k]) for k in weights])
loss_w_reg_op = loss_op + l1_const*sum_l1 

batch = tf.Variable(0)
decaying_learning_rate = tf.train.exponential_decay(learning_rate, batch*batch_size, tr_size, 0.95, staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate=decaying_learning_rate).minimize(loss_w_reg_op, global_step=batch)

In [8]:
init = tf.global_variables_initializer()
n_batches = tr_size//batch_size
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.25
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

print('Initialized')

for epoch in range(num_epochs):

    batch_order = list(range(n_batches))
    np.random.shuffle(batch_order)

    for i in batch_order:
        batch_x = tr_x[i*batch_size:(i+1)*batch_size]
        batch_y = tr_y[i*batch_size:(i+1)*batch_size]
        _, lr = sess.run([optimizer,decaying_learning_rate], feed_dict={X:batch_x, Y:batch_y})

    if (epoch+1) % 50 == 0:
        tr_mse = sess.run(loss_op, feed_dict={X:tr_x, Y:tr_y})
        va_mse = sess.run(loss_op, feed_dict={X:va_x, Y:va_y})
        te_mse = sess.run(loss_op, feed_dict={X:te_x, Y:te_y})
        print('Epoch', epoch+1)
        print('\t','train rmse', math.sqrt(tr_mse), 'val rmse', math.sqrt(va_mse), 'test rmse', math.sqrt(te_mse))
        print('\t','learning rate', lr)
        
print('done')

Initialized
Epoch 50
	 train rmse 0.03911299422164593 val rmse 0.046781548587818335 test rmse 0.04766639578022324
	 learning rate 0.000809947
Epoch 100
	 train rmse 0.035532514060659194 val rmse 0.043421838204422136 test rmse 0.044302592678017526
	 learning rate 6.23213e-05
Epoch 150
	 train rmse 0.0352709986999008 val rmse 0.043239377950527326 test rmse 0.04399459512626101
	 learning rate 4.79531e-06
Epoch 200
	 train rmse 0.035248727497338636 val rmse 0.04320754406068322 test rmse 0.043997104906904214
	 learning rate 3.68974e-07
done


## Interpret Weights

In [9]:
def preprocess_weights(w_dict):
    hidden_layers = [int(layer[1:]) for layer in w_dict.keys() if layer.startswith('h')]
    output_h = ['h' + str(x) for x in range(max(hidden_layers),1,-1)]
    w_agg = np.abs(w_dict['out'])
    w_h1 = np.abs(w_dict['h1'])

    for h in output_h:
        w_agg = np.matmul( np.abs(w_dict[h]), w_agg)

    return w_h1, w_agg 

def get_interaction_ranking(w_dict):
    xdim = w_dict['h1'].shape[0]
    w_h1, w_agg = preprocess_weights(w_dict)
        
    # rank interactions
    interaction_strengths = dict()

    for i in range(len(w_agg)):
        sorted_fweights = sorted(enumerate(w_h1[:,i]), key=lambda x:x[1], reverse = True)
        interaction_candidate = []
        weight_list = []       
        for j in range(len(w_h1)):
            bisect.insort(interaction_candidate, sorted_fweights[j][0]+1)
            weight_list.append(sorted_fweights[j][1])
            if len(interaction_candidate) == 1:
                continue
            interaction_tup = tuple(interaction_candidate)
            if interaction_tup not in interaction_strengths:
                interaction_strengths[interaction_tup] = 0
            inter_agg = min(weight_list)      
            interaction_strengths[interaction_tup] += np.abs(inter_agg*np.sum(w_agg[i]))
        
    interaction_sorted = sorted(interaction_strengths.items(), key=operator.itemgetter(1), reverse=True)

    # forward prune the ranking of redundant interactions
    interaction_ranking_pruned = []
    existing_largest = []
    for i, inter in enumerate(interaction_sorted):
        if len(interaction_ranking_pruned) > 20000: break
        skip = False
        indices_to_remove = set()
        for inter2_i, inter2 in enumerate(existing_largest):
            # if this is not the existing largest
            if set(inter[0]) < set(inter2[0]):
                skip = True
                break
            # if this is larger, then need to recall this index later to remove it from existing_largest
            if set(inter[0]) > set(inter2[0]):
                indices_to_remove.add(inter2_i)
        if skip:
            assert len(indices_to_remove) == 0
            continue
        prevlen = len(existing_largest)
        existing_largest[:] = [el for el_i, el in enumerate(existing_largest) if el_i not in indices_to_remove]
        existing_largest.append(inter)
        interaction_ranking_pruned.append((inter[0], inter[1]))

        curlen = len(existing_largest)

    return interaction_ranking_pruned

def get_pairwise_ranking(w_dict):
    xdim = w_dict['h1'].shape[0]
    w_h1, w_agg = preprocess_weights(w_dict)

    input_range = range(1,xdim+1)
    pairs = [(xa,yb) for xa in input_range for yb in input_range if xa != yb]
    for entry in pairs:
        if (entry[1], entry[0]) in pairs:
            pairs.remove((entry[1],entry[0]))

    pairwise_strengths = []
    for pair in pairs:
        a = pair[0]
        b = pair[1]
        wa = w_h1[a-1].reshape(w_h1[a-1].shape[0],1)
        wb = w_h1[b-1].reshape(w_h1[b-1].shape[0],1)
        wz = np.abs(np.minimum(wa , wb))*w_agg
        cab = np.sum(np.abs(wz))
        pairwise_strengths.append((pair, cab))
#     list(zip(pairs, pairwise_strengths))

    pairwise_ranking = sorted(pairwise_strengths,key=operator.itemgetter(1), reverse=True)

    return pairwise_ranking

In [10]:
w_dict = sess.run(weights)

## Get Interactions

In [11]:
# Variable-Order Interaction Ranking
get_interaction_ranking(w_dict)

[((1, 2), 3.9842937443447619),
 ((1, 4), 1.2878044545665732),
 ((5, 7), 1.1594737470149994),
 ((2, 3), 1.1573218256235123),
 ((7, 8), 1.0395813845098441),
 ((3, 4), 0.9955278397655154),
 ((5, 8), 0.74511629942584667),
 ((5, 7, 8), 0.65411340026184916),
 ((4, 7), 0.52771079540252686),
 ((4, 8), 0.51658676937230408),
 ((4, 5, 8), 0.44517305865883827),
 ((4, 5, 7, 8), 0.26200527045875788),
 ((1, 2, 3), 0.059168170135783404),
 ((1, 2, 4), 0.036433947985880635),
 ((1, 2, 8), 0.018344591895609306),
 ((4, 5, 6, 7, 8), 0.018325034528970718),
 ((3, 4, 8), 0.017842164263129234),
 ((1, 2, 9), 0.012655940991698011),
 ((1, 2, 4, 9), 0.012392999654720196),
 ((1, 4, 5, 7, 8), 0.010827778314705938),
 ((1, 2, 8, 9), 0.010519576074955239),
 ((3, 4, 5, 6, 7, 8), 0.0095966905355453491),
 ((1, 4, 6), 0.0094528580084776905),
 ((2, 4, 5, 7, 8), 0.0089620146900415421),
 ((1, 2, 3, 4), 0.0078430417855737296),
 ((1, 2, 5, 8, 9), 0.0070880441926419735),
 ((3, 4, 5, 8, 10), 0.0063181919977068901),
 ((1, 3, 4, 6),

In [12]:
# Pairwise Interaction Ranking
get_pairwise_ranking(w_dict)

[((1, 2), 4.0479612),
 ((5, 8), 1.6183388),
 ((5, 7), 1.6133682),
 ((7, 8), 1.5589437),
 ((1, 4), 1.3692775),
 ((2, 3), 1.1941909),
 ((4, 8), 1.0497335),
 ((3, 4), 1.0476193),
 ((4, 5), 1.0218934),
 ((4, 7), 0.83411992),
 ((1, 3), 0.095073581),
 ((2, 4), 0.080335304),
 ((1, 8), 0.070268914),
 ((3, 5), 0.063574985),
 ((2, 8), 0.061794207),
 ((3, 8), 0.059727035),
 ((1, 5), 0.058769718),
 ((2, 5), 0.051377397),
 ((2, 9), 0.049841423),
 ((4, 6), 0.048084557),
 ((1, 9), 0.047795787),
 ((5, 6), 0.047461249),
 ((6, 8), 0.046787921),
 ((8, 9), 0.045206197),
 ((5, 9), 0.044821486),
 ((4, 9), 0.04005),
 ((6, 7), 0.039612152),
 ((5, 10), 0.034415431),
 ((4, 10), 0.034250751),
 ((8, 10), 0.033772744),
 ((2, 7), 0.032941118),
 ((3, 7), 0.032652456),
 ((3, 9), 0.028732639),
 ((1, 7), 0.028270187),
 ((1, 10), 0.02776891),
 ((1, 6), 0.027052868),
 ((3, 6), 0.025879536),
 ((3, 10), 0.025838438),
 ((2, 6), 0.022900498),
 ((7, 9), 0.02238699),
 ((2, 10), 0.019447444),
 ((9, 10), 0.016316127),
 ((6, 9), 