In [1]:
%matplotlib inline

import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
import math

import pyemd
from itertools import combinations
from itertools import chain
from scipy.stats import wasserstein_distance
from IPython.display import display

In [2]:
random.seed(9001)

## Generation of Datasets
Two datasets of 500 and 7300 samples
Each person in the datasets has 6 protected attributes:
* Gender                  = {Male, Female}
* Country                 = {America, India, Other}
* Year of Birth           = [1950, 2009]
* Language                = {English, Indian, Other}
* Ethnicity               = {White, African-American, Indian, Other}
* Years of Experience     = [0,30]

And two observed attributes:
* Language Test = [25,100]
* Approval rate = [25,100]

Task Qualification Function:

$f = \alpha b_1 + (1-\alpha)b_2$

Where $b_1$ is the *language test* and $b_2$ is *approval rate* and the $\alpha \in \{0,0.3,0.5,0.7,1\}$

In [3]:
# the protected columns
protected_attrs = {
    'gender' : ['male', 'female'],
    'country' : ['america', 'india', 'other'],
    'year_birth' : list(range(1950, 2009+1)),
    'language' : ['english', 'india', 'other'],
    'ethnicity' : ['white', 'african-american', 'indian', 'other'],
    'year_experience' : list(range(0,30+1))
}
# the observed columns
observed_attrs = {
    'language_test' : list(range(25,100+1)),
    'approval_rate' : list(range(25,100+1))
}

In [4]:
def generate_dataset(n):
    '''Generates the dataset accordinly the parameter n that represents the amount of people'''
    # define the dataset structure
    dataset = []
    # generate the samples
    for i in range(n):
        sample_protected = [v[random.randint(0,len(v)-1)] for k,v in protected_attrs.items()]
        sample_observed  = [v[random.randint(0,len(v)-1)] for k,v in observed_attrs.items()]
        sample = sample_protected + sample_observed
        dataset.append(sample)
        
    columns = list(protected_attrs.keys()) + list(observed_attrs.keys())
    return pd.DataFrame(dataset, columns=columns)

In [5]:
small_dataset = generate_dataset(500)
big_dataset = generate_dataset(7300)

# The algorithm

In [20]:
class BalancedAlgorithm:
    def __init__(self, attributes, bins=list(range(25,100+1))):
        self.attributes = attributes.copy()
        self.bins = bins
        
    def worst_attribute(self,W,f,A):
        worst_attr = ''
        highest_emd = float('-inf')
        
        for column, possible_values in A.items():
            histograms = []
            for value in possible_values:
                query_string = '{} == "{}"'.format(column, value)
                partition = W.query(query_string) # query by attribute value
                
                if partition.empty:
                    continue
                
                h,b = np.histogram([f(row) for i,row in partition.iterrows()], bins=self.bins)
                histograms.append(h)
            
            # we need more than 1 attr-value to compare the histograms
            if len(histograms) <= 1:
                return ''
            # we need to make the pairwise EMD
            pairs = combinations(histograms, 2)
            emd_list = []
            for pair in pairs:
                emd_value = wasserstein_distance(pair[0], pair[1])
                emd_list.append(emd_value)
                
            avg_emd = np.mean(emd_list)
            if avg_emd > highest_emd:
                highest_emd = avg_emd
                worst_attr = column
            
        return worst_attr, highest_emd
        
    def split(self,W,a):
        if type(W) is list:
            array = []
            for w in W:
                array += [df for _, df in w.groupby(a)]
            return array
                
        return [df for _, df in W.groupby(a)]

    def average_emd(self,W,f):
        histograms = []
        emd_list = []
        for partition in W:
            h,b = np.histogram([f(row) for i,row in partition.iterrows()], bins=self.bins)
            histograms.append(h)

        if len(histograms) <= 1:
            return 0
        pairs = combinations(histograms, 2)
        for pair in pairs:
            emd_value = wasserstein_distance(pair[0], pair[1])
            emd_list.append(emd_value)

        return np.mean(emd_list)

    def run(self,W,f,A):
        removal_list = []
        A = A.copy()
        a, emd_val = self.worst_attribute(W,f,A)
        removal_list.append(a)
        A.pop(a) # line 2 of the pseudo code
        current = self.split(W,a)
        current_avg = self.average_emd(current,f)

        while len(A) > 0:
            worst = [self.worst_attribute(c,f,A) for c in current]
            max_index = np.argmax([t[1] for t in worst])
            a = worst[max_index][0]
            
            removal_list.append(a)
            A.pop(a)
            children = self.split(current[max_index],a)
            
            # add the others partitions not splitted
            for i, partition in enumerate(current):
                if i == max_index:
                    continue
                children += [current[i]]
            
            children_avg = self.average_emd(children,f)
            if current_avg >= children_avg:
                break
            else:
                current = children
                current_avg = children_avg

        return current, current_avg, removal_list

In [7]:
class ScoringFunction:
    def __init__(self, alpha, b1_name, b2_name):
        self.a = alpha
        self.b1_name = b1_name
        self.b2_name = b2_name
        
    def f(self,row):
        b1 = row[self.b1_name]
        b2 = row[self.b1_name]
        return (self.a*b1 + (1-self.a)*b2)

In [8]:
alpha = [0,0.3,0.5,0.7,1]

f3 = ScoringFunction(alpha[2], 'language_test', 'approval_rate').f

In [21]:
balanced = BalancedAlgorithm(protected_attrs)
result = balanced.run(small_dataset.copy(), f3, protected_attrs)

In [22]:
print('Final EMD Value = {}\nSplitted on: {}'.format(result[1],result[2]))

Final EMD Value = 1.0336507936507935
Splitted on: ['gender', 'country', 'ethnicity', 'language']
