In [1]:
import os, sys

os.chdir('../..')

# Import

In [16]:
import os, sys, inspect, time

import numpy as np
import torch 
import matplotlib.pyplot as plt
import pandas as pd
from datasets import USCensus, UCICrime
from tasks.fairness.experiment import train, test
from tasks.fairness.model import FairNet
import utils_os

%load_ext autoreload
%autoreload 2

device = 'cuda:0'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Data preparation

In [17]:
X, Y, A = UCICrime.read_dataset(name='adult')
X = torch.Tensor(X).to(device)                     # input data
Y = torch.Tensor(Y).to(device)                     # target to predict
A = torch.Tensor(A).to(device)                     # sensitive attribute
assert X[0, -1] == A[0, -1]                        # the last index of x is also the sensitive attr 
assert X[0, -2] != A[0, -1]
dim = 20
n, d = X.size()
n, K = Y.size()

N = 20000
x, y, a = X[0:N,:], Y[0:N, :], A[0:N, :]
x = (x - x.mean(dim=0, keepdim=True))/x.std(dim=0, keepdim=True)
x_test, y_test, a_test = X[N:N+2500,:], Y[N:N+2500, :], A[N:N+2500, :]
x_test = (x_test - x_test.mean(dim=0, keepdim=True))/x_test.std(dim=0, keepdim=True)
print('train', x.size(), y.size())
print('test', x_test.size(), y_test.size())

DATASET = 'UCIAdult'
dim_z = 4*dim

train torch.Size([20000, 16]) torch.Size([20000, 1])
test torch.Size([2500, 16]) torch.Size([2500, 1])


# Experiments

### N/A

Plain model that has no constraint on I(Z; T).

In [132]:
estimator = 'NONE'

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.beta = 10
        self.lr = 1e-3
        self.bs = 500
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -9.857749938964844 loss val= -1.1813234090805054 adv_loss= tensor([0.], device='cuda:0') time= 0.0008447170257568359
t= 50 loss= -9.96342658996582 loss val= -9.973118782043457 adv_loss= tensor([0.], device='cuda:0') time= 0.0008435249328613281
t= 100 loss= -9.957209587097168 loss val= -9.987390518188477 adv_loss= tensor([0.], device='cuda:0') time= 0.0008449554443359375
t= 150 loss= -9.970044136047363 loss val= -9.990811347961426 adv_loss= tensor([0.], device='cuda:0') time= 0.0008385181427001953
t= 200 loss= -9.93735408782959 loss val= -9.993398666381836 adv_loss= tensor([0.], device='cuda:0') time= 0.0008389949798583984
t= 250 loss= -9.951817512512207 loss val= -9.972981452941895 adv_loss= tensor([0.], device='cuda:0') time= 0.0008325576782226562
t= 300 loss= -9.93704605102539 loss val= -9.986382484436035 adv_loss= tensor([0.], device='cuda:0') time= 0.0008449554443359375
t= 350 loss= -9.964373588562012 loss val= -9.990212440490723 adv_loss= tensor([0.], device='cuda:0') t

### Slice

The proposed slice based method, which estimate a slice of I(Z; T)

In [48]:
estimator = 'SLICE'

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.beta = 0.15
        self.lr = 5e-4
        self.bs = 500
        self.n_slice = 100
        self.inner_epochs = 0                   # <-- 0 means we don't optimise the slices, but you can also do so
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -0.11936375498771667 loss val= 0.9364299774169922 adv_loss= (0.9919296503067017, 0.991300106048584) time per iter= 0.2506539821624756
t= 50 loss= -0.1354360729455948 loss val= 0.013980641961097717 adv_loss= (0.214186429977417, 0.14658018946647644) time per iter= 0.22774410247802734
t= 100 loss= -0.09420060366392136 loss val= -0.09858796000480652 adv_loss= (0.159359410405159, 0.049193620681762695) time per iter= 0.22551226615905762
t= 150 loss= -0.14069044589996338 loss val= -0.07006895542144775 adv_loss= (0.16730333864688873, 0.0755138099193573) time per iter= 0.23713040351867676
t= 200 loss= -0.016740724444389343 loss val= -0.12110413610935211 adv_loss= (0.14782661199569702, 0.028454702347517014) time per iter= 0.22893381118774414
t= 250 loss= 0.03495204448699951 loss val= -0.14713731408119202 adv_loss= (0.14595946669578552, 0.002224075375124812) time per iter= 0.271761417388916
t= 300 loss= 0.03112076222896576 loss val= -0.10401973128318787 adv_loss= (0.1542695313692093, 0

### Adversarial training

Baseline methods that uses neural network (i.e. the *adversary*) to estimate I(Z; T) or its proxies. 

If the adversary is not trained sufficiently (controlled by *hyperparams.inner_epochs*), the learned Z will not be so fair.

In [37]:
estimator = 'RENYI'                             # neural Renyi correlation

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.early_stop = False                 # <-- don't use early stopping as the adversary is not mature at early stage
        self.beta = 0.35
        self.lr = 5e-4
        self.bs = 500
        self.inner_epochs = 3                   # <-- this adjust the execution time
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= 0.07830354571342468 loss val= 0.9647189974784851 adv_loss= -0.9817465543746948 time per iter= 0.8954341411590576
t= 50 loss= -0.18962064385414124 loss val= -0.3322595953941345 adv_loss= -0.021536797285079956 time per iter= 0.8899922370910645
t= 100 loss= -0.28281277418136597 loss val= -0.33801764249801636 adv_loss= -0.03141624107956886 time per iter= 0.8963000774383545
t= 150 loss= -0.14592097699642181 loss val= -0.22914168238639832 adv_loss= -0.11960452795028687 time per iter= 0.8935518264770508
t= 200 loss= -0.1158144623041153 loss val= -0.1758745163679123 adv_loss= -0.16743427515029907 time per iter= 0.8867254257202148
t= 250 loss= 0.11445313692092896 loss val= -0.10944932699203491 adv_loss= -0.22993695735931396 time per iter= 0.8841004371643066
t= 300 loss= 0.36939120292663574 loss val= -0.06791889667510986 adv_loss= -0.2741472125053406 time per iter= 0.8929531574249268
t= 350 loss= 0.37139758467674255 loss val= -0.04945150017738342 adv_loss= -0.29358190298080444 time pe

In [49]:
estimator = 'TC'                                # neural total correlation       

class Hyperparams(utils_os.ConfigDict):
    def __init__(self): 
        self.estimator = estimator
        self.early_stop = False                 # <-- don't use early stopping as the adversary is not mature at early stage
        self.beta = 0.05
        self.lr = 5e-4
        self.bs = 500
        self.inner_epochs = 5                   # <-- this adjust the execution time
        self.dim_learnt = dim_z
        self.dim_sensitive = 1
hyperparams=Hyperparams()

net = FairNet(architecture=[d, 10*dim, 10*dim, 4*dim], dim_y=K, hyperparams=hyperparams)

train(hyperparams, estimator, x, y, net, DATASET)
test(hyperparams, estimator, x_test, y_test, net, DATASET)

t= 0 loss= -0.6360607743263245 loss val= -0.5947480797767639 adv_loss= 0.5675549507141113 time per iter= 0.750281810760498
t= 50 loss= -0.7432966828346252 loss val= -0.7423279881477356 adv_loss= 0.6918755769729614 time per iter= 0.7195310592651367
t= 100 loss= -0.7386649250984192 loss val= -0.7399130463600159 adv_loss= 0.6906622648239136 time per iter= 1.4694995880126953
t= 150 loss= -0.7377430200576782 loss val= -0.7405813336372375 adv_loss= 0.692774772644043 time per iter= 1.3388607501983643
t= 200 loss= -0.6815589070320129 loss val= -0.7574877142906189 adv_loss= 0.7073756456375122 time per iter= 1.5055198669433594
t= 250 loss= -0.7017704248428345 loss val= -0.7684246301651001 adv_loss= 0.7195345759391785 time per iter= 1.406221628189087
t= 300 loss= -0.6817378401756287 loss val= -0.7648965120315552 adv_loss= 0.7175062894821167 time per iter= 1.3466072082519531
t= 350 loss= -0.6883325576782227 loss val= -0.7841498851776123 adv_loss= 0.7311117649078369 time per iter= 1.601251840591430