In [1]:
import os, sys

os.chdir('../..')

# Import

In [33]:
import os, sys, inspect, time

import numpy as np
import torch 
import matplotlib.pyplot as plt
import pandas as pd
from datasets import USCensus, UCICrime
from tasks.fairness.experiment import train, test
from tasks.fairness.model import FairNet
import utils_os

%load_ext autoreload
%autoreload 2

device = 'cuda:0'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Data preparation

In [34]:
X, Y, A = UCICrime.read_dataset(name='adult')
X = torch.Tensor(X).to(device)                     # input data
Y = torch.Tensor(Y).to(device)                     # target to predict
A = torch.Tensor(A).to(device)                     # sensitive attribute
assert X[0, -1] == A[0, -1]                        # the last index of x is also the sensitive attr 
assert X[0, -2] != A[0, -1]
dim = 20
n, d = X.size()
n, K = Y.size()

N = 20000
x, y, a = X[0:N,:], Y[0:N, :], A[0:N, :]
x = (x - x.mean(dim=0, keepdim=True))/x.std(dim=0, keepdim=True)
x_test, y_test, a_test = X[N:N+2500,:], Y[N:N+2500, :], A[N:N+2500, :]
x_test = (x_test - x_test.mean(dim=0, keepdim=True))/x_test.std(dim=0, keepdim=True)
print('train', x.size(), y.size())
print('test', x_test.size(), y_test.size())

DATASET = 'UCIAdult'
dim_z = 4*dim

train torch.Size([20000, 16]) torch.Size([20000, 1])
test torch.Size([2500, 16]) torch.Size([2500, 1])


# Experiments

### N/A

Plain model that has no constraint on I(Z; T).

In [132]:
estimator = 'NONE'

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.beta = 10
        self.lr = 1e-3
        self.bs = 500
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -9.857749938964844 loss val= -1.1813234090805054 adv_loss= tensor([0.], device='cuda:0') time= 0.0008447170257568359
t= 50 loss= -9.96342658996582 loss val= -9.973118782043457 adv_loss= tensor([0.], device='cuda:0') time= 0.0008435249328613281
t= 100 loss= -9.957209587097168 loss val= -9.987390518188477 adv_loss= tensor([0.], device='cuda:0') time= 0.0008449554443359375
t= 150 loss= -9.970044136047363 loss val= -9.990811347961426 adv_loss= tensor([0.], device='cuda:0') time= 0.0008385181427001953
t= 200 loss= -9.93735408782959 loss val= -9.993398666381836 adv_loss= tensor([0.], device='cuda:0') time= 0.0008389949798583984
t= 250 loss= -9.951817512512207 loss val= -9.972981452941895 adv_loss= tensor([0.], device='cuda:0') time= 0.0008325576782226562
t= 300 loss= -9.93704605102539 loss val= -9.986382484436035 adv_loss= tensor([0.], device='cuda:0') time= 0.0008449554443359375
t= 350 loss= -9.964373588562012 loss val= -9.990212440490723 adv_loss= tensor([0.], device='cuda:0') t

### Slice

The proposed slice based method, which estimate a slice of I(Z; T)

In [128]:
estimator = 'SLICE'

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.beta = 0.15
        self.lr = 5e-4
        self.bs = 500
        self.n_slice = 100
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -0.08379819989204407 loss val= 0.9518726468086243 adv_loss= (0.9930687546730042, 0.9925546050071716) time= 0.08807969093322754
t= 50 loss= -0.1278180032968521 loss val= 0.025787904858589172 adv_loss= (0.23736535012722015, 0.158733069896698) time= 0.06826949119567871
t= 100 loss= -0.13896586000919342 loss val= -0.09850331395864487 adv_loss= (0.17163525521755219, 0.04567597061395645) time= 0.06440305709838867
t= 150 loss= -0.12926752865314484 loss val= -0.08577495813369751 adv_loss= (0.18564406037330627, 0.05951729416847229) time= 0.06843852996826172
t= 200 loss= -0.08085440844297409 loss val= -0.06429358571767807 adv_loss= (0.17469634115695953, 0.08533938974142075) time= 0.06384634971618652
t= 250 loss= -0.09801480919122696 loss val= -0.12339253723621368 adv_loss= (0.15174585580825806, 0.02621607296168804) time= 0.0688786506652832
t= 300 loss= -0.03187737613916397 loss val= -0.09900850057601929 adv_loss= (0.1579950600862503, 0.04757241532206535) time= 0.0694875717163086
t= 35

### Adversarial training

Baseline methods that uses neural network (i.e. the *adversary*) to estimate I(Z; T) or its proxies. 

If the adversary is not trained sufficiently (controlled by *hyperparams.inner_epochs*), the learned Z will not be so fair.

In [143]:
estimator = 'RENYI'                             # neural Renyi correlation

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.early_stop = False                 # <-- don't use early stopping as the adversary is not mature at early stage
        self.beta = 0.85
        self.lr = 5e-4
        self.bs = 500
        self.inner_epochs = 6                   # <-- this adjust the execution time of neural Renyi model
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -0.6693118810653687 loss val= 0.8527640700340271 adv_loss= -0.9893918037414551 time= 0.28438353538513184
t= 50 loss= -0.8343992829322815 loss val= -0.8046528100967407 adv_loss= -0.04346594214439392 time= 0.2792022228240967
t= 100 loss= -0.8413237929344177 loss val= -0.8395995497703552 adv_loss= -0.018952084705233574 time= 0.2871394157409668
t= 150 loss= -0.6148027777671814 loss val= -0.6775297522544861 adv_loss= -0.16563567519187927 time= 0.27283668518066406
t= 200 loss= -0.3584189713001251 loss val= -0.6378015279769897 adv_loss= -0.20467928051948547 time= 0.28690195083618164
t= 250 loss= -0.2297525405883789 loss val= -0.6344290971755981 adv_loss= -0.20855972170829773 time= 0.27259278297424316
t= 300 loss= -0.17110639810562134 loss val= -0.6288067102432251 adv_loss= -0.21423959732055664 time= 0.29013872146606445
t= 350 loss= -0.21498364210128784 loss val= -0.6272209882736206 adv_loss= -0.2162967026233673 time= 0.2738804817199707
t= 400 loss= -0.1487886905670166 loss val= -0.

In [147]:
estimator = 'TC'                                # neural total correlation       

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.early_stop = False                 # <-- don't use early stopping as the adversary is not mature at early stage
        self.beta = 0.05
        self.lr = 5e-4
        self.bs = 500
        self.inner_epochs = 20                  # <-- this adjust the execution time of neural Renyi model
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -66.26215362548828 loss val= -0.23371586203575134 adv_loss= 0.21377435326576233 time= 1.561725378036499
t= 50 loss= -5.484129428863525 loss val= -0.6484896540641785 adv_loss= 0.6053942441940308 time= 1.5729913711547852
t= 100 loss= -5.089438438415527 loss val= -0.6892072558403015 adv_loss= 0.6407406330108643 time= 1.4096150398254395
t= 150 loss= -0.7078949809074402 loss val= -0.6804920434951782 adv_loss= 0.6308146715164185 time= 1.518554925918579
t= 200 loss= -0.6553541421890259 loss val= -0.6620994210243225 adv_loss= 0.6104956865310669 time= 1.5644960403442383
t= 250 loss= -0.6213163137435913 loss val= -0.6264081597328186 adv_loss= 0.5759978890419006 time= 1.5502920150756836
t= 300 loss= -0.6003984212875366 loss val= -0.608903169631958 adv_loss= 0.5632209777832031 time= 1.4453811645507812
t= 350 loss= -0.5980075001716614 loss val= -0.609339714050293 adv_loss= 0.5509132146835327 time= 1.2585937976837158
t= 400 loss= -0.5630080103874207 loss val= -0.5909461379051208 adv_loss=