In [1]:
class Profile:
    """
    A Profile explores bias in a dataset w.r.t. a protected_attribute (e.g. race, gender,
    disability). It stratifies the data into the classes found under the protected_attribute feature.
    It then identifies differences between the strata and their output distributions.
    TODO: If given a specific protected_class, it will stratify the data into two sets, protected and non-protected.
    """

    def __init__(self, data, target, protected_attribute):
        """
        Parameters:
            data: DataFrame with features as column names.
            target: Array of labels of the dataset.
            protected_attribute: The feature of the data that is to be protected from discrimination. This column is assumed to be categorical.
        """
        self.data = data
        self.target = target
        self.protected_attribute = protected_attribute

        # partition data into a dict where keys are the category and the values are partitions of the dataset
        self.stratified_data = {}
        self.stratified_target = {}
        categories = set(self.data[self.protected_attribute])
        for category in categories:
            index_vector = (self.data[self.protected_attribute] == category)
            self.stratified_data[category] = self.data[index_vector]
            self.stratified_target[category] = self.target[index_vector]
        for key in self.stratified_data:
            self.stratified_data[key] = self.stratified_data[key].drop(protected_attribute, axis=1)

    def profile(self):
        # compare sets with: basic statistics (box plots), histograms (of frequencies for labels and for each category within a protected class), (aif360 visuals)
        # if label class is not given, highlight the strongest differences/correlations
        pass


In [194]:
import matplotlib.pyplot as plt
from sklearn import datasets
import pandas as pd
import numpy as np
import seaborn as sns

In [192]:
iris_raw = datasets.load_iris()
iris_data = pd.DataFrame(iris_raw.data, columns=iris_raw.feature_names)
iris_target = iris_raw.target
protected_attribute = iris.feature_names[0]
categories = list(set(iris_data[protected_attribute]))

In [193]:
pf = Profile(iris_data, iris_target, iris.feature_names[0])

In [190]:
sns.boxplot( )
for key in pf.stratified_data:
    print(pf.stratified_data[key])

    sepal width (cm)  petal length (cm)  petal width (cm)
2                3.2                1.3               0.2
29               3.2                1.6               0.2
    sepal width (cm)  petal length (cm)  petal width (cm)
33               4.2                1.4               0.2
36               3.5                1.3               0.2
53               2.3                4.0               1.3
80               2.4                3.8               1.1
81               2.4                3.7               1.0
89               2.5                4.0               1.3
90               2.6                4.4               1.2
    sepal width (cm)  petal length (cm)  petal width (cm)
4                3.6                1.4               0.2
7                3.4                1.5               0.2
25               3.0                1.6               0.2
26               3.4                1.6               0.4
35               3.2                1.2               0.2
40            