In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from sklearn.feature_selection import mutual_info_regression

import sys
sys.path.insert(0, '../fairml')
import plotting
import generate
import models
import actions
import utils

# Generate data

In [None]:
# generate test data (a large, one time only thing)
n_test_samples = 100000

generate_toys = generate.generate_toys

X, Y, Z = generate_toys(n_test_samples)
X1, Y1, Z1 = generate_toys(n_test_samples, z=1)
X0, Y0, Z0 = generate_toys(n_test_samples, z=0)
X_1, Y_1, Z_1 = generate_toys(n_test_samples, z=-1)

test_data = {}
test_data['all Z'] = X, Y, Z
test_data['Z=1'] = X1, Y1, Z1
test_data['Z=0'] = X0, Y0, Z0
test_data['Z=-1'] = X_1, Y_1, Z_1

# Train two classifiers, in stages

In [None]:
sess = tf.InteractiveSession()
ctr = 0

In [None]:
n_samples = 10000
n_epochs = 100
learning_rate = 0.005
n_adv_cycles = 30
lam = 1
n_clf = 1
n_adv = 5
ctr+=1
name = 'Nb'+str(ctr)

#######################
# input placeholders
#######################
x_in = tf.placeholder(tf.float32, shape=(None, 2), name='X1_X2')
y_in = tf.placeholder(tf.float32, shape=(None, ), name='Y')
z_in = tf.placeholder(tf.float32, shape=(None, ), name='Z')
inputs = [x_in, y_in, z_in]

#######################
# create the stage 1 classifier graph, loss, and optimisation
#######################
clf_output, vars_D = models.classifier(x_in, name+'clf')
loss_D = models.classifier_loss(clf_output, y_in)
opt_D = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss_D, var_list=vars_D)

#######################
# create the stage 2 model: s2Y (i.e. correlate with Y, no dependence on Z)
#######################
s2Y, vars_s2Y = models.classifier(clf_output, name+'s2Y')

# and its MINE graphs
T_xy, T_x_y, vars_MINE = models.MINE(s2Y, z_in, name+'_MINE', deep=True)
loss_MINE = models.MINE_loss(T_xy, T_x_y)
opt_MINE = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss_MINE, var_list=vars_MINE)
#T_xy__s2Y_Z, T_x_y__s2Y_Z, vars_s2Y_Z = models.MINE(s2Y, z_in, name+'_s2Y_Z_MINE', deep=True)
#loss_s2Y_Z = models.MINE_loss(T_xy__s2Y_Z, T_x_y__s2Y_Z)
#opt_s2Y_Z = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss_s2Y_Z, var_list=vars_s2Y_Z)

#T_xy__s2Y_Y, T_x_y__s2Y_Y, vars_s2Y_Y = models.MINE(s2Y, y_in, name+'_s2Y_Y_MINE', deep=True)
#loss_s2Y_Y = models.MINE_loss(T_xy__s2Y_Y, T_x_y__s2Y_Y)
#opt_s2Y_Y = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss_s2Y_Y, var_list=vars_s2Y_Y)

# determine the s2Y loss function:
# -----------------------------
# maximise MI between s2Y and Y, while minimising MI between s2Y and sensitive parameter Z:
# maximise MI(s2Y, Y) - lambda * MI(s2Y, Z)
# -----------------------------
# this equals:
# minimise - MI(s2Y, Y) + lambda * MI(s2Y, Z)
# and since MI = - loss_MINE, also have:
# minimise: loss_s2Y_Y - lambda * loss_s2Y_Z

#loss_s2Y =  loss_s2Y_Y - lam*loss_s2Y_Z
#opt_s2Y = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss_s2Y, var_list=vars_s2Y)

#######################
#######################
# run the trainings
#######################
#######################

# initialise the variables
sess.run(tf.global_variables_initializer())

# pretrain the classifier
actions.train(sess, opt_D, loss_D, inputs, generate_toys, n_samples, n_epochs, 'Classifier Loss (L_D)')

# plot the performance of the classifier
#pred = utils.sigmoid(sess.run(clf_output, feed_dict={x_in:X, y_in:Y}))
#pred1 = utils.sigmoid(sess.run(clf_output, feed_dict={x_in:X1, y_in:Y1}))
#pred0 = utils.sigmoid(sess.run(clf_output, feed_dict={x_in:X0, y_in:Y0}))
#pred_1 = utils.sigmoid(sess.run(clf_output, feed_dict={x_in:X_1, y_in:Y_1}))
#test_data['preds'] = pred, pred1, pred0, pred_1
#plotting.plot_classifier_performance(test_data, 'media/plots/clf_pretrained')

# pretrain the two MINEs:
s2Y_val = sess.run(s2Y, feed_dict={x_in:X, y_in:Y})
N = 10
print(s2Y_val[:N, :])
print(Z[:N])
true_MI_Z = mutual_info_regression(s2Y_val[:N, :], Z[:N])[0]
actions.train(sess, opt_MINE, loss_MINE, inputs, generate_toys, n_samples, n_epochs, 'MINE loss: s2Y and Z', -true_MI_Z)


#true_MI_Y = mutual_info_regression(s2Y_val[:N, :], Y[:N])[0]
#actions.train(sess, opt_s2Y_Z, loss_s2Y_Z, inputs, generate_toys, n_samples, n_epochs, 'MINE loss: s2Y and Z', -true_MI_Z)
#actions.train(sess, opt_s2Y_Y, loss_s2Y_Y, inputs, generate_toys, n_samples, n_epochs, 'MINE loss: s2Y and Y', -true_MI_Y)

# now do the adversarial part (modifed loss function for the classifier)
losses = {'s2Y':[], 's2Y_Z':[], 's2Y_Y':[], 'D':[]}
MIs = {'Y':[], 'Z':[]}
for c in range(n_adv_cycles):

    # report progress
    if c > 10 and c%(n_adv_cycles//10) == 0:
        print('{c}/{t}'.format(c=c, t=n_adv_cycles))
    
    # update MINE
    losses['s2Y_Z'].append(actions.train(sess, opt_s2Y_Z, loss_s2Y_Z, inputs, generate_toys, n_samples, n_adv, None))
    losses['s2Y_Y'].append(actions.train(sess, opt_s2Y_Y, loss_s2Y_Y, inputs, generate_toys, n_samples, n_adv, None))
    
    # train s2Y
    losses['s2Y'].append(actions.train(sess, opt_s2Y, loss_s2Y, inputs, generate_toys, n_samples, n_clf, None))
    
    # monitor D loss, true mutual information
    l_D, s2Y_val = sess.run([loss_D, s2Y], feed_dict={x_in:X, y_in:Y})
    losses['D'].append(l_D)
    N = 1000
    MIs['Y'].append(mutual_info_regression(s2Y_val[:N, :], Y[:N])[0])
    MIs['Z'].append(mutual_info_regression(s2Y_val[:N, :], Z[:N])[0])

fig, ax = plt.subplots(3, figsize=(7,7), sharex=True)
ax[0].plot(range(n_adv_cycles), losses['s2Y'], c='k', label='Loss_s2Y')
ax[0].legend(loc='best')
ax[1].plot(range(n_adv_cycles), -np.array(losses['s2Y_Z']), 'r:', label='MINE(s2Y, Z)')
ax[1].plot(range(n_adv_cycles), MIs['Z'], c='r', label='MI(s2Y, Z)')
ax[1].legend(loc='best')
ax[2].plot(range(n_adv_cycles), -np.array(losses['s2Y_Y']), 'b:', label='MINE(s2Y, Y)')
ax[2].plot(range(n_adv_cycles), MIs['Y'], c='b', label='MI(s2Y, Y)')
ax[2].legend(loc='best')
ax[2].set_xlabel('Adversarial cycles')
ax[0].set_title('Losses for s2Y')
plt.show()

# plot the performance of the classifier
pred = utils.sigmoid(sess.run(s2Y, feed_dict={x_in:X, y_in:Y}))
pred1 = utils.sigmoid(sess.run(s2Y, feed_dict={x_in:X1, y_in:Y1}))
pred0 = utils.sigmoid(sess.run(s2Y, feed_dict={x_in:X0, y_in:Y0}))
pred_1 = utils.sigmoid(sess.run(s2Y, feed_dict={x_in:X_1, y_in:Y_1}))
test_data['preds'] = pred, pred1, pred0, pred_1
plotting.plot_classifier_performance(test_data, 'media/plots/clf_s2Y')