In [1]:
import sys, os
sys.path.append(os.path.abspath(os.path.join('../..')))

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import distributions
from torch.nn.parameter import Parameter
import torch.utils.data as data_utils
from collections import namedtuple
import functools

%load_ext autoreload
%autoreload 2

In [2]:
from facl.independence.density_estimation.pytorch_kde import kde
from facl.independence.hgr import chi_2, hgr, binary_renyi2_differentiable


def chi_squared_kde(X,Y):
    return binary_renyi2_differentiable(X, Y)

#def hgr_kde(X, Y):
#    return hgr(X, Y, kde)

We download and preprocess the dataset Adult from UCI as in https://github.com/jmikko/fair_ERM

In [3]:
from examples.data_loading import read_dataset

encoded_data, to_protect, encoded_data_test, to_protect_test = read_dataset(name='adult', fold=1)
encoded_data.head()



Unnamed: 0,0,1,2,3,4,5,6,7,8,10,11,12,13,Target
0,0.034201,2.917717,-1.062295,-0.344074,1.128753,0.942936,-1.482624,-0.258387,0.38411,0.142888,-0.21878,-0.07812,0.262999,0.0
1,0.866417,1.873997,-1.007438,-0.344074,1.128753,-0.390005,-0.737534,-0.884479,0.38411,-0.146733,-0.21878,-2.326738,0.262999,0.0
2,-0.041455,-0.213443,0.245284,0.179902,-0.438122,-1.722946,-0.240806,-0.258387,0.38411,-0.146733,-0.21878,-0.07812,0.262999,0.0
3,1.093385,-0.213443,0.425853,-2.439977,-1.221559,-0.390005,-0.240806,-0.884479,-2.018744,-0.146733,-0.21878,-0.07812,0.262999,0.0
4,-0.798015,-0.213443,1.407393,-0.344074,1.128753,-0.390005,0.752648,2.245982,-2.018744,-0.146733,-0.21878,-0.07812,-5.3293,0.0


We  define a very simple neural net 

In [4]:
# Hyper Parameters 
input_size = encoded_data.shape[1]-1
num_classes = 2
num_epochs = 20
batch_size = 128
batchRenyi = 128.
learning_rate = 5e-3
lambda_renyi = 4. * batchRenyi/batch_size


class NetRegression(nn.Module):
    def __init__(self, input_size, num_classes):
        super(NetRegression, self).__init__()
        size = 100
        self.first = nn.Linear(input_size, size)
        self.last = nn.Linear(size, num_classes)       
    
    def forward(self, x):
        out = F.selu( self.first(x) )
        out = self.last(out)
        return out
    
cfg_factory=namedtuple('Config', 'model  batch_size num_epochs lambda_renyi batchRenyi learning_rate input_size num_classes' )


config = cfg_factory(NetRegression, batch_size, num_epochs, lambda_renyi, batchRenyi, learning_rate, input_size, num_classes)


A few helper functions to compute performance metrics

In [5]:

def EntropyToProba(entropy): #Only for X Tensor of dimension 2
    return entropy[:,1].exp() / entropy.exp().sum(dim=1)

def calc_accuracy(outputs,Y): #Care outputs are going to be in dimension 2
    max_vals, max_indices = torch.max(outputs,1)
    acc = (max_indices == Y).sum().numpy()/max_indices.size()[0]
    return acc

def results_on_test(model, criterion, encoded_data_test, to_protect_test):
    target = torch.tensor(encoded_data_test['Target'].values.astype(np.long)).long()
    to_protect_test = torch.Tensor(to_protect_test)
    data = torch.tensor(encoded_data_test.drop('Target', axis = 1).values.astype(np.float32))
    outputs = model(data).detach()
    loss = criterion(outputs, target)
    p = EntropyToProba(outputs)
    pt = torch.Tensor(to_protect_test)

    ans = {}

    balanced_acc = (calc_accuracy(outputs[to_protect_test==0],target[to_protect_test==0]) +
                    calc_accuracy(outputs[to_protect_test==1],target[to_protect_test==1]))/2

    ans['loss'] = loss.item()
    ans['accuracy'] = calc_accuracy(outputs,target)
    ans['balanced_acc'] = balanced_acc

    f = 0.5
    p1 = (((pt == 1.)*(p>f)).sum().float() / (pt == 1).sum().float())
    p0 = (((pt == 0.)*(p>f)).sum().float() / (pt == 0).sum().float())
    o1 = (((pt == 1.)*(p>f)*(target==1)).sum().float()  / ((pt == 1)*(target==1)).sum().float())
    o2 = (((pt == 0.)*(p>f)*(target==1)).sum().float()  / ((pt == 0)*(target==1)).sum().float())
    di = p1 / p0
    deo = (o1 - o2).abs()
    ans['di'] = di.item()
    ans['deo'] = deo.item()

    return ans


In [6]:
verbose = True

model = config.model(config.input_size, config.num_classes)
        
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=0)

train_target = torch.tensor(encoded_data['Target'].values.astype(np.long)).long()
train_data = torch.tensor(encoded_data.drop('Target', axis = 1).values.astype(np.float32))
train_protect = torch.tensor(to_protect).float()
train_tensor = data_utils.TensorDataset(train_data, train_target)
train_loader = data_utils.DataLoader(dataset = train_tensor, batch_size = config.batch_size, shuffle = True)

for epoch in range(config.num_epochs):
        for i, (x, y) in enumerate(train_loader):
            optimizer.zero_grad()
            
            outputs = model(x)
            
            #Select a renyi regularization mini batch and compute the value of the model on it
            frac=config.batchRenyi/train_data.shape[0]
            foo = torch.bernoulli(frac*torch.ones(train_data.shape[0])).byte()
            br = train_data[foo, : ]
            pr = train_protect[foo]
            yr = train_target[foo].float()
            ren_outs = model(br)
        
            #Compute the usual loss of the prediction
            loss =  criterion(outputs, y)
            
            #Compte the fairness penalty for positive labels only since we optimize for DEO
            delta =  EntropyToProba(ren_outs[yr==1.])
            r2 = chi_squared_kde( delta, pr[yr==1.])
            
            #loss += config.lambda_renyi * r2
            
            #In Adam we trust
            loss.backward()
            optimizer.step()
        if verbose:
            print ('Epoch: [%d/%d], Batch: [%d/%d], Loss: %.4f, Accuracy: %.4f, Fairness penalty: %.4f'  % (epoch+1, config.num_epochs, i, len(encoded_data)//batch_size,
                    loss.item(),calc_accuracy(outputs,y),
                    r2.item()
                     ))
            #print( results_on_test(model, criterion, encoded_data_test, to_protect_test) )

print("Results on test set")
results_on_test(model, criterion, encoded_data_test, to_protect_test)

Epoch: [1/20], Batch: [254/254], Loss: 0.3301, Accuracy: 0.8163, Fairness penalty: 0.0304
Epoch: [2/20], Batch: [254/254], Loss: 0.1850, Accuracy: 0.9184, Fairness penalty: 0.0032
Epoch: [3/20], Batch: [254/254], Loss: 0.3635, Accuracy: 0.7959, Fairness penalty: 0.0006
Epoch: [4/20], Batch: [254/254], Loss: 0.3028, Accuracy: 0.8776, Fairness penalty: 0.0020
Epoch: [5/20], Batch: [254/254], Loss: 0.3593, Accuracy: 0.8776, Fairness penalty: 0.0111
Epoch: [6/20], Batch: [254/254], Loss: 0.3544, Accuracy: 0.8163, Fairness penalty: 0.0000
Epoch: [7/20], Batch: [254/254], Loss: 0.2754, Accuracy: 0.8980, Fairness penalty: 0.0008
Epoch: [8/20], Batch: [254/254], Loss: 0.4581, Accuracy: 0.7755, Fairness penalty: 0.0040
Epoch: [9/20], Batch: [254/254], Loss: 0.3021, Accuracy: 0.8367, Fairness penalty: 0.0987
Epoch: [10/20], Batch: [254/254], Loss: 0.3315, Accuracy: 0.8776, Fairness penalty: 0.0148
Epoch: [11/20], Batch: [254/254], Loss: 0.3047, Accuracy: 0.8163, Fairness penalty: 0.0096
Epoch: [

{'loss': 0.3318299353122711,
 'accuracy': 0.8447989890214043,
 'balanced_acc': 0.8640701942626562,
 'di': 0.44910478591918945,
 'deo': 0.0397072434425354}