From 1a053ea0cd1191ab0b7bc63dc17141924a07f51c Mon Sep 17 00:00:00 2001 From: Fiona Victoria <37321147+fionavictoria@users.noreply.github.com> Date: Fri, 16 Sep 2022 11:08:56 -0700 Subject: [PATCH] Fix - Get notebooks in examples/sklearn/ to work in Google colab (Part 1/3) #378 Signed-off-by: fionavictoria --- ...rch_reduction_classification_sklearn.ipynb | 2060 ++++++++++------- 1 file changed, 1201 insertions(+), 859 deletions(-) diff --git a/examples/sklearn/demo_grid_search_reduction_classification_sklearn.ipynb b/examples/sklearn/demo_grid_search_reduction_classification_sklearn.ipynb index 1c7417f2..6e1e204e 100644 --- a/examples/sklearn/demo_grid_search_reduction_classification_sklearn.ipynb +++ b/examples/sklearn/demo_grid_search_reduction_classification_sklearn.ipynb @@ -1,865 +1,1207 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Sklearn compatible Grid Search for classification\n", - "\n", - "Grid search is an in-processing technique that can be used for fair classification or fair regression. For classification it reduces fair classification to a sequence of cost-sensitive classification problems, returning the deterministic classifier with the lowest empirical error subject to fair classification constraints among\n", - "the candidates searched. The code for grid search wraps the source class `fairlearn.reductions.GridSearch` available in the https://github.com/fairlearn/fairlearn library, licensed under the MIT Licencse, Copyright Microsoft Corporation." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import warnings\n", - "warnings.filterwarnings(\"ignore\", category=FutureWarning)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import pandas as pd\n", - "\n", - "from sklearn.linear_model import LogisticRegression\n", - "from sklearn.metrics import accuracy_score\n", - "from sklearn.model_selection import train_test_split\n", - "\n", - "from aif360.sklearn.inprocessing import GridSearchReduction\n", - "\n", - "from aif360.sklearn.datasets import fetch_adult\n", - "from aif360.sklearn.metrics import average_odds_error" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Loading data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Datasets are formatted as separate `X` (# samples x # features) and `y` (# samples x # labels) DataFrames. The index of each DataFrame contains protected attribute values per sample. Datasets may also load a `sample_weight` object to be used with certain algorithms/metrics. All of this makes it so that aif360 is compatible with scikit-learn objects.\n", - "\n", - "For example, we can easily load the Adult dataset from UCI with the following line:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ageworkclasseducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-country
racesex
Non-whiteMale25.0Private11th7.0Never-marriedMachine-op-inspctOwn-childBlackMale0.00.040.0United-States
WhiteMale38.0PrivateHS-grad9.0Married-civ-spouseFarming-fishingHusbandWhiteMale0.00.050.0United-States
Male28.0Local-govAssoc-acdm12.0Married-civ-spouseProtective-servHusbandWhiteMale0.00.040.0United-States
Non-whiteMale44.0PrivateSome-college10.0Married-civ-spouseMachine-op-inspctHusbandBlackMale7688.00.040.0United-States
WhiteMale34.0Private10th6.0Never-marriedOther-serviceNot-in-familyWhiteMale0.00.030.0United-States
\n", - "
" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "JaLHu46TbUmy" + }, + "source": [ + "# Sklearn compatible Grid Search for classification\n", + "\n", + "Grid search is an in-processing technique that can be used for fair classification or fair regression. For classification it reduces fair classification to a sequence of cost-sensitive classification problems, returning the deterministic classifier with the lowest empirical error subject to fair classification constraints among\n", + "the candidates searched. The code for grid search wraps the source class `fairlearn.reductions.GridSearch` available in the https://github.com/fairlearn/fairlearn library, licensed under the MIT Licencse, Copyright Microsoft Corporation." + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install tk\n", + "!pip install 'aif360[LawSchoolGPA]'" ], - "text/plain": [ - " age workclass education education-num \\\n", - "race sex \n", - "Non-white Male 25.0 Private 11th 7.0 \n", - "White Male 38.0 Private HS-grad 9.0 \n", - " Male 28.0 Local-gov Assoc-acdm 12.0 \n", - "Non-white Male 44.0 Private Some-college 10.0 \n", - "White Male 34.0 Private 10th 6.0 \n", - "\n", - " marital-status occupation relationship race \\\n", - "race sex \n", - "Non-white Male Never-married Machine-op-inspct Own-child Black \n", - "White Male Married-civ-spouse Farming-fishing Husband White \n", - " Male Married-civ-spouse Protective-serv Husband White \n", - "Non-white Male Married-civ-spouse Machine-op-inspct Husband Black \n", - "White Male Never-married Other-service Not-in-family White \n", - "\n", - " sex capital-gain capital-loss hours-per-week \\\n", - "race sex \n", - "Non-white Male Male 0.0 0.0 40.0 \n", - "White Male Male 0.0 0.0 50.0 \n", - " Male Male 0.0 0.0 40.0 \n", - "Non-white Male Male 7688.0 0.0 40.0 \n", - "White Male Male 0.0 0.0 30.0 \n", - "\n", - " native-country \n", - "race sex \n", - "Non-white Male United-States \n", - "White Male United-States \n", - " Male United-States \n", - "Non-white Male United-States \n", - "White Male United-States " - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "X, y, sample_weight = fetch_adult()\n", - "X.head()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can then map the protected attributes to integers," - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "X.index = pd.MultiIndex.from_arrays(X.index.codes, names=X.index.names)\n", - "y.index = pd.MultiIndex.from_arrays(y.index.codes, names=y.index.names)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "and the target classes to 0/1," - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "y = pd.Series(y.factorize(sort=True)[0], index=y.index)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "split the dataset," - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "(X_train, X_test,\n", - " y_train, y_test) = train_test_split(X, y, train_size=0.7, random_state=1234567)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We use Pandas for one-hot encoding for easy reference to columns associated with protected attributes, information necessary for grid search reduction." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ageeducation-numcapital-gaincapital-losshours-per-weekworkclass_Privateworkclass_Self-emp-not-incworkclass_Self-emp-incworkclass_Federal-govworkclass_Local-gov...native-country_Guatemalanative-country_Nicaraguanative-country_Scotlandnative-country_Thailandnative-country_Yugoslavianative-country_El-Salvadornative-country_Trinadad&Tobagonative-country_Perunative-country_Hongnative-country_Holand-Netherlands
racesex
1158.011.00.00.042.001000...0000000000
051.012.00.00.030.001000...0000000000
126.014.00.01887.040.010000...0000000000
144.03.00.00.040.010000...0000000000
133.06.00.00.040.010000...0000000000
\n", - "

5 rows × 102 columns

\n", - "
" + "metadata": { + "id": "oqUKUXH3dFVT", + "outputId": "1f00c58a-f2bc-4eda-a29c-c97e3f33fb09", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting tk\n", + " Downloading tk-0.1.0-py3-none-any.whl (3.9 kB)\n", + "Installing collected packages: tk\n", + "Successfully installed tk-0.1.0\n", + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Requirement already satisfied: aif360[LawSchoolGPA] in /usr/local/lib/python3.7/dist-packages (0.5.0)\n", + "Requirement already satisfied: pandas>=0.24.0 in /usr/local/lib/python3.7/dist-packages (from aif360[LawSchoolGPA]) (1.3.5)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from aif360[LawSchoolGPA]) (3.2.2)\n", + "Requirement already satisfied: scikit-learn>=1.0 in /usr/local/lib/python3.7/dist-packages (from aif360[LawSchoolGPA]) (1.0.2)\n", + "Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from aif360[LawSchoolGPA]) (1.21.6)\n", + "Requirement already satisfied: scipy>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from aif360[LawSchoolGPA]) (1.7.3)\n", + "Collecting tempeh\n", + " Downloading tempeh-0.1.12-py3-none-any.whl (39 kB)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.24.0->aif360[LawSchoolGPA]) (2.8.2)\n", + "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.24.0->aif360[LawSchoolGPA]) (2022.2.1)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas>=0.24.0->aif360[LawSchoolGPA]) (1.15.0)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=1.0->aif360[LawSchoolGPA]) (3.1.0)\n", + "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=1.0->aif360[LawSchoolGPA]) (1.1.0)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->aif360[LawSchoolGPA]) (0.11.0)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->aif360[LawSchoolGPA]) (1.4.4)\n", + "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->aif360[LawSchoolGPA]) (3.0.9)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from kiwisolver>=1.0.1->matplotlib->aif360[LawSchoolGPA]) (4.1.1)\n", + "Requirement already satisfied: pytest in /usr/local/lib/python3.7/dist-packages (from tempeh->aif360[LawSchoolGPA]) (3.6.4)\n", + "Collecting memory-profiler\n", + " Downloading memory_profiler-0.60.0.tar.gz (38 kB)\n", + "Collecting shap\n", + " Downloading shap-0.41.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (569 kB)\n", + "\u001b[K |████████████████████████████████| 569 kB 7.5 MB/s \n", + "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from tempeh->aif360[LawSchoolGPA]) (2.23.0)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from memory-profiler->tempeh->aif360[LawSchoolGPA]) (5.4.8)\n", + "Requirement already satisfied: py>=1.5.0 in /usr/local/lib/python3.7/dist-packages (from pytest->tempeh->aif360[LawSchoolGPA]) (1.11.0)\n", + "Requirement already satisfied: more-itertools>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from pytest->tempeh->aif360[LawSchoolGPA]) (8.14.0)\n", + "Requirement already satisfied: pluggy<0.8,>=0.5 in /usr/local/lib/python3.7/dist-packages (from pytest->tempeh->aif360[LawSchoolGPA]) (0.7.1)\n", + "Requirement already satisfied: atomicwrites>=1.0 in /usr/local/lib/python3.7/dist-packages (from pytest->tempeh->aif360[LawSchoolGPA]) (1.4.1)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from pytest->tempeh->aif360[LawSchoolGPA]) (57.4.0)\n", + "Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.7/dist-packages (from pytest->tempeh->aif360[LawSchoolGPA]) (22.1.0)\n", + "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->tempeh->aif360[LawSchoolGPA]) (3.0.4)\n", + "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->tempeh->aif360[LawSchoolGPA]) (1.24.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->tempeh->aif360[LawSchoolGPA]) (2022.6.15)\n", + "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->tempeh->aif360[LawSchoolGPA]) (2.10)\n", + "Requirement already satisfied: tqdm>4.25.0 in /usr/local/lib/python3.7/dist-packages (from shap->tempeh->aif360[LawSchoolGPA]) (4.64.1)\n", + "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from shap->tempeh->aif360[LawSchoolGPA]) (1.5.0)\n", + "Requirement already satisfied: packaging>20.9 in /usr/local/lib/python3.7/dist-packages (from shap->tempeh->aif360[LawSchoolGPA]) (21.3)\n", + "Collecting slicer==0.0.7\n", + " Downloading slicer-0.0.7-py3-none-any.whl (14 kB)\n", + "Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (from shap->tempeh->aif360[LawSchoolGPA]) (0.56.2)\n", + "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from numba->shap->tempeh->aif360[LawSchoolGPA]) (4.12.0)\n", + "Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in /usr/local/lib/python3.7/dist-packages (from numba->shap->tempeh->aif360[LawSchoolGPA]) (0.39.1)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->numba->shap->tempeh->aif360[LawSchoolGPA]) (3.8.1)\n", + "Building wheels for collected packages: memory-profiler\n", + " Building wheel for memory-profiler (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for memory-profiler: filename=memory_profiler-0.60.0-py3-none-any.whl size=31284 sha256=433aefeb0eaa28520ae538e18cd0a293757e0ab19c4e6c61ea3e02ed0659e9e9\n", + " Stored in directory: /root/.cache/pip/wheels/67/2b/fb/326e30d638c538e69a5eb0aa47f4223d979f502bbdb403950f\n", + "Successfully built memory-profiler\n", + "Installing collected packages: slicer, shap, memory-profiler, tempeh\n", + "Successfully installed memory-profiler-0.60.0 shap-0.41.0 slicer-0.0.7 tempeh-0.1.12\n" + ] + } + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "BmjX6xLmbUm0" + }, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings(\"ignore\", category=FutureWarning)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "6N6TJuzXbUm0" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.metrics import accuracy_score\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "from aif360.sklearn.inprocessing import GridSearchReduction\n", + "\n", + "from aif360.sklearn.datasets import fetch_adult\n", + "from aif360.sklearn.metrics import average_odds_error" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "39M3bDclbUm0" + }, + "source": [ + "### Loading data" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cuRMM729bUm1" + }, + "source": [ + "Datasets are formatted as separate `X` (# samples x # features) and `y` (# samples x # labels) DataFrames. The index of each DataFrame contains protected attribute values per sample. Datasets may also load a `sample_weight` object to be used with certain algorithms/metrics. All of this makes it so that aif360 is compatible with scikit-learn objects.\n", + "\n", + "For example, we can easily load the Adult dataset from UCI with the following line:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "um_jEG9xbUm1", + "outputId": "e35cd307-a489-4d3b-ac86-81edfe4586e3", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 368 + } + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " age workclass education education-num \\\n", + "race sex \n", + "Non-white Male 25.0 Private 11th 7.0 \n", + "White Male 38.0 Private HS-grad 9.0 \n", + " Male 28.0 Local-gov Assoc-acdm 12.0 \n", + "Non-white Male 44.0 Private Some-college 10.0 \n", + "White Male 34.0 Private 10th 6.0 \n", + "\n", + " marital-status occupation relationship race \\\n", + "race sex \n", + "Non-white Male Never-married Machine-op-inspct Own-child Black \n", + "White Male Married-civ-spouse Farming-fishing Husband White \n", + " Male Married-civ-spouse Protective-serv Husband White \n", + "Non-white Male Married-civ-spouse Machine-op-inspct Husband Black \n", + "White Male Never-married Other-service Not-in-family White \n", + "\n", + " sex capital-gain capital-loss hours-per-week \\\n", + "race sex \n", + "Non-white Male Male 0.0 0.0 40.0 \n", + "White Male Male 0.0 0.0 50.0 \n", + " Male Male 0.0 0.0 40.0 \n", + "Non-white Male Male 7688.0 0.0 40.0 \n", + "White Male Male 0.0 0.0 30.0 \n", + "\n", + " native-country \n", + "race sex \n", + "Non-white Male United-States \n", + "White Male United-States \n", + " Male United-States \n", + "Non-white Male United-States \n", + "White Male United-States " + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclasseducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-country
racesex
Non-whiteMale25.0Private11th7.0Never-marriedMachine-op-inspctOwn-childBlackMale0.00.040.0United-States
WhiteMale38.0PrivateHS-grad9.0Married-civ-spouseFarming-fishingHusbandWhiteMale0.00.050.0United-States
Male28.0Local-govAssoc-acdm12.0Married-civ-spouseProtective-servHusbandWhiteMale0.00.040.0United-States
Non-whiteMale44.0PrivateSome-college10.0Married-civ-spouseMachine-op-inspctHusbandBlackMale7688.00.040.0United-States
WhiteMale34.0Private10th6.0Never-marriedOther-serviceNot-in-familyWhiteMale0.00.030.0United-States
\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ] + }, + "metadata": {}, + "execution_count": 4 + } ], - "text/plain": [ - " age education-num capital-gain capital-loss hours-per-week \\\n", - "race sex \n", - "1 1 58.0 11.0 0.0 0.0 42.0 \n", - " 0 51.0 12.0 0.0 0.0 30.0 \n", - " 1 26.0 14.0 0.0 1887.0 40.0 \n", - " 1 44.0 3.0 0.0 0.0 40.0 \n", - " 1 33.0 6.0 0.0 0.0 40.0 \n", - "\n", - " workclass_Private workclass_Self-emp-not-inc \\\n", - "race sex \n", - "1 1 0 1 \n", - " 0 0 1 \n", - " 1 1 0 \n", - " 1 1 0 \n", - " 1 1 0 \n", - "\n", - " workclass_Self-emp-inc workclass_Federal-gov workclass_Local-gov \\\n", - "race sex \n", - "1 1 0 0 0 \n", - " 0 0 0 0 \n", - " 1 0 0 0 \n", - " 1 0 0 0 \n", - " 1 0 0 0 \n", - "\n", - " ... native-country_Guatemala native-country_Nicaragua \\\n", - "race sex ... \n", - "1 1 ... 0 0 \n", - " 0 ... 0 0 \n", - " 1 ... 0 0 \n", - " 1 ... 0 0 \n", - " 1 ... 0 0 \n", - "\n", - " native-country_Scotland native-country_Thailand \\\n", - "race sex \n", - "1 1 0 0 \n", - " 0 0 0 \n", - " 1 0 0 \n", - " 1 0 0 \n", - " 1 0 0 \n", - "\n", - " native-country_Yugoslavia native-country_El-Salvador \\\n", - "race sex \n", - "1 1 0 0 \n", - " 0 0 0 \n", - " 1 0 0 \n", - " 1 0 0 \n", - " 1 0 0 \n", - "\n", - " native-country_Trinadad&Tobago native-country_Peru \\\n", - "race sex \n", - "1 1 0 0 \n", - " 0 0 0 \n", - " 1 0 0 \n", - " 1 0 0 \n", - " 1 0 0 \n", - "\n", - " native-country_Hong native-country_Holand-Netherlands \n", - "race sex \n", - "1 1 0 0 \n", - " 0 0 0 \n", - " 1 0 0 \n", - " 1 0 0 \n", - " 1 0 0 \n", - "\n", - "[5 rows x 102 columns]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "X_train, X_test = pd.get_dummies(X_train), pd.get_dummies(X_test)\n", - "X_train = X_train.drop(columns=['sex_Female'])\n", - "X_test = X_test.drop(columns=['sex_Female'])\n", - "X_train.head()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The protected attribute information is also replicated in the labels:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "race sex\n", - "1 1 0\n", - " 0 1\n", - " 1 1\n", - " 1 0\n", - " 1 0\n", - "dtype: int64" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y_train.head()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Running metrics" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "With the data in this format, we can easily train a scikit-learn model and get predictions for the test data:" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.8453600648632712\n" - ] - } - ], - "source": [ - "y_pred = LogisticRegression(solver='liblinear', random_state=1234).fit(X_train, y_train).predict(X_test)\n", - "lr_acc = accuracy_score(y_test, y_pred)\n", - "print(lr_acc)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can assess how close the predictions are to equality of odds.\n", - "\n", - "`average_odds_error()` computes the (unweighted) average of the absolute values of the true positive rate (TPR) difference and false positive rate (FPR) difference, i.e.:\n", - "\n", - "$$ \\tfrac{1}{2}\\left(|FPR_{D = \\text{unprivileged}} - FPR_{D = \\text{privileged}}| + |TPR_{D = \\text{unprivileged}} - TPR_{D = \\text{privileged}}|\\right) $$" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.09356509680536546\n" - ] - } - ], - "source": [ - "lr_aoe = average_odds_error(y_test, y_pred, prot_attr='sex')\n", - "print(lr_aoe)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Grid Search" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Choose a base model for the candidate classifiers. Base models should implement a fit method that can take a sample weight as input. For details refer to the docs. " - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "estimator = LogisticRegression(solver='liblinear', random_state=1234)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Determine the columns associated with the protected attribute(s). Grid search can handle more than one attribute but it is computationally expensive. A similar method with less computational overhead is exponentiated gradient reduction, detailed at [examples/sklearn/demo_exponentiated_gradient_reduction_sklearn.ipynb](sklearn/demo_exponentiated_gradient_reduction_sklearn.ipynb)." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "prot_attr = 'sex_Male'" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Search for the best classifier and observe test accuracy. Other options for `constraints` include \"DemographicParity\", \"TruePositiveRateParity\", \"FalsePositiveRateParity\", and \"ErrorRateParity\"." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.8455074813886637\n" - ] - } - ], - "source": [ - "np.random.seed(0) #need for reproducibility\n", - "grid_search_red = GridSearchReduction(prot_attr=prot_attr, \n", - " estimator=estimator, \n", - " constraints=\"EqualizedOdds\",\n", - " grid_size=20,\n", - " drop_prot_attr=False)\n", - "grid_search_red.fit(X_train, y_train)\n", - "gs_acc = grid_search_red.score(X_test, y_test)\n", - "print(gs_acc)\n", - "\n", - "#Check if accuracy is comparable\n", - "assert abs(lr_acc-gs_acc)<0.03" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.06715455716850638\n" - ] - } - ], - "source": [ - "gs_aoe = average_odds_error(y_test, grid_search_red.predict(X_test), prot_attr='sex')\n", - "print(gs_aoe)\n", - "\n", - "#Check if average odds error improved\n", - "assert gs_aoe\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageeducation-numcapital-gaincapital-losshours-per-weekworkclass_Privateworkclass_Self-emp-not-incworkclass_Self-emp-incworkclass_Federal-govworkclass_Local-gov...native-country_Guatemalanative-country_Nicaraguanative-country_Scotlandnative-country_Thailandnative-country_Yugoslavianative-country_El-Salvadornative-country_Trinadad&Tobagonative-country_Perunative-country_Hongnative-country_Holand-Netherlands
racesex
1158.011.00.00.042.001000...0000000000
051.012.00.00.030.001000...0000000000
126.014.00.01887.040.010000...0000000000
144.03.00.00.040.010000...0000000000
133.06.00.00.040.010000...0000000000
\n", + "

5 rows × 102 columns

\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + " \n", + " " + ] + }, + "metadata": {}, + "execution_count": 8 + } + ], + "source": [ + "X_train, X_test = pd.get_dummies(X_train), pd.get_dummies(X_test)\n", + "X_train = X_train.drop(columns=['sex_Female'])\n", + "X_test = X_test.drop(columns=['sex_Female'])\n", + "X_train.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "H_YLKxL2bUm3" + }, + "source": [ + "The protected attribute information is also replicated in the labels:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "0lZcS0bmbUm3", + "outputId": "5ffb9347-e863-4725-dc51-be4f20ade173", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "race sex\n", + "1 1 0\n", + " 0 1\n", + " 1 1\n", + " 1 0\n", + " 1 0\n", + "dtype: int64" + ] + }, + "metadata": {}, + "execution_count": 9 + } + ], + "source": [ + "y_train.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oTUVhotNbUm3" + }, + "source": [ + "### Running metrics" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lQN64AqmbUm3" + }, + "source": [ + "With the data in this format, we can easily train a scikit-learn model and get predictions for the test data:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "GBhbEdWrbUm3", + "outputId": "f22a975f-0f2f-4a78-b4ca-d4b3fae1c870", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.8453600648632712\n" + ] + } + ], + "source": [ + "y_pred = LogisticRegression(solver='liblinear', random_state=1234).fit(X_train, y_train).predict(X_test)\n", + "lr_acc = accuracy_score(y_test, y_pred)\n", + "print(lr_acc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zGU6upwubUm4" + }, + "source": [ + "We can assess how close the predictions are to equality of odds.\n", + "\n", + "`average_odds_error()` computes the (unweighted) average of the absolute values of the true positive rate (TPR) difference and false positive rate (FPR) difference, i.e.:\n", + "\n", + "$$ \\tfrac{1}{2}\\left(|FPR_{D = \\text{unprivileged}} - FPR_{D = \\text{privileged}}| + |TPR_{D = \\text{unprivileged}} - TPR_{D = \\text{privileged}}|\\right) $$" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "1wRHNc1GbUm4", + "outputId": "9165efe3-0e50-4f16-e984-2ffb0ae79f7b", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.09356509680536546\n" + ] + } + ], + "source": [ + "lr_aoe = average_odds_error(y_test, y_pred, prot_attr='sex')\n", + "print(lr_aoe)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vzJDTOXebUm4" + }, + "source": [ + "### Grid Search" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IuG-VsAnbUm4" + }, + "source": [ + "Choose a base model for the candidate classifiers. Base models should implement a fit method that can take a sample weight as input. For details refer to the docs. " + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "w8zUuKZcbUm4" + }, + "outputs": [], + "source": [ + "estimator = LogisticRegression(solver='liblinear', random_state=1234)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vScj2bAhbUm4" + }, + "source": [ + "Determine the columns associated with the protected attribute(s). Grid search can handle more than one attribute but it is computationally expensive. A similar method with less computational overhead is exponentiated gradient reduction, detailed at [examples/sklearn/demo_exponentiated_gradient_reduction_sklearn.ipynb](sklearn/demo_exponentiated_gradient_reduction_sklearn.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "zEthhIr4bUm4" + }, + "outputs": [], + "source": [ + "prot_attr = 'sex_Male'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "F3cM_dipbUm4" + }, + "source": [ + "Search for the best classifier and observe test accuracy. Other options for `constraints` include \"DemographicParity\", \"TruePositiveRateParity\", \"FalsePositiveRateParity\", and \"ErrorRateParity\"." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "Jxy3XOwubUm4", + "outputId": "a482a204-9b74-4415-83bb-53c118fea0cf", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.8458760227021449\n" + ] + } + ], + "source": [ + "np.random.seed(0) #need for reproducibility\n", + "grid_search_red = GridSearchReduction(prot_attr=prot_attr, \n", + " estimator=estimator, \n", + " constraints=\"EqualizedOdds\",\n", + " grid_size=20,\n", + " drop_prot_attr=False)\n", + "grid_search_red.fit(X_train, y_train)\n", + "gs_acc = grid_search_red.score(X_test, y_test)\n", + "print(gs_acc)\n", + "\n", + "#Check if accuracy is comparable\n", + "assert abs(lr_acc-gs_acc)<0.03" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "73lRvF3CbUm4", + "outputId": "363d6088-d895-49c2-ed23-55685247f70b", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.05787745779072595\n" + ] + } + ], + "source": [ + "gs_aoe = average_odds_error(y_test, grid_search_red.predict(X_test), prot_attr='sex')\n", + "print(gs_aoe)\n", + "\n", + "#Check if average odds error improved\n", + "assert gs_aoe