In [1]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt

from fado.preprocessing import MetricOptimizer

In [2]:
iris = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv')

# encode non-numeric columns to labels
le = LabelEncoder()
iris_transform = iris.apply(le.fit_transform)

# make species binary
iris_transform['species'] = iris_transform['species'] == 0

# make label
iris_transform['petal_length'] = iris_transform['petal_length'] > 9

protected_attribute = 'species'
label = 'petal_length'

In [3]:
# Initialize MetricOptimizer
# Keep 75\% of the whole dataset
preproc = MetricOptimizer(frac=0.75,
                          protected_attribute='species',
                          label='petal_length')

In [4]:
# Fit data
preproc = preproc.fit(iris_transform)
# Remove samples to yield a fair dataset
iris_fair = preproc.transform()

### Genetic Algorithm (Solver)

In [5]:
from fado.metrics import statistical_parity_absolute_difference
# heuristic test
from fado.preprocessing.solvers.geneticalgorithm import genetic_algorithm_uniform_method
from fado.preprocessing import HeuristicWrapper

In [6]:
heuristic = genetic_algorithm_uniform_method
disc_measure = statistical_parity_absolute_difference
preproc_heuristics = HeuristicWrapper(heuristic, disc_measure=disc_measure,
                protected_attribute=protected_attribute,
                label=label)

In [7]:
# Create pre-processing instance
preproc_heuristics.fit(iris_transform)
# Remove samples to yield a fair dataset
iris_fair_heuristic = preproc_heuristics.transform()

In [15]:
disc_fair = statistical_parity_absolute_difference(iris_fair['petal_length'], iris_fair['species'])
disc_fair_heuristic = statistical_parity_absolute_difference(iris_fair_heuristic['petal_length'], iris_fair_heuristic['species'])
disc_orig = statistical_parity_absolute_difference(iris_transform['petal_length'], iris_transform['species'])

In [17]:
print(f'The original dataset has a statistical disparity (absolute) value of: {disc_orig}')
print(f'The pre-processed fair dataset has a statistical disparity (absolute) value of: {disc_fair}')
print(f'The pre-processed fair dataset with GA has a statistical disparity (absolute) value of: {disc_fair_heuristic}')
print('(Lower is better.)')

The original dataset has a statistical disparity (absolute) value of: 0.99
The pre-processed fair dataset has a statistical disparity (absolute) value of: 0.9841269841269841
The pre-processed fair dataset with GA has a statistical disparity (absolute) value of: 0.0
(Lower is better.)
