In [1]:
import os, sys

os.chdir('../..')

# Import

In [5]:
import os, sys, inspect, time

import numpy as np
import torch 
import matplotlib.pyplot as plt
import pandas as pd
from datasets import USCensus, UCICrime
from tasks.fairness.experiment import train, test
from tasks.fairness.model import FairNet

import utils_os, utils_data

%load_ext autoreload
%autoreload 2

device = 'cuda:1'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data preparation

In [3]:
X, Y, A = USCensus.read_tract(label='Proverty', sensitive_attribute='Gender')
X = torch.Tensor(X).to(device)                     # input data
Y = torch.Tensor(Y).to(device)                     # target to predict
A = torch.Tensor(A).to(device)                     # sensitive attribute
assert X[0, -1] == A[0, -1]                        # the last index of x is also the sensitive attr 
assert X[0, -2] != A[0, -1]
dim = 20
n, d = X.size()
n, K = Y.size()

N = 20000
x, y, a = X[0:N,:], Y[0:N, :], A[0:N, :]
x = (x - x.mean(dim=0, keepdim=True))/x.std(dim=0, keepdim=True)
x_test, y_test, a_test = X[N:N+2500,:], Y[N:N+2500, :], A[N:N+2500, :]
x_test = (x_test - x_test.mean(dim=0, keepdim=True))/x_test.std(dim=0, keepdim=True)
print('train', x.size(), y.size())
print('test', x_test.size(), y_test.size())

DATASET = 'USCensus'
dim_z = 4*dim

train torch.Size([20000, 32]) torch.Size([20000, 1])
test torch.Size([2500, 32]) torch.Size([2500, 1])


# Experiments

### N/A

Plain model that has no constraint on I(Z; T). 

In [88]:
estimator = 'NONE'

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.beta = 10
        self.lr = 5e-4
        self.bs = 500
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -9.520732879638672 loss val= -5.054341793060303 adv_loss= tensor([0.], device='cuda:1') time= 0.0008563995361328125
t= 50 loss= -9.793397903442383 loss val= -9.495684623718262 adv_loss= tensor([0.], device='cuda:1') time= 0.0008587837219238281
t= 100 loss= -9.866706848144531 loss val= -9.397472381591797 adv_loss= tensor([0.], device='cuda:1') time= 0.0008535385131835938
t= 150 loss= -9.891358375549316 loss val= -9.36882495880127 adv_loss= tensor([0.], device='cuda:1') time= 0.0008504390716552734
t= 200 loss= -9.893041610717773 loss val= -9.363776206970215 adv_loss= tensor([0.], device='cuda:1') time= 0.0008552074432373047
t= 250 loss= -9.887506484985352 loss val= -9.363200187683105 adv_loss= tensor([0.], device='cuda:1') time= 0.0008544921875
t= 300 loss= -9.869596481323242 loss val= -9.3630952835083 adv_loss= tensor([0.], device='cuda:1') time= 0.0008604526519775391
t= 350 loss= -9.868520736694336 loss val= -9.363089561462402 adv_loss= tensor([0.], device='cuda:1') time= 0.

### Slice

The proposed slice based method, which estimate a slice of I(Z; T)

In [27]:
estimator = 'SLICE'

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.beta = 0.75
        self.lr = 5e-4
        self.bs = 500
        self.n_slice = 100
        self.inner_epochs = 0                   # <-- 0 means we don't optimise the slices, but you can also do so
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -0.6070358753204346 loss val= 0.8509424328804016 adv_loss= (0.9503340125083923, 0.9500017166137695) time per iter= 0.2262434959411621
t= 50 loss= -0.7124919891357422 loss val= -0.547791600227356 adv_loss= (0.2390560507774353, 0.14599670469760895) time per iter= 0.248460054397583
t= 100 loss= -0.645367443561554 loss val= -0.6072866320610046 adv_loss= (0.18458671867847443, 0.09757854044437408) time per iter= 0.2509024143218994
t= 150 loss= -0.670754611492157 loss val= -0.6361450552940369 adv_loss= (0.1766381859779358, 0.06653881818056107) time per iter= 0.27583789825439453
t= 200 loss= -0.6228980422019958 loss val= -0.6322997212409973 adv_loss= (0.15864570438861847, 0.06550151854753494) time per iter= 0.2328033447265625
t= 250 loss= -0.6714261770248413 loss val= -0.6992306709289551 adv_loss= (0.146018385887146, 0.004361852537840605) time per iter= 0.20165681838989258
t= 300 loss= -0.6040209531784058 loss val= -0.613516628742218 adv_loss= (0.1632571518421173, 0.0954404398798942

### Adversarial training

Baseline methods that uses neural network (i.e. the 'adversary') to estimate I(Z; T) or its proxies. 

If the adversary is not trained sufficiently (tuned by inner_epochs), the learned Z will not be so far.

In [12]:
estimator = 'RENYI'                            # neural Renyi correlation

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.early_stop = False                # <-- don't use early stopping as the adversary is not mature at early stage
        self.beta = 0.15
        self.lr = 5e-4
        self.bs = 500
        self.inner_epochs = 2                  # <-- this adjust the execution time
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -0.0809314101934433 loss val= 0.7632180452346802 adv_loss= -0.8326592445373535 time per iter= 0.694091796875
t= 50 loss= -0.14120878279209137 loss val= -0.09934559464454651 adv_loss= -0.053575098514556885 time per iter= 0.641768217086792
t= 100 loss= -0.06660383939743042 loss val= -0.03902415931224823 adv_loss= -0.10181170701980591 time per iter= 0.28841352462768555
t= 150 loss= 0.05179755389690399 loss val= 0.05810964107513428 adv_loss= -0.17670926451683044 time per iter= 0.2869553565979004
t= 200 loss= 0.15463517606258392 loss val= 0.1431826651096344 adv_loss= -0.28230786323547363 time per iter= 0.278562068939209
t= 250 loss= 0.36782699823379517 loss val= 0.1845712810754776 adv_loss= -0.31612634658813477 time per iter= 0.2640187740325928
t= 300 loss= 0.308788001537323 loss val= 0.22461208701133728 adv_loss= -0.35160624980926514 time per iter= 0.28175950050354004
t= 350 loss= 0.37271440029144287 loss val= 0.24415381252765656 adv_loss= -0.36288535594940186 time per iter= 0.2

In [32]:
estimator = 'TC'                                # neural total correlation       

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.early_stop = False                 # <-- don't use early stopping as the adversary is not mature at early stage         
        self.beta = 0.10
        self.lr = 5e-4
        self.bs = 500
        self.inner_epochs = 3                   # <-- this adjust the execution time
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -0.7807778716087341 loss val= -0.7066316604614258 adv_loss= 0.6822531223297119 time per iter= 1.1141557693481445
t= 50 loss= -0.7901327013969421 loss val= -0.7851606607437134 adv_loss= 0.6905874609947205 time per iter= 1.1358845233917236
t= 100 loss= -0.7918014526367188 loss val= -0.7872661352157593 adv_loss= 0.6914178133010864 time per iter= 1.1185259819030762
t= 150 loss= -0.7863660454750061 loss val= -0.7877463102340698 adv_loss= 0.6915872097015381 time per iter= 1.1400105953216553
t= 200 loss= -0.7805623412132263 loss val= -0.7862117886543274 adv_loss= 0.6908053159713745 time per iter= 1.1238410472869873
t= 250 loss= -0.7739601731300354 loss val= -0.7857621908187866 adv_loss= 0.6910184621810913 time per iter= 0.8811068534851074
t= 300 loss= -0.7737019658088684 loss val= -0.791662335395813 adv_loss= 0.6955823302268982 time per iter= 1.1384270191192627
t= 350 loss= -0.7518147826194763 loss val= -0.7927985191345215 adv_loss= 0.6972127556800842 time per iter= 0.9584867954254