In [None]:
import tensorflow as tf
import tensorflow.contrib.layers as layers
import numpy as np
import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../fairml')
#import plotting
#import generate
import models
#import actions
#import utils

In [None]:
config = tf.ConfigProto(intra_op_parallelism_threads=8, inter_op_parallelism_threads=8, \
                        allow_soft_placement=True, device_count = {'CPU': 8})
sess = tf.InteractiveSession(config = config)

In [None]:
def prepare_data(n_samples):
    data = []
    pois = []
    nuisances = []
    
    for i in range(n_samples):
        poi = np.random.uniform(low = 0.0, high = 1.0)
        nuisance = np.random.uniform(low = 0.0, high = 1.0)
                
        datum = np.random.multivariate_normal(mean = [poi, nuisance], cov = np.array([[1, 0.2], [0.2, 1]]))
        
        data.append(datum)
        pois.append(poi)
        nuisances.append(nuisance)
        
        #pois.append(nuisance)
        #nuisances.append(poi)
        
    return np.array(data), np.array(pois), np.array(nuisances)

In [None]:
data_train, pois_train, nuisances_train = prepare_data(8000)
nuisances_train = np.expand_dims(nuisances_train, axis = 1)
pois_train = np.expand_dims(pois_train, axis = 1)

In [None]:
plt.hexbin(x = data_train[:, 0], y = data_train[:, 1], bins = 'log', gridsize = 50)
plt.colorbar()
ax = plt.gca()
ax.set_xlim([-3, 4])
ax.set_ylim([-3, 4])
plt.savefig("toy_filterdataset.pdf")
plt.show()

In [None]:
def filter_network(filter_input):
    with tf.variable_scope("filter"):
        lay = layers.relu(filter_input, 20)
        lay = layers.relu(lay, 20)
        outputs = layers.linear(lay, 1)
        
    these_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope = "filter")

    return outputs, these_vars

In [None]:
# prepare input tensors
data_in = tf.placeholder(tf.float32, [None, 2], name = 'data_in')
nuisances_in = tf.placeholder(tf.float32, [None, 1], name = 'nuisances_in')
pois_in = tf.placeholder(tf.float32, [None, 1], name = 'pois_in')

filter_output, filter_vars = filter_network(data_in)

# prepare the two MINE blocks connected to the filter output
T1_nuis, T2_nuis, MINE_vars_nuis = models.MINE(filter_output, nuisances_in, "MINE_nuis")
T1_pois, T2_pois, MINE_vars_pois = models.MINE(filter_output, pois_in, "MINE_pois")

# upon convergence, the MINE losses below give the negative mutual information
MINE_loss_nuis = models.MINE_loss(T1_nuis, T2_nuis)
MINE_loss_pois = models.MINE_loss(T1_pois, T2_pois)

# MINE optimizers
train_MINE_pois = tf.train.AdamOptimizer(learning_rate = 0.01, beta1 = 0.3, beta2 = 0.5).minimize(MINE_loss_nuis, var_list = MINE_vars_nuis)
train_MINE_nuis = tf.train.AdamOptimizer(learning_rate = 0.01, beta1 = 0.3, beta2 = 0.5).minimize(MINE_loss_pois, var_list = MINE_vars_pois)

# total loss
total_loss = MINE_loss_pois - 3 * MINE_loss_nuis

# filter optimizer
train_filter = tf.train.AdamOptimizer(learning_rate = 0.005, beta1 = 0.3, beta2 = 0.5).minimize(total_loss, var_list = filter_vars)

In [None]:
sess.run(tf.global_variables_initializer())

MINE_init_epochs = 200
batches_per_epoch = 200
number_epochs = 3
batch_size = 200

# pre-train MINE blocks
for epoch in range(MINE_init_epochs):
    sess.run(train_MINE_pois, feed_dict = {data_in: data_train, nuisances_in: nuisances_train, pois_in: pois_train})
    sess.run(train_MINE_nuis, feed_dict = {data_in: data_train, nuisances_in: nuisances_train, pois_in: pois_train})
    
# initial MI values        
MI_pois = -sess.run(MINE_loss_pois, feed_dict = {data_in: data_train, nuisances_in: nuisances_train, pois_in: pois_train})
MI_nuis = -sess.run(MINE_loss_nuis, feed_dict = {data_in: data_train, nuisances_in: nuisances_train, pois_in: pois_train})

print("MI_nuis = {}, MI_pois = {}".format(MI_nuis, MI_pois))

for epochs in range(number_epochs):
    for batch in range(batches_per_epoch):
        # prepare batch training data
        inds = np.random.choice(len(data_train), batch_size)
        data_batch = data_train[inds]
        pois_batch = pois_train[inds]
        nuis_batch = nuisances_train[inds]

        # update MINE
        sess.run(train_MINE_pois, feed_dict = {data_in: data_train, nuisances_in: nuisances_train, pois_in: pois_train})
        sess.run(train_MINE_nuis, feed_dict = {data_in: data_train, nuisances_in: nuisances_train, pois_in: pois_train})

        # update filter
        sess.run(train_filter, feed_dict = {data_in: data_batch, nuisances_in: nuis_batch, pois_in: pois_batch})

        if not batch % 100:
            # debug output
            MI_pois = -sess.run(MINE_loss_pois, feed_dict = {data_in: data_train, nuisances_in: nuisances_train, pois_in: pois_train})
            MI_nuis = -sess.run(MINE_loss_nuis, feed_dict = {data_in: data_train, nuisances_in: nuisances_train, pois_in: pois_train})

            print("MI_nuis = {}, MI_pois = {}".format(MI_nuis, MI_pois))

# look at the filter output as a function of the incoming random variable
data_test = np.random.uniform(low = -4, high = 4, size = [50000, 2])

pred = sess.run(filter_output, feed_dict = {data_in: data_test})

def plot_contour(data_xy, data_z, x_low = -4, x_high = 4 , y_low = -4, y_high = 4):
    from matplotlib.mlab import griddata
    
    xi = np.linspace(x_low, x_high, 1000)
    yi = np.linspace(y_low, y_high, 1000)
    zi = griddata(data_xy[:, 0], data_xy[:, 1], data_z, xi, yi, interp = "linear")

    plt.contourf(xi, yi, zi, interp='linear', levels = 20)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.colorbar()
    
plot_contour(data_test, pred.flatten())

plt.savefig("information_filter_contours_f1.pdf")
plt.show()