In [2]:
import os, sys

os.chdir('../..')

# Import

In [28]:
import os, sys, inspect, time

import numpy as np
import torch 
import matplotlib.pyplot as plt
import pandas as pd
from datasets import USCensus, UCICrime
from tasks.fairness.experiment import train, test
from tasks.fairness.model import FairNet

import utils_os, utils_data

%load_ext autoreload
%autoreload 2

device = 'cuda:1'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data preparation

In [29]:
X, Y, A = USCensus.read_tract(label='Proverty', sensitive_attribute='Gender')
X = torch.Tensor(X).to(device)                     # input data
Y = torch.Tensor(Y).to(device)                     # target to predict
A = torch.Tensor(A).to(device)                     # sensitive attribute
assert X[0, -1] == A[0, -1]                        # the last index of x is also the sensitive attr 
assert X[0, -2] != A[0, -1]
dim = 20
n, d = X.size()
n, K = Y.size()

N = 20000
x, y, a = X[0:N,:], Y[0:N, :], A[0:N, :]
x = (x - x.mean(dim=0, keepdim=True))/x.std(dim=0, keepdim=True)
x_test, y_test, a_test = X[N:N+2500,:], Y[N:N+2500, :], A[N:N+2500, :]
x_test = (x_test - x_test.mean(dim=0, keepdim=True))/x_test.std(dim=0, keepdim=True)
print('train', x.size(), y.size())
print('test', x_test.size(), y_test.size())

DATASET = 'USCensus'
dim_z = 4*dim

train torch.Size([20000, 32]) torch.Size([20000, 1])
test torch.Size([2500, 32]) torch.Size([2500, 1])


# Experiments

### N/A

Plain model that has no constraint on I(Z; T). 

In [88]:
estimator = 'NONE'

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.beta = 10
        self.lr = 5e-4
        self.bs = 500
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -9.520732879638672 loss val= -5.054341793060303 adv_loss= tensor([0.], device='cuda:1') time= 0.0008563995361328125
t= 50 loss= -9.793397903442383 loss val= -9.495684623718262 adv_loss= tensor([0.], device='cuda:1') time= 0.0008587837219238281
t= 100 loss= -9.866706848144531 loss val= -9.397472381591797 adv_loss= tensor([0.], device='cuda:1') time= 0.0008535385131835938
t= 150 loss= -9.891358375549316 loss val= -9.36882495880127 adv_loss= tensor([0.], device='cuda:1') time= 0.0008504390716552734
t= 200 loss= -9.893041610717773 loss val= -9.363776206970215 adv_loss= tensor([0.], device='cuda:1') time= 0.0008552074432373047
t= 250 loss= -9.887506484985352 loss val= -9.363200187683105 adv_loss= tensor([0.], device='cuda:1') time= 0.0008544921875
t= 300 loss= -9.869596481323242 loss val= -9.3630952835083 adv_loss= tensor([0.], device='cuda:1') time= 0.0008604526519775391
t= 350 loss= -9.868520736694336 loss val= -9.363089561462402 adv_loss= tensor([0.], device='cuda:1') time= 0.

### Slice

The proposed slice based method, which estimate a slice of I(Z; T)

In [84]:
estimator = 'SLICE'

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.beta = 0.35
        self.lr = 5e-4
        self.bs = 500
        self.n_slice = 100
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -0.32833436131477356 loss val= 0.8138784170150757 adv_loss= (0.9428901076316833, 0.9410136938095093) time= 0.0967864990234375
t= 50 loss= -0.3027324378490448 loss val= -0.2206546664237976 adv_loss= (0.17648135125637054, 0.0934160128235817) time= 0.07204580307006836
t= 100 loss= -0.30741024017333984 loss val= -0.2498663067817688 adv_loss= (0.17200900614261627, 0.07138828188180923) time= 0.07196831703186035
t= 150 loss= -0.26396113634109497 loss val= -0.2812604606151581 adv_loss= (0.16744105517864227, 0.04428320378065109) time= 0.07210969924926758
t= 200 loss= -0.23471036553382874 loss val= -0.27446249127388 adv_loss= (0.16574694216251373, 0.047728877514600754) time= 0.07109284400939941
t= 250 loss= -0.19777096807956696 loss val= -0.22445540130138397 adv_loss= (0.16239887475967407, 0.10240037739276886) time= 0.07114219665527344
t= 300 loss= -0.2553684711456299 loss val= -0.3185093104839325 adv_loss= (0.14165353775024414, 0.006020902190357447) time= 0.07230401039123535
t= 350 l

### Adversarial training

Baseline methods that uses neural network (i.e. the 'adversary') to estimate I(Z; T) or its proxies. 

If the adversary is not trained sufficiently (tuned by inner_epochs), the learned Z will not be so far.

In [92]:
estimator = 'RENYI'                            # neural Renyi correlation

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.early_stop = False                # <-- don't use early stopping as the adversary is not mature at early stage
        self.beta = 0.15
        self.lr = 5e-4
        self.bs = 500
        self.inner_epochs = 4                  # <-- this adjust the execution time of neural Renyi model
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -0.0964183360338211 loss val= 0.8825427293777466 adv_loss= -0.9419568181037903 time= 0.14068889617919922
t= 50 loss= -0.07684874534606934 loss val= -0.1205780953168869 adv_loss= -0.04392237216234207 time= 0.18235993385314941
t= 100 loss= -0.13804253935813904 loss val= -0.11862627416849136 adv_loss= -0.05400653928518295 time= 0.17664313316345215
t= 150 loss= -0.06514017283916473 loss val= -0.13622711598873138 adv_loss= -0.05854476988315582 time= 0.188917875289917
t= 200 loss= 0.14819596707820892 loss val= 0.07584838569164276 adv_loss= -0.20300178229808807 time= 0.15832042694091797
t= 250 loss= 0.31101691722869873 loss val= 0.09185196459293365 adv_loss= -0.21724757552146912 time= 0.15630102157592773
t= 300 loss= 0.3326883912086487 loss val= 0.10909204185009003 adv_loss= -0.23380279541015625 time= 0.17449617385864258
t= 350 loss= 0.32776734232902527 loss val= 0.11558926105499268 adv_loss= -0.24150541424751282 time= 0.18619823455810547
t= 400 loss= 0.3161201477050781 loss val= 0

In [91]:
estimator = 'TC'                                # neural total correlation       

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.early_stop = False                 # <-- don't use early stopping as the adversary is not mature at early stage         
        self.beta = 0.05
        self.lr = 5e-4
        self.bs = 500
        self.inner_epochs = 10                  # <-- this adjust the execution time of neural Renyi model
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -21.613868713378906 loss val= -0.42932307720184326 adv_loss= 0.4191444516181946 time= 0.6922318935394287
t= 50 loss= -0.7624861001968384 loss val= -0.7328280806541443 adv_loss= 0.6875702142715454 time= 0.689429759979248
t= 100 loss= -0.7343903183937073 loss val= -0.7385129332542419 adv_loss= 0.6911075115203857 time= 0.7346620559692383
t= 150 loss= -0.7233610153198242 loss val= -0.7395967245101929 adv_loss= 0.6910511255264282 time= 0.672919750213623
t= 200 loss= -0.7070338129997253 loss val= -0.7488800883293152 adv_loss= 0.6995567083358765 time= 0.7108669281005859
t= 250 loss= -0.6999318599700928 loss val= -0.7569356560707092 adv_loss= 0.706995964050293 time= 0.7031376361846924
t= 300 loss= -0.6850447058677673 loss val= -0.7667465209960938 adv_loss= 0.7123385071754456 time= 0.7192873954772949
t= 350 loss= -0.6666769981384277 loss val= -0.7731035947799683 adv_loss= 0.723172664642334 time= 0.7463462352752686
t= 400 loss= -0.6835117340087891 loss val= -0.7772664427757263 adv_los