# Bounty Hunting Project Notebook

Author: Francesca Marini

# Set Up

First, we load in some helper files.

In [None]:
# pip installs
!pip install folktables
!pip install dill

In [2]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# directory structure
%cd drive/MyDrive/CIS\ 523/Bias\ Bounty/ClassBiasBounties/
!pwd
%ls

/content/drive/MyDrive/CIS 523/Bias Bounty/ClassBiasBounties
/content/drive/MyDrive/CIS 523/Bias Bounty/ClassBiasBounties
 19-Project.ipynb       [0m[01;34mdata[0m/                     [01;34m__pycache__[0m/
 acsData.py             [01;34mdontlook[0m/                [01;34m'__pycache__ (1)'[0m/
 bountyHuntData.py      groupID_ghSubmission.py   README.txt
 bountyHuntWrapper.py   groupSettings.py          updater.py
 cscUpdater.py          model.py                  verifier.py


In [53]:
# imports 
import numpy as np
from sklearn import metrics
from sklearn.tree import DecisionTreeClassifier
import pandas as pd
from folktables import ACSDataSource, ACSEmployment, ACSIncome, ACSPublicCoverage, ACSMobility, ACSTravelTime
from sklearn.model_selection import train_test_split
import sys
import warnings
warnings.filterwarnings("ignore")
import dill as pickle
from pprint import pprint
import json

## Model

In [5]:
# Data structure for the decision list
# Simple decision lists have been tested. pointer version has not been. ###

class DecisionlistNode:
    """

        Node of decision lists are objects with the following form

            -----------------        ------------------
            |   predicate   | --0--> | right_child Node|
            -----------------        ------------------
                    |
                    1
                    |
                    v
                  Leaf
    """

    def __init__(self, predicate=None, leaf=None, right_child=None):
        """

        predicate: function from X to {0,1}. If not defined in input to initialization, will be defined as function that
         always returns 1
        leaf: classification function from X to {0,1}. If not defined in the initialization, will be defined as function
         that always returns 1.
        right_child: Node object
        """
        self.predicate = predicate
        self.leaf = leaf
        self.right_child = right_child

        if self.predicate is None:
            self.predicate = lambda x: 1

        if self.leaf is None:
            self.leaf = lambda x: 1

class DecisionList:
    """
    A simple decision list object with standard traversal

            -----------------        ------------------                  ------------------
            |   predicate 1  | --0--> |   predicate 2  | --0--> ... ---> |   predicate n  |  ---> 0
            -----------------        ------------------                  ------------------
                    |						  |                                  |
                    1                         1                                  1
                    |                         |                                  |
                    v                         v                                  v
                  Leaf 1                    Leaf 2                             Leaf n

    """

    def __init__(self, initial_model):
        self.head = DecisionlistNode(leaf=initial_model)
        self.curr_node = self.head
        self.predicates = [self.head.predicate]  # list of predicate functions for ease of access
        self.leaves = [self.head.leaf]  # list of leaf functions for eas of access

    def predict(self, x):
        """
        Predicts the label of x according to traversal of the decision list.
        """
        if self.curr_node.predicate(x) == 1:
            # set output to the leaf function evaluated at x
            out = self.curr_node.leaf(x)
            # reset current node to head
            self.curr_node = self.head
            return out

        elif self.curr_node.right_child is not None:
            self.curr_node = self.curr_node.right_child
            return self.predict(x)
        else:
            # reset current node to head and output 0 with a warning
            print("Warning: reached end of pDL with no predicate succeeding, returning 0.")
            return 0

    def prepend(self, node):
        """
        Prepends a new node to the head of the decision list.

        node we want to                          current
        prepend                               decision list
        ---------                              -------------      -----------------------
        | node | --> node.right_child = None    | self.head | ---> | self.head.right_child| --> ...
        --------                               -------------      -----------------------
            |						               |
            1                                      1
            |                                      |
            v                                      v
          leaf                                self.head.leaf
        """
        node.right_child = self.head
        self.predicates.append(node.predicate)
        self.leaves.append(node.leaf)
        self.head = node
        self.curr_node = self.head

class PointerDecisionListNode(DecisionlistNode):
    """

        Node of pointer decision lists are objects with the following form if catch_node = False

            -----------------        ------------------
            |   predicate   | --0--> | right_child Node|
            -----------------        ------------------
                    |
                    1
                    |
                    v
                  Leaf

        Or, if catch_node = True:

            -----------------        ------------------
            |   predicate   | --0--> | right_child Node|
            -----------------        ------------------
                    |
                    1
                    |
                    v
            -----------------
            | right_main_node |
            -----------------
    """

    def __init__(self, predicate=None, leaf=None, right_child=None, catch_node=False, right_main_node=None,
                 pred_name=None):
        DecisionlistNode.__init__(self, predicate=predicate, leaf=leaf, right_child=right_child)

        self.catch_node = catch_node
        self.right_main_node = right_main_node
        self.predName = pred_name

        if catch_node:
            # assert(right_main_node.right_child is not None)
            self.leaf = None

class PointerDecisionList:
    """
                     -------------------------------------------------------------------------------
                    |                         |                         						   |
                    1						  1                         						   v
                      |                         |                  right_main_node		    right_main_node.right_child
            -----------------        ------------------         ------------------          ---------------           -----------------
            |   predicate 1  | --0--> |   predicate 2  | --0--> |   predicate 3  |  --0-->  | predicate 3 |    ...    |  predicate n  |  ---> 0
            -----------------        ------------------         ------------------          ---------------           ------------------
                catch_node    			 catch_node	                    |                          |                          |
                                                                        1                          1                          1
                                                                        |                          |                          |
                                                                        v                          v                          v
                                                                           Leaf                       Leaf                        Leaf

    """

    def __init__(self, initial_model, all_groups=None):
        """
        initial_model is some fit function e.g. a .fit method output by scikit.learn,
        which takes as input a dataframe x and outputs labels y for those values of x.

        If you already know the groups we'll be passing in, they can be specified with all_groups
        """
        self.all_groups = None
        if all_groups is None:
            all_groups = []
        self.head = PointerDecisionListNode(leaf=initial_model, pred_name='Total')
        self.curr_node = self.head
        self.predicates = [self.head.predicate]  # list of predicate functions for ease of access
        self.pred_names = ['Total']
        self.leaves = [self.head.leaf]  # list of leaf functions for ease of access.

        # keeping track of the group errors so far here (including groups that haven't been introduced yet, which
        # wouldn't happen IRL!) for computational efficiency.
        # only relevant if we already know the groups we're computing things on.
        if all_groups:
            n = len(all_groups)
            self.test_errors = np.empty(shape=(n, n))
            self.test_errors[:] = np.NaN  # filling with NaNs to avoid confusion
            self.train_errors = np.empty(shape=(n, n))
            self.train_errors[:] = np.NaN  # filling with NaNs to avoid confusion
        else:
            self.test_errors = []
            self.train_errors = []

        # keeping track of the number of rounds so far
        self.num_rounds = 0

        # keeping track of the final node belonging to each update so that we can point to those
        self.update_nodes = [self.head]
        self.track_rejects = [1]
        self.update_node_indices_tracking_rejects = [0]

    def init_groups(self):
        n = len(self.all_groups)
        self.test_errors = np.empty(shape=(n, n))
        self.test_errors[:] = np.NaN  # filling with NaNs to avoid confusion
        self.train_errors = np.empty(shape=(n, n))
        self.train_errors[:] = np.NaN  # filling with NaNs to avoid confusion

    def predict(self, x):
        """
        Predicts the label of x according to traversal of the pointer decision list.
        Note that this is NOT a vectorized function as it is currently written: only single values of
        x may be passed in at a time. This is very much not ideal for a pandas dataframe.
        """
        # if catch_node evaluates to true, move to the right child of the main node and continue traversal
        if self.curr_node.catch_node is True:
            if self.curr_node.predicate(x) == 1:
                self.curr_node = self.curr_node.right_main_node
                return self.predict(x)
            else:
                self.curr_node = self.curr_node.right_child
                return self.predict(x)
        # otherwise, traverse as usual
        else:
            if self.curr_node.predicate(x) == 1:
                # set output to the leaf function evaluated at x
                # because we are passing in single rows of a larger dataframe at a time (which is not good in long run)
                # have to reshape x to be something that predict will accept.
                out = self.curr_node.leaf(np.array(x).reshape(1, -1))
                # reset current node to head
                self.curr_node = self.head
                # output
                return out[0]
            elif self.curr_node.right_child is not None:
                self.curr_node = self.curr_node.right_child
                return self.predict(x)
            else:
                # reset current node to head and output 0 with a warning
                print("Warning: reached end of pDL with no predicate succeeding, returning 0.")
                return 0

    def prepend(self, node):
        """
        Prepends a new node to the head of the decision list
        """
        if node.catch_node is True:
            # if prepending to another catch_node
            if self.head.catch_node is True:
                node.right_child = self.head
                self.head = node
                self.curr_node = self.head
            else:
                node.right_child = self.head
                self.head = node
                self.curr_node = self.head
        else:
            node.right_child = self.head
            self.predicates.append(node.predicate)
            self.pred_names.append(node.predName)
            self.leaves.append(node.leaf)
            self.head = node
            self.curr_node = self.head

    def pop(self):
        """
        Removes the head node from the decision list
        """
        self.head = self.head.right_child
        self.curr_node = self.head

## Verifier

In [6]:
def verify(curr_model, test_x, test_y, h, g, alpha=0.001):
    """
    Updates the curr_model model object to incorporate the (h,g).

    Inputs:
    curr_model: model object that is to be updated
    holdout_data: data to test the proposed new model on
    h_t: new model
    g_t: function from X -> {0,1} which returns 1 if x is in identified group and 0 else.

    Return: None
    """
    # pull the x and y values that belong to g
    indices = test_x.apply(g, axis=1) == 1
    xs = test_x[indices]
    ys = test_y[indices]

    # get predicted ys from current model and proposed h
    curr_model_preds = xs.apply(curr_model.predict, axis=1)
    h_preds = h(xs)

    # measure the error of current model and proposed h
    curr_model_error = metrics.zero_one_loss(ys, curr_model_preds)
    h_error = metrics.zero_one_loss(ys, h_preds)

    # determine if (g,h) should be accepted or not
    group_weight = sum(indices) / float(len(test_x))
    improvement = curr_model_error - h_error

    if group_weight * improvement >= alpha:
        return True

    return False

def is_proposed_group_good(curr_model, test_x, test_y, h, g, train_x, train_y):
    """
    Checks that the group error of g on h isn't worse than g on f.
    Doesn't worry about weight of group
    Inputs:
    curr_model: model object that is to be updated
    holdout_data: data to test the proposed new model on
    h_t: new model
    g_t: function from X -> {0,1} which returns 1 if x is in identified group and 0 else.

    Return: None
    """
    # pull the x and y values that belong to g
    indices = test_x.apply(g, axis=1) == 1
    xs = test_x[indices]
    ys = test_y[indices]

    # get predicted ys from current model and proposed h
    curr_model_preds = xs.apply(curr_model.predict, axis=1)
    h_preds = h(xs)

    # measure the error of current model and proposed h
    curr_model_error = metrics.zero_one_loss(ys, curr_model_preds)
    h_error = metrics.zero_one_loss(ys, h_preds)

    print("Error of current model on proposed group: %s" % curr_model_error)
    print("Error of h trained on proposed group: %s" % h_error)
    print("Group size in test set: %s" % len(xs))
    print("Group weight in test set: %s" % (len(xs)/len(test_x)))

    ## REMOVE BEFORE DEPLOYMENT
    t_indices = train_x.apply(g, axis=1) == 1
    t_xs = train_x[t_indices]
    t_ys = train_y[t_indices]

    # get predicted ys from current model and proposed h
    t_curr_model_preds = t_xs.apply(curr_model.predict, axis=1)
    t_h_preds = h(t_xs)

    # measure the error of current model and proposed h
    t_curr_model_error = metrics.zero_one_loss(t_ys, t_curr_model_preds)
    t_h_error = metrics.zero_one_loss(t_ys, t_h_preds)

    print("Training Error of current model on proposed group: %s" % t_curr_model_error)
    print("Training Error of h trained on proposed group: %s" % t_h_error)
    print("Group size in training set: %s" % len(t_xs))
    print("Group weight in training set: %s" % (len(t_xs)/len(train_x)))

    if h_error >= curr_model_error:
        return False

    else:
        return True

def is_proposed_group_good_csc(curr_model, test_x, test_y, h, g):
    """
    Checks that the group error of g on h isn't worse than g on f.
    Doesn't worry about weight of group
    Inputs:
    curr_model: model object that is to be updated
    holdout_data: data to test the proposed new model on
    h_t: new model
    g_t: function from X -> {0,1} which returns 1 if x is in identified group and 0 else.

    Return: None
    """
    # pull the x and y values that belong to g
    indices = test_x.apply(g, axis=1) == 1
    xs = test_x[indices]
    ys = test_y[indices]

    # get predicted ys from current model and proposed h
    curr_model_preds = xs.apply(curr_model.predict, axis=1)

    # some dumb reshaping to mesh with how h takes inputs
    def _h(x):
        _x = np.array(x).reshape(1, -1)
        return h(_x)[0]

    h_preds = xs.apply(_h, axis=1)

    # measure the error of current model and proposed h
    curr_model_error = metrics.zero_one_loss(ys, curr_model_preds)
    h_error = metrics.zero_one_loss(ys, h_preds)

    print("Error of current model on proposed group: %s" % curr_model_error)
    print("Error of h trained on proposed group: %s" % h_error)

    if h_error >= curr_model_error:
        return False

    else:
        return True

def check_group_sizes(test_x, group):
    # Returns True if the group has more than 0 elements in test_x
    indices = test_x.apply(group, axis=1) == 1
    if sum(indices) >= 1:
        return True
    else:
        return False

## CSC Updater

In [7]:
# Global parameter that determines depth of the decision tree trained
dt_depth = 7

# noinspection SpellCheckingInspection
def training_r0_r1(train_x, true_ys, pred_ys):
    """
    The cost sensitive classification approach to the bias bounties problems takes in the training data and the
    predictions of the current model and builds two models, one for the cost of predicting 0 for a datapoint
    and one for the cost of predicting 1. This function generates these two functions, r0 and r1. It also
    outputs some additional information for debugging and visualization purposes.

    Inputs:
    train_x: the training data
    train_y: the labels for the training data
    pred_y: the current model's predictions on the training data

    """
    # training r0
    # initialize cost vector to 0s
    cost_pred0 = np.zeros(len(true_ys))
    # find the locations where the current model predicts 1
    curr_predicts_1 = np.where(pred_ys == 1)
    # label the points where the current model and the true labels are both 1 w cost of 1
    cost_pred0[np.intersect1d(curr_predicts_1, np.where(true_ys == 1))] = 1
    # label the points where current model predicts 1 and true label is 0 w cost of -1
    cost_pred0[np.intersect1d(curr_predicts_1, np.where(true_ys == 0))] = -1
    # learn a regression model
    r0 = DecisionTreeClassifier(max_depth=dt_depth, random_state=0)
    r0.fit(train_x, cost_pred0)

    # training r1. Same thing but flipped
    cost_pred1 = np.zeros(len(true_ys))
    curr_predicts_0 = np.where(pred_ys == 0)
    cost_pred1[np.intersect1d(curr_predicts_0, np.where(true_ys == 0))] = 1
    cost_pred1[np.intersect1d(curr_predicts_0, np.where(true_ys == 1))] = -1
    r1 = DecisionTreeClassifier(max_depth=dt_depth, random_state=0)
    r1.fit(train_x, cost_pred1)

    return [r0, r1, np.array([true_ys, pred_ys, cost_pred0, cost_pred1])]

def _g(r0, r1):
    """
    Next in the CSC approach, we build a function g which decides on input x whether or not it should be included in the
    new group g that our algorithmic bounty hunter will be updating the model with or if it should send that point
    to the previous model.It does this by checking what the values of our two cost functions are at input x, and
    returns 1 if these costs are both negative.Otherwise, it returns 1. This function _g constructs such a g given
    the two cost functions r0 and r1, and outputs the g.
    """

    def g(x):
        x = np.array(x).reshape(1, -1)
        cost0 = r0.predict(x)
        cost1 = r1.predict(x)
        if cost0 < 0 or cost1 < 0:
            return 1
        else:
            return 0

    return g

# noinspection SpellCheckingInspection
def _h(r0, r1):
    """
    Next, we build a model h for the elements in g. We return a 1 or True for them if the cost of predicting True is
    less than the cost of predicting 1, and otherwise return 0 or False. h's return values are in brackets to match
    the way that scikit learn usually is done over batches of data points (which we cannot do here unless we are sure
    that all of the datapoints will end up with the same path through our pointer decision list.)
    """

    def h(x):
        cost0 = r0.predict(x)
        cost1 = r1.predict(x)
        if cost0 < cost1:
            return [False]  # have to stick these in brackets to match syntax of sklearn prediction functions
        else:
            return [True]

    return h

#################################################################################################
# The rest of the functions in this file are for generating the actual updates to our model.
# They are slight variations on the functions in Updater.py, because while there we were running
# experiments on groups that were pre-defined, here we are generating the groups on the fly.
##################################################################################################

# noinspection SpellCheckingInspection
def measure_group_errors(model, X, y):
    """
    Helper function that measures the group errors of groups defined in model over test data X with true
    labels y

    Inputs:
    model: DecisionList or PointerDecisionList object
    X: n x m dataframe of test data
    y: dataframe of n true labels (or optimal predictions) of points in X
    """
    indices = [X.apply(g, axis=1) == 1 for g in model.predicates]
    xs = [X[i] for i in indices]
    ys = [y[i] for i in indices]
    group_errors = []
    for i in range(len(model.predicates)):
        pred_ys = xs[i].apply(model.predict, axis=1)
        group_errors.append(metrics.zero_one_loss(np.array(ys[i]), np.array(pred_ys)))
    return group_errors

# noinspection SpellCheckingInspection
def measure_group_error(model, group, X, y):
    """
    Function to measure group errors of a specific group

    NOTE THIS WILL BREAK IF YOU PASS IN AN EMPTY GROUP
    """

    indices = X.apply(group, axis=1) == 1
    xs = X[indices]
    ys = y[indices]
    pred_ys = xs.apply(model.predict, axis=1).to_numpy()
    group_errors = metrics.zero_one_loss(ys, pred_ys)

    return group_errors


# noinspection SpellCheckingInspection
def all_group_errors(curr_model, group_pred, X, y):  # gets errors for each model for a given group
    # go through all the updates and get each group error for that group
    true_head = curr_model.head
    errs = []
    for i in range(len(curr_model.update_nodes)):
        curr_model.head = curr_model.update_nodes[i]
        curr_model.curr_node = curr_model.head
        errs.append(measure_group_error(curr_model, group_pred, X, y))
    curr_model.head = true_head
    curr_model.curr_node = true_head
    return errs

def update_errors(model, group, train_x, train_y, test_x, test_y):
    # measure the group's train and test error over all previous versions of the PDF
    # then append as column to the model's train and test error arrays

    # first, store where the current model is at
    true_head = model.head
    test_errs, train_errs = [], []
    # next, iterate through all previous models by changing the head node
    for i in range(len(model.update_nodes)):
        model.head = model.update_nodes[i]
        model.curr_node = model.head
        # calculate the train and test errors
        train_errs.append(measure_group_error(model, group, train_x, train_y))
        test_errs.append(measure_group_error(model, group, test_x, test_y))
    # reset the current node and head to the newest model
    model.head = true_head
    model.curr_node = true_head

    # now, append the train and test errors to their trackers.
    [model.test_errors[i].append(test_errs[i]) for i in range(len(model.test_errors))]
    [model.train_errors[i].append(train_errs[i]) for i in range(len(model.train_errors))]

    return None

def find_next_problem_node(curr_model, new_errors):
    """
    Finds a node in the PDL which, after an update, now has worse group error than it had previously.

    curr_model: the current PDL
    new_errors: the errors introduced on each group by the newest update

    Return: [returnIndex, return_model] where returnIndex is the index of the node that had worse error and
    return_model is the model that performed best on that node.
    """
    # initialize the returns to -1 so we can easily tell if they weren't updated
    return_index = -1
    return_model = -1

    for node_index in range(len(curr_model.update_nodes)):
        # pull out the column of errors that corresponds to this nodes error at each round of updates
        nodes_errors = [curr_model.test_errors[i][node_index] for i in range(len(curr_model.test_errors))]
        # find the round that minimizes the error
        indices_min_round = np.nanargmin(nodes_errors)
        # grab the value of the minimal error
        min_val = curr_model.test_errors[indices_min_round][node_index]
        # check if the minimal value is less than the new model's error on that group
        if min_val < new_errors[node_index]:
            # if the min error of a previously found group is better, report the node and
            # the model that was best for it, and break out of the loop.
            return_index = node_index
            return_model = indices_min_round
            break

    return [return_index, return_model]

# noinspection SpellCheckingInspection
def iterative_update(curr_model, h_t, g_t, train_X, train_y, test_X, test_y, group_name):
    """
    Updates the curr_model to incorporate (g_t, h_t) in a way that preserves group error
    monotonicity over the sample data X with labels y

    Inputs:
    curr_model: PointerDecisionList object that is to be updated
    h_t: new model that performs better than curr_model on points for which g_t returns 1
    g_t: function from X -> {0,1} which returns 1 if x is in identified group and 0 else.

    Return: None
    """
    # add a round to the round tracker
    curr_model.num_rounds += 1

    # prepend the node
    new_node = PointerDecisionListNode(predicate=g_t, leaf=h_t, pred_name=group_name)
    curr_model.prepend(new_node)

    # add a column to the train and test errors to track how the new group did at all previous rounds
    print("updating errors")
    update_errors(curr_model, g_t, train_X, train_y, test_X, test_y)

    # measure new group errors and compare to old
    print("getting new errors")
    new_errors = measure_group_errors(curr_model, test_X, test_y)

    # recursively check for new errors

    [problem_node_index, problem_node_model_index] = find_next_problem_node(curr_model, new_errors)
    problem_node_tracking = []

    while True:
        print("prob node", problem_node_index)
        print("new errs", new_errors)

        if problem_node_model_index == -1:
            break

        else:
            # add node to tracker so we can visualize PDL
            problem_node_tracking.append([curr_model.pred_names[problem_node_index], problem_node_model_index])
            # build a node that points to that model
            new_node = PointerDecisionListNode(predicate=curr_model.predicates[problem_node_index], catch_node=True,
                                               right_main_node=curr_model.update_nodes[problem_node_model_index])
            # prepend that node to the model
            curr_model.prepend(new_node)
            # the group of new model will change w new node appended so check those
            new_errors = measure_group_errors(curr_model, test_X, test_y)
            # check for further/new problem nodes
            [problem_node_index, problem_node_model_index] = find_next_problem_node(curr_model, new_errors)

    if new_errors is None:
        curr_model.pop()  # remove the new model from the head of the pDL
        return "Could not calculate all group errors and cannot update"

    # now that all of the updates have happened, add the final node of the update to the model
    curr_model.update_nodes.append(new_node)

    curr_model.train_errors.append(measure_group_errors(curr_model, train_X, train_y))
    curr_model.test_errors.append(measure_group_errors(curr_model, test_X, test_y))

    return [curr_model.train_errors, curr_model.test_errors]

## Bounty Hunt Wrapper

In [8]:
def build_model(x, y, group_function, dt_depth):
    print("building h")
    # learn the indices first, since this is an inefficient operation
    indices = x.apply(group_function, axis=1) == 1

    # then pull the particular rows from the dataframe
    training_xs = x[indices]
    training_ys = y[indices]

    dt = DecisionTreeClassifier(max_depth=dt_depth, random_state=0)  # setting random state for replicability
    dt.fit(training_xs.values, training_ys) # added .values to get warning to not print out
    print("finished building h")
    return dt.predict

def build_initial_pdl(initial_model, train_x, train_y, validation_x, validation_y):
    f = PointerDecisionList(initial_model.predict) #model.PointerDecisionList(initial_model.predict)
    # manually stick in the train and test errors because I'm dumb and have codependencies in the files
    f.test_errors.append(measure_group_errors(f, validation_x, validation_y)) #cscUpdater.measure_group_errors(f, validation_x, validation_y))
    f.train_errors.append(measure_group_errors(f, train_x, train_y)) #cscUpdater.measure_group_errors(f, train_x, train_y))
    return f

def verify_size(x, group):
    # helper function that checks that the discovered group isn't too small to run on
    indices = x.apply(group, axis=1) == 1
    xs = x[indices]
    if len(xs) == 0:
        return False
    else:
        return True

def run_checks(f, validation_x, validation_y, g, h, train_x, train_y):
    size_check = verify_size(validation_x, g)
    if not size_check:  # Remove before deployment:
        print("Group has 0 weight in test set")
        indices = train_x.apply(g, axis=1) == 1
        xs = train_x[indices]
        ys = train_y[indices]

        # get predicted ys from current model and proposed h
        curr_model_preds = xs.apply(f.predict, axis=1)
        h_preds = h(xs)

        # measure the error of current model and proposed h
        curr_model_error = metrics.zero_one_loss(ys, curr_model_preds)
        h_error = metrics.zero_one_loss(ys, h_preds)

        print("Training Error of current model on proposed group: %s" % curr_model_error)
        print("Training Error of h trained on proposed group: %s" % h_error)
        print("Group size in training set: %s" % len(xs))
        print("Group weight in training set: %s" % (len(xs)/len(train_x)))
    if size_check:
        improvement_check = is_proposed_group_good(f, validation_x, validation_y, h, g, #verifier.is_proposed_group_good(f, validation_x, validation_y, h, g,
                                                            train_x, train_y)
        if improvement_check:
            print("Passed checks.")
            return True
        else:
            print("Failed improvement check.")
            return False
    else:
        print("Failed group size check.")
        return False

def measure_group_error(model, group, X, y):
    """
    Function to measure group errors of a specific group

    NOTE THIS WILL BREAK IF YOU PASS IN AN EMPTY GROUP
    """

    indices = X.apply(group, axis=1) == 1
    xs = X[indices]
    ys = y[indices]
    pred_ys = xs.apply(model.predict, axis=1).to_numpy()
    group_errors = metrics.zero_one_loss(ys, pred_ys)

    return group_errors

def run_updates(f, g, h, train_x, train_y, validation_x, validation_y, group_name="g"):
    iterative_update(f, h, g, train_x, train_y, validation_x, validation_y, group_name) #cscUpdater.iterative_update(f, h, g, train_x, train_y, validation_x, validation_y, group_name)

Next, we load in the data. You should use `train_x` and `train_y` to train your models. The second set of data (`validation_x` and `validation_y`) is for testing your models, to ensure that you aren't overfitting. It is also what will be passed to the updater in order to determine if a proposed update should be accepted and if repairs are needed. Since you have access to this data, you could overfit to it and get a bunch of updates accepted. However, a) we'll be able to tell you did this and b) your updates will fail on the holdout set that only we have access to, so doing this is not in your best interest.

## Group Settings

In [9]:
group_id = 19
acs_task = 'income'
acs_states = ['CA']

## Bounty Hunt Data

In [10]:
sys.path.append("..")
sys.path.append(".")

# I did not edit any of the logic of the code, and I did not really look at what it was doing. 
# for the sake of making everything work in Colab though, I did slightly edit the function headers to take in the group settings parameters passed in

#### DO NOT LOOK AT THIS FILE PLEASE :) FOR RESEARCHY REASONS WE NEED TO PRETEND YOU DON'T HAVE ACCESS TO IT BUT
#### IRA COULDN'T FIGURE OUT HOW TO ENCRYPT IT IN A WAY THAT WOULD WORK ON EVERYONE'S SYSTEMS

def get_data(acs_task, acs_states):
    acs_task = acs_task #groupSettings.acs_task
    acs_states = acs_states #groupSettings.acs_states
    test_size = 0.3
    acs_year = 2018
    acs_horizon = '1-Year'
    acs_survey = 'person'
    row_start = 0
    row_end = -1
    col_start = 0
    col_end = -1
    data_source = ACSDataSource(survey_year=acs_year, horizon=acs_horizon, survey=acs_survey)
    columns, features, label, group = [],[],[],[]
    # this pulls in the raw data
    acs_data = data_source.get_data(states=acs_states, download=True)
    # columns of the feature vector
    if acs_task == 'employment':
        # label is True if adult is employed

        # columns of the feature vector
        columns = [
            'AGEP',
            'SCHL',
            'MAR',
            'RELP',
            'DIS',
            'ESP',
            'CIT',
            'MIG',
            'MIL',
            'ANC',
            'NATIVITY',
            'DEAR',
            'DEYE',
            'DREM',
            'SEX',
            'RAC1P',
        ]
        features, label, group = ACSEmployment.df_to_numpy(acs_data)
    elif acs_task == 'income':
        # label is True if US working adults’ yearly income is above $50,000

        # columns of the feature vector
        columns = [
            'AGEP',
            'COW',
            'SCHL',
            'MAR',
            'OCCP',
            'POBP',
            'RELP',
            'WKHP',
            'SEX',
            'RAC1P',
        ]
        features, label, group = ACSIncome.df_to_numpy(acs_data)
    elif acs_task == 'public_coverage':
        # label True if low-income individual, not eligible for Medicare, has coverage from public health insurance.

        # coluns of the feature vector
        columns = [
            'AGEP',
            'SCHL',
            'MAR',
            'SEX',
            'DIS',
            'ESP',
            'CIT',
            'MIG',
            'MIL',
            'ANC',
            'NATIVITY',
            'DEAR',
            'DEYE',
            'DREM',
            'PINCP',
            'ESR',
            'ST',
            'FER',
            'RAC1P',
        ]
        features, label, group = ACSPublicCoverage.df_to_numpy(acs_data)
    elif acs_task == 'mobility':

        columns = [
            'AGEP',
            'SCHL',
            'MAR',
            'SEX',
            'DIS',
            'ESP',
            'CIT',
            'MIL',
            'ANC',
            'NATIVITY',
            'RELP',
            'DEAR',
            'DEYE',
            'DREM',
            'RAC1P',
            'GCL',
            'COW',
            'ESR',
            'WKHP',
            'JWMNP',
            'PINCP',
        ]
        # label True if a young adult moved addresses in the last year.
        features, label, group = ACSMobility.df_to_numpy(acs_data)
    elif acs_task == 'travel_time':

        columns = [
            'AGEP',
            'SCHL',
            'MAR',
            'SEX',
            'DIS',
            'ESP',
            'MIG',
            'RELP',
            'RAC1P',
            'PUMA',
            'ST',
            'CIT',
            'OCCP',
            'JWTR',
            'POWPUMA',
            'POVPIP',
        ]
        # label True if a working adult has a travel time to work of greater than 20 minutes
        features, label, group = ACSTravelTime.df_to_numpy(acs_data)
    else:
        print("Invalid task")

    features = features[row_start:row_end, col_start:col_end]
    label = label[row_start:row_end]
    group = group[row_start:row_end]
    div = 2
    X_train, all_test_X, y_train, all_test_y, group_train, group_test = train_test_split(features, label, group,
                                                                                 test_size=test_size, random_state=10)
    X_train = np.hstack((X_train, group_train[:, np.newaxis]))
    all_test_X = np.hstack((all_test_X, group_test[:, np.newaxis]))

    # making the training data into an actual pandas dataframe so that we can actually read it and see what things mean.
    X_train = pd.DataFrame(X_train, columns=columns)
    all_test_X = pd.DataFrame(all_test_X, columns=columns)

    validation_x = all_test_X.iloc[:len(all_test_X)//div]
    validation_y = all_test_y[:len(all_test_y)//div]

    return [X_train, y_train, validation_x, validation_y]

In [11]:
[train_x, train_y, validation_x, validation_y] = get_data(acs_task=acs_task, acs_states=acs_states) #bountyHuntData.get_data()

# Build Decision Stump Initial Model

The model that you'll be building off of is a decision stump, i.e. a very stupid decision list with only one node. **Warning: do not rerun the next code block unless you want to completely restart building your PDL, as it will re-initialize it to just the decision stump!**

In [12]:
initial_model = DecisionTreeClassifier(max_depth = 1, random_state=0)

In [13]:
initial_model.fit(train_x.values, train_y) # added .values to train_x and this fixed a warning printout I was getting

DecisionTreeClassifier(max_depth=1, random_state=0)

In [14]:
f = build_initial_pdl(initial_model, train_x, train_y, validation_x, validation_y) #bountyHuntWrapper.build_initial_pdl(initial_model, train_x, train_y, validation_x, validation_y)

# Bounty Hunting

Here's where the bulk of the work you'll be doing will live. Your job is to generate groups g such that there is some h that does better than the current model f on that group. Here, we generate an example group function, which identifies African American individuals.

## Defining Groups

You might also imaging making a group function that tries to learn what regions the current algorithm performs poorly on in an adaptive way, instead of just guessing ad-hoc that it will do poorly on a particular subgroup. In order to generate such a g, you will need to generate a constructor that takes as input a current model and the training data, and outputs a function g. A template for doing this is provided below. The example version returns a very silly function g which looks at the predictions the current model makes, and returns a group function where the group it has learned is all the points that the PDL labels as a 1. It completely ignores the true labels (train_y), so is probably not a very good group function.

### Adaptive Groups

In [None]:
'''
def g_(f, train_x, train_y):
    # f is the current PDL
    preds = train_x.apply(f.predict, axis=1)
    merged_train = train_x.copy()
    merged_train['train_y'] = train_y
    merged_train['preds'] = preds
    # get all the values that we currently mislabel - could use this to find groups that have high error and try those
    mistakes = merged_train[merged_train['preds'] != merged_train['train_y']]
    xs = mistakes.drop(columns=['train_y', 'preds'])
    ys = mistakes['train_y']
    #xs = train_x[preds == 1] -- old lines
    #ys = train_y[preds == 1] -- old lines
    dt = DecisionTreeClassifier(max_depth = 10, random_state=0)
    dt.fit(xs, ys)
    def g(x):
        # g should take as input a SINGLE x and return 0 or 1 for it.
        # if we call dt.predict on x it will break because the dimensions of x are wrong, so we have to reshape it and reshape the output.
        # this is not particularly efficient, so if you have better ways of doing this go for it. :)
        y = dt.predict(np.array(x).reshape(1, -1))
        return y[0]
    return g
'''

# if you wanted to build a particular g using the above, you could use the following line.
#g = g_(f, train_x, train_y)

In [15]:
# possible group definitions

# class of worker
def for_profit_worker(x):
  if x['COW'] == 1:
      return 1
  else:
      return 0

def non_profit_worker(x):
  if x['COW'] == 2:
      return 1
  else:
      return 0

def local_govt_worker(x):
  if x['COW'] == 3:
      return 1
  else:
      return 0

def self_employed_worker(x):
  if x['COW'] == 6:
      return 1
  else:
      return 0

def cow_minorities(x):
  if x['COW'] == 4:
      return 1
  elif x['COW'] == 5:
      return 1
  elif x['COW'] == 7:
      return 1
  elif x['COW'] == 8:
      return 1
  else:
      return 0

# marital status
def married(x):
  if x['MAR'] == 1:
      return 1
  else:
      return 0

def divorced(x):
  if x['MAR'] == 3:
      return 1
  else:
      return 0

def mar_minorities(x):
  if x['MAR'] == 2:
      return 1
  elif x['MAR'] == 4:
      return 1
  else:
      return 0

def single(x):
  if x['MAR'] == 5:
      return 1
  else:
      return 0

# sex
def male(x):
  if x['SEX'] == 1:
      return 1
  else:
      return 0

def female(x):
  if x['SEX'] == 2:
      return 1
  else:
      return 0

# race
def white(x):
  if x['RAC1P'] == 1:
      return 1
  else:
      return 0

def black(x):
  if x['RAC1P'] == 2:
      return 1
  else:
      return 0

def other_race(x):
  if x['RAC1P'] == 9:
      return 1
  else:
      return 0

def asian(x):
  if x['RAC1P'] == 6:
      return 1
  else:
      return 0

def multiple_race(x):
  if x['RAC1P'] == 8:
      return 1
  else:
      return 0

def native_american(x):
  if x['RAC1P'] == 3:
      return 1
  elif x['RAC1P'] == 4:
      return 1
  elif x['RAC1P'] == 5:
      return 1
  elif x['RAC1P'] == 7:
      return 1
  else:
      return 0

# age groups
def twenties(x):
  if x['AGEP'] < 30:
      return 1
  else:
      return 0

def thirties(x):
  if x['AGEP'] < 40 and x['AGEP'] >= 30:
      return 1
  else:
      return 0

def forties(x):
  if x['AGEP'] < 50 and x['AGEP'] >= 40:
      return 1
  else:
      return 0

def fifties(x):
  if x['AGEP'] < 60 and x['AGEP'] >= 50:
      return 1
  else:
      return 0

def sixties(x):
  if x['AGEP'] < 70 and x['AGEP'] >= 60:
      return 1
  else:
      return 0

def elderly(x):
  if x['AGEP'] >= 70:
      return 1
  else:
      return 0

# education level
def no_school(x):
  if x['SCHL'] == 1:
    return 1
  else:
    return 0

def some_school(x):
  if x['SCHL'] > 1 and x['SCHL'] <= 15:
    return 1
  else:
    return 0

def high_school_grad(x):
  if x['SCHL'] > 15 and x['SCHL'] <= 19:
    return 1
  else:
    return 0

def assoc_degree(x):
  if x['SCHL'] == 20:
    return 1
  else:
    return 0

def assoc_degree(x):
  if x['SCHL'] == 20:
    return 1
  else:
    return 0

def bachelor_degree(x):
  if x['SCHL'] == 21:
    return 1
  else:
    return 0

def advanced_degree(x):
  if x['SCHL'] > 21:
    return 1
  else:
    return 0

# work hours
def part_time(x):
  if x['WKHP'] < 30:
      return 1
  else:
      return 0

def full_time(x):
  if x['WKHP'] < 60 and x['WKHP'] >= 30:
      return 1
  else:
      return 0

def over_time(x):
  if x['WKHP'] >= 60:
      return 1
  else:
      return 0

# occupation
def MGR(x):
  if x['OCCP'] <= 440:
    return 1
  else:
    return 0

def BUS(x):
  if x['OCCP'] >= 500 and x['OCCP'] <= 750:
    return 1
  else:
    return 0

def FIN(x):
  if x['OCCP'] >= 800 and x['OCCP'] <= 960:
    return 1
  else:
    return 0 

def CMM(x):
  if x['OCCP'] >= 1005 and x['OCCP'] <= 1240:
    return 1
  else:
    return 0

def ENG(x):
  if x['OCCP'] >= 1305 and x['OCCP'] <= 1560:
    return 1
  else:
    return 0

def SCI(x):
  if x['OCCP'] >= 1600 and x['OCCP'] <= 1980:
    return 1
  else:
    return 0

def CMS(x):
  if x['OCCP'] >= 2001 and x['OCCP'] <= 2060:
    return 1
  else:
    return 0

def LGL(x):
  if x['OCCP'] >= 2105 and x['OCCP'] <= 2180:
    return 1
  else:
    return 0

def EDU(x):
  if x['OCCP'] >= 2205 and x['OCCP'] <= 2555:
    return 1
  else:
    return 0

def ENT(x):
  if x['OCCP'] >= 2600 and x['OCCP'] <= 2920:
    return 1
  else:
    return 0

def MED(x):
  if x['OCCP'] >= 3000 and x['OCCP'] <= 3550:
    return 1
  else:
    return 0

def HLS(x):
  if x['OCCP'] >= 3601 and x['OCCP'] <= 3655:
    return 1
  else:
    return 0

def PRT(x):
  if x['OCCP'] >= 3700 and x['OCCP'] <= 3960:
    return 1
  else:
    return 0

def EAT(x):
  if x['OCCP'] >= 4000 and x['OCCP'] <= 4160:
    return 1
  else:
    return 0

def CLN(x):
  if x['OCCP'] >= 4200 and x['OCCP'] <= 4255:
    return 1
  else:
    return 0

def PRS(x):
  if x['OCCP'] >= 4330 and x['OCCP'] <= 4655:
    return 1
  else:
    return 0

def SAL(x):
  if x['OCCP'] >= 4700 and x['OCCP'] <= 4965:
    return 1
  else:
    return 0

def OFF(x):
  if x['OCCP'] >= 5000 and x['OCCP'] <= 5940:
    return 1
  else:
    return 0

def FFF(x):
  if x['OCCP'] >= 6005 and x['OCCP'] <= 6130:
    return 1
  else:
    return 0

def CON(x):
  if x['OCCP'] >= 6200 and x['OCCP'] <= 6765:
    return 1
  else:
    return 0

def EXT(x):
  if x['OCCP'] >= 6800 and x['OCCP'] <= 6950:
    return 1
  else:
    return 0

def RPR(x):
  if x['OCCP'] >= 7000 and x['OCCP'] <= 7640:
    return 1
  else:
    return 0

def PRD(x):
  if x['OCCP'] >= 7700 and x['OCCP'] <= 8990:
    return 1
  else:
    return 0

def TRN(x):
  if x['OCCP'] >= 9005 and x['OCCP'] <= 9760:
    return 1
  else:
    return 0

def MIL(x):
  if x['OCCP'] >= 9800 and x['OCCP'] <= 9830:
    return 1
  else:
    return 0

# place of birth continents
def us_born(x):
  if x['POBP'] < 100:
    return 1
  else: 
    return 0

def european(x):
  if x['POBP'] < 200 and x['POBP'] >= 100:
    return 1
  else: 
    return 0

def asian(x):
  if x['POBP'] < 300 and x['POBP'] >= 200:
    return 1
  else: 
    return 0

def american(x):
  if x['POBP'] < 400 and x['POBP'] >= 300:
    return 1
  else: 
    return 0

def african(x):
  if x['POBP'] < 500 and x['POBP'] >= 400:
    return 1
  else: 
    return 0

def oceanian(x):
  if x['POBP'] < 600 and x['POBP'] >= 500:
    return 1
  else: 
    return 0

In [16]:
group_functions = [(for_profit_worker, 'for profit worker'), 
                   (non_profit_worker, 'non profit worker'), 
                   (local_govt_worker, 'local govt worker'), 
                   (self_employed_worker, 'self employed worker'), 
                   (cow_minorities, 'cow minorities'), 
                   (married, 'married'), 
                   (divorced, 'divorced'), 
                   (mar_minorities, 'mar minorities'), 
                   (single, 'single'), 
                   (male, 'male'), 
                   (female, 'female'), 
                   (white, 'white'), 
                   (black, 'black'), 
                   (other_race, 'other race'),
                   (asian, 'asian'),
                   (multiple_race, 'multiple race'), 
                   (native_american, 'native american')]

In [17]:
group_functions2 = [
          (twenties, 'twenties'),
          (thirties, 'thirties'),
          (forties, 'forties'),
          (fifties, 'fifties'),
          (sixties, 'sixties'),
          (elderly, 'elderly'),
          (part_time, 'part time'),
          (full_time, 'full time'),
          (over_time, 'over time'),
          (no_school, 'no school'),
          (some_school, 'some school'),
          (high_school_grad, 'high school grad'),
          (assoc_degree, 'associates degree'),
          (bachelor_degree, 'bachelors degree'),
          (advanced_degree, 'advanced degree')          
]

In [18]:
occupation_functions = [
            (MGR, 'MGR'),
            (BUS, 'BUS'),
            (FIN, 'FIN'),
            (CMM, 'CMM'),
            (ENG, 'ENG'),
            (SCI, 'SCI'),
            (CMS, 'CMS'),
            (LGL, 'LGL'),
            (EDU, 'EDU'),
            (ENT, 'ENT'),
            (MED, 'MED'),
            (HLS, 'HLS'),
            (PRT, 'PRT'),
            (EAT, 'EAT'),
            (CLN, 'CLN'),
            (PRS, 'PRS'),
            (SAL, 'SAL'),
            (OFF, 'OFF'),
            (FFF, 'FFF'),
            (CON, 'CON'),
            (EXT, 'EXT'),
            (RPR, 'RPR'),
            (PRD, 'PRD'),
            (TRN, 'TRN'),
            (MIL, 'MIL')            
]

In [19]:
continent_functions = [
                       (us_born, 'US born'),
                       (european, 'european'),
                       (asian, 'asian'),
                       (american, 'american'),
                       (african, 'african'),
                       (oceanian, 'oceanian')
]

In [20]:
# rank each of these groups according to their error from the decision stump
funcs = [group_functions, group_functions2, occupation_functions, continent_functions]
ranking = []
for func in funcs:
  for fun, name in func:
    err = measure_group_error(f, fun, train_x, train_y)
    ranking.append((fun, name, err))
ranking.sort(key = lambda x: x[2])
ranking.reverse()

In [21]:
ranking

[(<function __main__.EDU>, 'EDU', 0.5546695662679715),
 (<function __main__.CMS>, 'CMS', 0.5438673068529027),
 (<function __main__.ENT>, 'ENT', 0.5312569521690768),
 (<function __main__.PRT>, 'PRT', 0.5147113594040968),
 (<function __main__.EXT>, 'EXT', 0.4534883720930233),
 (<function __main__.LGL>, 'LGL', 0.437984496124031),
 (<function __main__.RPR>, 'RPR', 0.42611894543225015),
 (<function __main__.local_govt_worker>,
  'local govt worker',
  0.38288740754269013),
 (<function __main__.assoc_degree>, 'associates degree', 0.3556405353728489),
 (<function __main__.bachelor_degree>, 'bachelors degree', 0.3521507025832975),
 (<function __main__.SCI>, 'SCI', 0.34404761904761905),
 (<function __main__.cow_minorities>, 'cow minorities', 0.34098577571948396),
 (<function __main__.non_profit_worker>,
  'non profit worker',
  0.33964226289517474),
 (<function __main__.self_employed_worker>,
  'self employed worker',
  0.33831230148510605),
 (<function __main__.BUS>, 'BUS', 0.33788187372708756

### Manual Groups

In the following cell(s), generate group functions that you think will make improvements, and then try to run the updates as explained in the subsequent section. In the final version of your code that you turn in, the groups that you generated and their corresponding models h, and the order in which you did updates, should be obvious and re-generating your final PDL should be completely reproducible just by running the code blocks in this notebook.

In [22]:
# manually defined groups that tend to be systematically mistreated and would be educated guesses to test on the model

def black_men(x):
  if x['RAC1P'] == 2 and x['SEX'] == 1:
    return 1
  else: 
    return 0

def black_women(x):
  if x['RAC1P'] == 2 and x['SEX'] == 2:
    return 1
  else: 
    return 0

def asian_men(x):
  if x['RAC1P'] == 6 and x['SEX'] == 1:
    return 1
  else: 
    return 0

def asian_women(x):
  if x['RAC1P'] == 6 and x['SEX'] == 2:
    return 1
  else: 
    return 0

def white_men(x):
  if x['RAC1P'] == 1 and x['SEX'] == 1:
    return 1
  else: 
    return 0

def white_women(x):
  if x['RAC1P'] == 1 and x['SEX'] == 2:
    return 1
  else: 
    return 0

def native_american_men(x):
  if x['RAC1P'] == 3 and x['SEX'] == 1:
      return 1
  elif x['RAC1P'] == 4 and x['SEX'] == 1:
      return 1
  elif x['RAC1P'] == 5 and x['SEX'] == 1:
      return 1
  elif x['RAC1P'] == 7 and x['SEX'] == 1:
      return 1
  else:
      return 0

def native_american_women(x):
  if x['RAC1P'] == 3 and x['SEX'] == 2:
      return 1
  elif x['RAC1P'] == 4 and x['SEX'] == 2:
      return 1
  elif x['RAC1P'] == 5 and x['SEX'] == 2:
      return 1
  elif x['RAC1P'] == 7 and x['SEX'] == 2:
      return 1
  else:
      return 0

def middle_aged_black(x):
  if x['RAC1P'] == 2 and x['AGEP'] >= 40 and x['AGEP'] < 70:
      return 1
  else:
      return 0

def middle_aged_asian(x):
  if x['RAC1P'] == 6 and x['AGEP'] >= 40 and x['AGEP'] < 70:
      return 1
  else:
      return 0

def middle_aged_white(x):
  if x['RAC1P'] == 1 and x['AGEP'] >= 40 and x['AGEP'] < 70:
      return 1
  else:
      return 0

def middle_aged_native_american(x):
  if x['RAC1P'] == 3 and x['AGEP'] >= 40 and x['AGEP'] < 70:
      return 1
  elif x['RAC1P'] == 4 and x['AGEP'] >= 40 and x['AGEP'] < 70:
      return 1
  elif x['RAC1P'] == 5 and x['AGEP'] >= 40 and x['AGEP'] < 70:
      return 1
  elif x['RAC1P'] == 7 and x['AGEP'] >= 40 and x['AGEP'] < 70:
      return 1
  else:
      return 0

def overtime_white(x):
  if x['RAC1P'] == 1 and x['WKHP'] >= 60:
      return 1
  else:
      return 0

def overtime_black(x):
  if x['RAC1P'] == 2 and x['WKHP'] >= 60:
      return 1
  else:
      return 0

def overtime_asian(x):
  if x['RAC1P'] == 6 and x['WKHP'] >= 60:
      return 1
  else:
      return 0

def overtime_native_american(x):
  if x['RAC1P'] == 3 and x['WKHP'] >= 60:
      return 1
  elif x['RAC1P'] == 4 and x['WKHP'] >= 60:
      return 1
  elif x['RAC1P'] == 5 and x['WKHP'] >= 60:
      return 1
  elif x['RAC1P'] == 7 and x['WKHP'] >= 60:
      return 1
  else:
      return 0

def married_men(x):
  if x['MAR'] == 1 and x['SEX'] == 1:
    return 1
  else: 
    return 0

def married_women(x):
  if x['MAR'] == 1 and x['SEX'] == 2:
    return 1
  else: 
    return 0

def divorced_men(x):
  if x['MAR'] == 3 and x['SEX'] == 1:
    return 1
  else: 
    return 0

def divorced_women(x):
  if x['MAR'] == 3 and x['SEX'] == 2:
    return 1
  else: 
    return 0

def single_men(x):
  if x['MAR'] == 5 and x['SEX'] == 1:
    return 1
  else: 
    return 0

def single_women(x):
  if x['MAR'] == 5 and x['SEX'] == 2:
    return 1
  else: 
    return 0

def divorced_black_women(x):
  if x['MAR'] == 3 and x['SEX'] == 2 and x['RAC1P'] == 2:
    return 1
  else: 
    return 0   

def single_black_women(x):
  if x['MAR'] == 5 and x['SEX'] == 2 and x['RAC1P'] == 2:
    return 1
  else: 
    return 0

def divorced_white_women(x):
  if x['MAR'] == 3 and x['SEX'] == 2 and x['RAC1P'] == 1:
    return 1
  else: 
    return 0   

def single_white_women(x):
  if x['MAR'] == 5 and x['SEX'] == 2 and x['RAC1P'] == 1:
    return 1
  else: 
    return 0

def divorced_asian_women(x):
  if x['MAR'] == 3 and x['SEX'] == 2 and x['RAC1P'] == 6:
    return 1
  else: 
    return 0   

def single_asian_women(x):
  if x['MAR'] == 5 and x['SEX'] == 2 and x['RAC1P'] == 6:
    return 1
  else: 
    return 0

def divorced_native_american_women(x):
  if x['MAR'] == 3 and x['SEX'] == 2 and (x['RAC1P'] == 3 or x['RAC1P'] == 4 or x['RAC1P'] == 5 or x['RAC1P'] == 7):
    return 1
  else: 
    return 0   

def single_native_american_women(x):
  if x['MAR'] == 5 and x['SEX'] == 2 and (x['RAC1P'] == 3 or x['RAC1P'] == 4 or x['RAC1P'] == 5 or x['RAC1P'] == 7):
    return 1
  else: 
    return 0

In [23]:
manual_functions = [
  (black_men, 'black men'),
  (black_women, 'black women'),
  (white_men, 'white men'),
  (white_women, 'white women'),
  (asian_men, 'asian men'),
  (asian_women, 'asian_women'),
  (native_american_men, 'native american men'),
  (native_american_women, 'natie american women'),
  (middle_aged_black, 'middle aged black'),
  (middle_aged_white, 'middle aged white'),
  (middle_aged_asian, 'middle aged asian'),
  (middle_aged_native_american, 'middle aged native american'),
  (overtime_white, 'overtime white'),
  (overtime_black, 'overtime black'),
  (overtime_asian, 'overtime asian'),
  (overtime_native_american, 'overtime native american'),
  (married_men, 'married men'),
  (married_women, 'married women'),
  (divorced_men, 'divorced women'),
  (single_men, 'single men'),
  (single_women, 'single women'),
  (divorced_black_women, 'divorced black women'),
  (single_black_women, 'single black women'),
  (divorced_white_women, 'divorced white women'),
  (single_white_women, 'single white women'),
  (divorced_asian_women, 'divorced asian women'),
  (single_asian_women, 'single asian women'),
  (divorced_native_american_women, 'divorced_native_american_women'),
  (single_native_american_women, 'single native american women')
]

## Updating the Function

Once you've found a promising group g, you can run the following updater code. Here, we define two different update functions. The first, `simple_updater`, only requires that you find some group g that you think f might do poorly on. Then, it automatically trains a decision list of depth 10 on the training data restricted to your g, and it passes that model and g along to the updater.

You might want to do something a bit fancier than a decision tree to make your model, in which case you can run the second updater, which takes as input a group g and model h, and then updates f accordingly.

Every time you run the update function, it will tell you if your (g,h) passed the validation checks, i.e. if a) your group existed in the validation data and b) it made an improvement compared to f. If it did pass the validation checks, then the model f is updated to include your g and h. **Note that this means that as you run updates, it will be increasingly difficult to find groups that make improvements.**

In [24]:
def simple_updater(g, group_name="g"):
    # if you want to change how h is trained, you can edit the below line.
    print('Testing Group: ' + group_name)
    h = build_model(train_x, train_y, g, dt_depth=10) #bountyHuntWrapper.build_model(train_x, train_y, g, dt_depth=10)
    # do not change anything beyond this point.
    if run_checks(f, validation_x, validation_y, g, h, train_x=train_x, train_y=train_y): #bountyHuntWrapper.run_checks(f, validation_x, validation_y, g, h, train_x=train_x, train_y=train_y):
        print("Running Update")
        run_updates(f, g, h, train_x, train_y, validation_x, validation_y, group_name=group_name) #bountyHuntWrapper.run_updates(f, g, h, train_x, train_y, validation_x, validation_y, group_name=group_name)

def updater(g, h, group_name="g"):
    # do not alter this code
    if run_checks(f, validation_x, validation_y, g, h, train_x=train_x, train_y=train_y): #bountyHuntWrapper.run_checks(f, validation_x, validation_y, g, h, train_x=train_x, train_y=train_y):
        print("Running Update")
        run_updates(f, g, h, train_x, train_y, validation_x, validation_y, group_name=group_name) #bountyHuntWrapper.run_updates(f, g, h, train_x, train_y, validation_x, validation_y, group_name=group_name)

In the below block, provide a script that builds *the entire final PDL that you come up with*. We will run this on the initial version of f (the decision stump) in order to evaluate your code. (Note: it is fine for the group functions g and the hs to be defined in above code blocks, just make sure that everything runs as you expect if you run everything from a clean kernel)

In [25]:
# updating with algorithmically ordered defined groups
for func, name, _ in ranking:
  simple_updater(func, name)

# updating with manually defined groups
for func, name in manual_functions:
  simple_updater(func, name)

Testing Group: EDU
building h
finished building h
Error of current model on proposed group: 0.5519516217702034
Error of h trained on proposed group: 0.18856514568444205
Group size in test set: 1819
Group weight in test set: 0.06197614991482112
Training Error of current model on proposed group: 0.5546695662679715
Training Error of h trained on proposed group: 0.11537996858765254
Group size in training set: 8277
Group weight in training set: 0.060431938319558426
Passed checks.
Running Update
updating errors
getting new errors
prob node -1
new errs [0.2571039182282794, 0.18856514568444205]
Testing Group: CMS
building h
finished building h
Error of current model on proposed group: 0.5603271983640081
Error of h trained on proposed group: 0.32515337423312884
Group size in test set: 489
Group weight in test set: 0.01666098807495741
Training Error of current model on proposed group: 0.5438673068529027
Training Error of h trained on proposed group: 0.12701876909646448
Group size in training set

# Saving Your Model

We'd like to output the PDL to some permanent location for grading purposes. The lines below do this.

In [26]:
# you will probably need to install dill, which you do w pip install dill in your command line
with open('pdl.pkl', 'wb') as pickle_file:
    pickle.dump(f, pickle_file)

If you saved your PDL to pdl.pkl and you want to reload it, you can do so as follows (instead of re-building it from scratch every time you shut down your kernel). Just be sure that your final PDL is fully replicable in the final version of your code, so that we can re-build it just given your gs and hs.

In [27]:
with open('pdl.pkl', 'rb') as pickle_file:
    content = pickle.load(pickle_file)

In [37]:
pprint(vars(f))

{'all_groups': None,
 'curr_node': <__main__.PointerDecisionListNode object at 0x7f4c97d98f90>,
 'head': <__main__.PointerDecisionListNode object at 0x7f4c97d98f90>,
 'leaves': [<bound method BaseDecisionTree.predict of DecisionTreeClassifier(max_depth=1, random_state=0)>,
            <bound method BaseDecisionTree.predict of DecisionTreeClassifier(max_depth=10, random_state=0)>,
            <bound method BaseDecisionTree.predict of DecisionTreeClassifier(max_depth=10, random_state=0)>,
            <bound method BaseDecisionTree.predict of DecisionTreeClassifier(max_depth=10, random_state=0)>,
            <bound method BaseDecisionTree.predict of DecisionTreeClassifier(max_depth=10, random_state=0)>,
            <bound method BaseDecisionTree.predict of DecisionTreeClassifier(max_depth=10, random_state=0)>,
            <bound method BaseDecisionTree.predict of DecisionTreeClassifier(max_depth=10, random_state=0)>,
            <bound method BaseDecisionTree.predict of DecisionTreeClassi

In [36]:
pprint(vars(content))

{'all_groups': None,
 'curr_node': <__main__.PointerDecisionListNode object at 0x7f4c97f6f510>,
 'head': <__main__.PointerDecisionListNode object at 0x7f4c97f6f510>,
 'leaves': [<bound method predict of DecisionTreeClassifier(max_depth=1, random_state=0)>,
            <bound method predict of DecisionTreeClassifier(max_depth=10, random_state=0)>,
            <bound method predict of DecisionTreeClassifier(max_depth=10, random_state=0)>,
            <bound method predict of DecisionTreeClassifier(max_depth=10, random_state=0)>,
            <bound method predict of DecisionTreeClassifier(max_depth=10, random_state=0)>,
            <bound method predict of DecisionTreeClassifier(max_depth=10, random_state=0)>,
            <bound method predict of DecisionTreeClassifier(max_depth=10, random_state=0)>,
            <bound method predict of DecisionTreeClassifier(max_depth=10, random_state=0)>,
            <bound method predict of DecisionTreeClassifier(max_depth=10, random_state=0)>,
       

In [28]:
#f = content

# Analysis of Your Final Model

## Part 1

1. How does your final model perform? On both the validation set and the training data, calculate f's error rates on each of the groups you identified, calculate the error rates of the initial model on each of the groups you identified, and compare them by taking their difference. Hint: you can use the helper function `bountyHuntWrapper.measure_group_error(model, g, x, y)` to get the error of f on x and y restricted to just the datapoints in a group g, and you can use `metrics.zero_one_loss` for the initial model (which is just a DL so you can directly use the scikit.learn functions on it).

My final model achieves an overall accuracy of 0.8454 on the training data and 0.8148 on the validation data. To see the breakdown of the error rates per group, please see the relevant cells below in the Data Exploration section of this notebook.

#### Error Output from Training Data

~~~~~~~~~~~~~~~~~
EDU
model error: 0.1355563610003625
stump error: 0.5546695662679715
error difference: -0.419113205267609
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
PRT
model error: 0.1389199255121043
stump error: 0.5147113594040968
error difference: -0.37579143389199254
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
CMS
model error: 0.1903099083369707
stump error: 0.5438673068529027
error difference: -0.35355739851593204
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
ENT
model error: 0.19933259176863183
stump error: 0.5312569521690768
error difference: -0.33192436040044493
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
LGL
model error: 0.1647286821705426
stump error: 0.437984496124031
error difference: -0.2732558139534884
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
oceanian
model error: 0.05250596658711215
stump error: 0.3174224343675418
error difference: -0.26491646778042965
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
local govt worker
model error: 0.14080905853346726
stump error: 0.38288740754269013
error difference: -0.24207834900922287
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
SCI
model error: 0.1255952380952381
stump error: 0.34404761904761905
error difference: -0.21845238095238095
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
MIL
model error: 0.08424908424908428
stump error: 0.2985347985347986
error difference: -0.2142857142857143
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
part time
model error: 0.06750504211650254
stump error: 0.26780559180606633
error difference: -0.2003005496895638
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
overtime white
model error: 0.15157661810805823
stump error: 0.3459339848792181
error difference: -0.1943573667711599
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
associates degree
model error: 0.1615223527269416
stump error: 0.3556405353728489
error difference: -0.1941181826459073
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
RPR
model error: 0.2375843041079092
stump error: 0.42611894543225015
error difference: -0.18853464132434095
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
over time
model error: 0.14849999999999997
stump error: 0.328375
error difference: -0.179875
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
non profit worker
model error: 0.16035773710482526
stump error: 0.33964226289517474
error difference: -0.17928452579034948
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
overtime black
model error: 0.1498559077809798
stump error: 0.32564841498559083
error difference: -0.17579250720461104
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
SAL
model error: 0.16530194472876147
stump error: 0.3360871472437491
error difference: -0.17078520251498763
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
cow minorities
model error: 0.17174991730069467
stump error: 0.34098577571948396
error difference: -0.1692358584187893
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
bachelors degree
model error: 0.18454316745413268
stump error: 0.3521507025832975
error difference: -0.16760753512916482
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced black women
model error: 0.2351598173515982
stump error: 0.4018264840182648
error difference: -0.16666666666666663
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single asian women
model error: 0.13140643623361148
stump error: 0.2949940405244339
error difference: -0.16358760429082242
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
african
model error: 0.1428571428571429
stump error: 0.30434782608695654
error difference: -0.16149068322981364
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single black women
model error: 0.1062124248496994
stump error: 0.2645290581162325
error difference: -0.1583166332665331
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
twenties
model error: 0.0792342133703875
stump error: 0.2327312027173074
error difference: -0.1534969893469199
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
other race
model error: 0.13979979288919575
stump error: 0.29288919571971006
error difference: -0.1530894028305143
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
elderly
model error: 0.18570410828781758
stump error: 0.3360246972215626
error difference: -0.150320588933745
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
black men
model error: 0.1585704371963914
stump error: 0.30707841776544065
error difference: -0.14850798056904924
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
black
model error: 0.15614506380120885
stump error: 0.3038952316991269
error difference: -0.14775016789791806
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
black women
model error: 0.1538711776187378
stump error: 0.3009108653220559
error difference: -0.1470396877033181
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
overtime asian
model error: 0.1293402777777778
stump error: 0.27604166666666663
error difference: -0.14670138888888884
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
european
model error: 0.1607629427792916
stump error: 0.3073569482288828
error difference: -0.14659400544959122
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single white women
model error: 0.1104080987029421
stump error: 0.255773489402088
error difference: -0.1453653906991459
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
BUS
model error: 0.19307535641547857
stump error: 0.33788187372708756
error difference: -0.14480651731160898
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
US born
model error: 0.15926198304204875
stump error: 0.3037614639210936
error difference: -0.14449948087904485
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
white men
model error: 0.16199960342814335
stump error: 0.30648394985569194
error difference: -0.1444843464275486
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single women
model error: 0.10784092447975446
stump error: 0.2517943393671286
error difference: -0.14395341488737412
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
CON
model error: 0.19461591220850483
stump error: 0.3357338820301783
error difference: -0.14111796982167346
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
married men
model error: 0.17857500500700985
stump error: 0.3182705788103345
error difference: -0.13969557380332465
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
self employed worker
model error: 0.19864366040003434
stump error: 0.33831230148510605
error difference: -0.1396686410850717
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
middle aged black
model error: 0.2038523274478331
stump error: 0.34253611556982344
error difference: -0.13868378812199034
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
white
model error: 0.1600009446550238
stump error: 0.2980504681946462
error difference: -0.1380495235396224
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced women
model error: 0.20487019521030392
stump error: 0.33930368283356815
error difference: -0.13443348762326424
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
middle aged white
model error: 0.18341244725738393
stump error: 0.3175545007032349
error difference: -0.13414205344585095
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single
model error: 0.11327407778334164
stump error: 0.24632782507093975
error difference: -0.1330537472875981
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
male
model error: 0.1587538527456428
stump error: 0.29157855454658543
error difference: -0.13282470180094264
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
native american men
model error: 0.13808975834292292
stump error: 0.2692750287686997
error difference: -0.13118527042577677
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
white women
model error: 0.1576925034352893
stump error: 0.288309837650771
error difference: -0.13061733421548172
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
married
model error: 0.17462287036778856
stump error: 0.30458968853693225
error difference: -0.12996681816914368
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
Total
model error: 0.1545734645600304
stump error: 0.28344674513010715
error difference: -0.12887328057007674
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
EXT
model error: 0.32558139534883723
stump error: 0.4534883720930233
error difference: -0.12790697674418605
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
female
model error: 0.14989243650658535
stump error: 0.2743410768730751
error difference: -0.12444864036648973
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single men
model error: 0.1179437439379244
stump error: 0.24162948593598443
error difference: -0.12368574199806004
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
sixties
model error: 0.19452415112386423
stump error: 0.3180296508847441
error difference: -0.12350549976087988
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
FIN
model error: 0.19705974350954014
stump error: 0.3203002815139193
error difference: -0.12324053800437917
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
high school grad
model error: 0.1614503816793893
stump error: 0.28396946564885495
error difference: -0.12251908396946565
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
middle aged native american
model error: 0.1701346389228886
stump error: 0.2925336597307222
error difference: -0.1223990208078336
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
forties
model error: 0.16650678866587953
stump error: 0.2885182998819362
error difference: -0.12201151121605669
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced_native_american_women
model error: 0.12195121951219512
stump error: 0.24390243902439024
error difference: -0.12195121951219512
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
native american
model error: 0.14217156568686262
stump error: 0.2639472105578884
error difference: -0.12177564487102577
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
fifties
model error: 0.18364688856729383
stump error: 0.303328509406657
error difference: -0.11968162083936318
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
married women
model error: 0.16965578000125858
stump error: 0.2873953810332893
error difference: -0.11773960103203074
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced
model error: 0.19510624597553128
stump error: 0.31270122343850615
error difference: -0.11759497746297487
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single native american women
model error: 0.10738255033557043
stump error: 0.22483221476510062
error difference: -0.1174496644295302
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
thirties
model error: 0.1723421926910299
stump error: 0.2888289036544851
error difference: -0.1164867109634552
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
asian men
model error: 0.15625823451910403
stump error: 0.2718489240228371
error difference: -0.11559068950373308
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
asian_women
model error: 0.15853027427979993
stump error: 0.27298602725547694
error difference: -0.114455752975677
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
advanced degree
model error: 0.12800548455021787
stump error: 0.24083051760442686
error difference: -0.11282503305420899
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
natie american women
model error: 0.14661654135338342
stump error: 0.25814536340852134
error difference: -0.11152882205513792
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
full time
model error: 0.1762782487919211
stump error: 0.28379486289148026
error difference: -0.10751661409955915
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced white women
model error: 0.19161327897495628
stump error: 0.29702970297029707
error difference: -0.10541642399534079
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
asian
model error: 0.16258448229251143
stump error: 0.26774804001081376
error difference: -0.10516355771830233
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
asian
model error: 0.16258448229251143
stump error: 0.26774804001081376
error difference: -0.10516355771830233
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
for profit worker
model error: 0.14701058761839947
stump error: 0.24843528611114207
error difference: -0.1014246984927426
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
mar minorities
model error: 0.16230045026606632
stump error: 0.2627916496111339
error difference: -0.10049119934506756
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
MGR
model error: 0.17568294980487154
stump error: 0.27526577849549183
error difference: -0.0995828286906203
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
middle aged asian
model error: 0.1702953052014884
stump error: 0.26458712690998343
error difference: -0.09429182170849504
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
overtime native american
model error: 0.1839080459770115
stump error: 0.27586206896551724
error difference: -0.09195402298850575
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
MED
model error: 0.15573199609973531
stump error: 0.24571667363142502
error difference: -0.0899846775316897
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced asian women
model error: 0.18008948545861303
stump error: 0.26957494407158833
error difference: -0.0894854586129753
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
multiple race
model error: 0.1276367986902588
stump error: 0.21245513506706126
error difference: -0.08481833637680247
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
PRD
model error: 0.18551150269211947
stump error: 0.26692772067221404
error difference: -0.08141621798009457
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
CMM
model error: 0.13857170505128702
stump error: 0.2144377782078576
error difference: -0.07586607315657057
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
american
model error: 0.12841646534981765
stump error: 0.20259580313580594
error difference: -0.07417933778598829
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
TRN
model error: 0.12634578847371758
stump error: 0.19896559003588765
error difference: -0.07261980156217007
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
ENG
model error: 0.12488493402884315
stump error: 0.19300398895366677
error difference: -0.06811905492482362
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
OFF
model error: 0.21330379829304835
stump error: 0.28014854387908006
error difference: -0.0668447455860317
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
some school
model error: 0.09503908366225522
stump error: 0.14031421716585402
error difference: -0.0452751335035988
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
no school
model error: 0.0972972972972973
stump error: 0.1409563409563409
error difference: -0.043659043659043606
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
CLN
model error: 0.08427523299623241
stump error: 0.11183819155264718
error difference: -0.027562958556414774
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
EAT
model error: 0.035506241331484056
stump error: 0.05395284327323158
error difference: -0.01844660194174752
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
PRS
model error: 0.10044486068836334
stump error: 0.11613205338328259
error difference: -0.01568719269491925
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
HLS
model error: 0.09564628919467633
stump error: 0.1026392961876833
error difference: -0.006993006993006978
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
FFF
model error: 0.0598111227701994
stump error: 0.06400839454354668
error difference: -0.0041972717733472775
~~~~~~~~~~~~~~~~~

#### Error Output from Validation Data

~~~~~~~~~~~~~~~~~
EDU
model error: 0.17372182517866963
stump error: 0.5519516217702034
error difference: -0.3782297965915338
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
PRT
model error: 0.15656565656565657
stump error: 0.48484848484848486
error difference: -0.3282828282828283
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
ENT
model error: 0.24163969795037754
stump error: 0.5544768069039914
error difference: -0.31283710895361383
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
CMS
model error: 0.29243353783231085
stump error: 0.5603271983640081
error difference: -0.2678936605316973
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
overtime native american
model error: 0.17142857142857137
stump error: 0.4285714285714286
error difference: -0.25714285714285723
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
local govt worker
model error: 0.194639175257732
stump error: 0.385979381443299
error difference: -0.19134020618556702
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single native american women
model error: 0.09523809523809523
stump error: 0.2698412698412699
error difference: -0.17460317460317465
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
LGL
model error: 0.23076923076923073
stump error: 0.40384615384615385
error difference: -0.17307692307692313
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
overtime white
model error: 0.1831470335339639
stump error: 0.352536543422184
error difference: -0.1693895098882201
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
EXT
model error: 0.25
stump error: 0.41666666666666663
error difference: -0.16666666666666663
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
part time
model error: 0.10673024023473321
stump error: 0.2684760682193288
error difference: -0.16174582798459558
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
RPR
model error: 0.27326440177252587
stump error: 0.431314623338257
error difference: -0.15805022156573112
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
non profit worker
model error: 0.1865234375
stump error: 0.33642578125
error difference: -0.14990234375
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
oceanian
model error: 0.21111111111111114
stump error: 0.3555555555555555
error difference: -0.14444444444444438
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
over time
model error: 0.2029478458049887
stump error: 0.3378684807256236
error difference: -0.13492063492063489
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
MIL
model error: 0.1826923076923077
stump error: 0.3173076923076923
error difference: -0.13461538461538458
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
SCI
model error: 0.19075144508670516
stump error: 0.32369942196531787
error difference: -0.1329479768786127
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
associates degree
model error: 0.2318238471125883
stump error: 0.36352305774823435
error difference: -0.13169921063564605
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
cow minorities
model error: 0.21148415437715717
stump error: 0.342955757765924
error difference: -0.13147160338876684
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
bachelors degree
model error: 0.21492792921364345
stump error: 0.3462252033680605
error difference: -0.13129727415441705
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
SAL
model error: 0.19096509240246407
stump error: 0.3189596167008898
error difference: -0.12799452429842573
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
natie american women
model error: 0.18497109826589597
stump error: 0.3121387283236994
error difference: -0.1271676300578034
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
black men
model error: 0.20341614906832295
stump error: 0.327639751552795
error difference: -0.12422360248447206
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
overtime black
model error: 0.20731707317073167
stump error: 0.3292682926829268
error difference: -0.12195121951219512
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
twenties
model error: 0.10664538595473017
stump error: 0.22576900754497964
error difference: -0.11912362159024947
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single white women
model error: 0.13270676691729322
stump error: 0.25037593984962403
error difference: -0.11766917293233081
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single asian women
model error: 0.17861975642760486
stump error: 0.29364005412719896
error difference: -0.1150202976995941
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
native american
model error: 0.17528735632183912
stump error: 0.29022988505747127
error difference: -0.11494252873563215
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
US born
model error: 0.18503937007874016
stump error: 0.2994144962648899
error difference: -0.11437512618614976
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
european
model error: 0.18652849740932642
stump error: 0.30051813471502586
error difference: -0.11398963730569944
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
BUS
model error: 0.20969245107176138
stump error: 0.32339235787511644
error difference: -0.11369990680335507
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
white men
model error: 0.18994933305759487
stump error: 0.30255402750491156
error difference: -0.11260469444731669
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
self employed worker
model error: 0.24286815728604472
stump error: 0.35389360061680797
error difference: -0.11102544333076325
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single women
model error: 0.13267243197637624
stump error: 0.24277578569921954
error difference: -0.1101033537228433
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
black
model error: 0.20092024539877296
stump error: 0.30828220858895705
error difference: -0.1073619631901841
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
white
model error: 0.18734566207539927
stump error: 0.29457279262470504
error difference: -0.10722713054930577
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
married men
model error: 0.21121562090725088
stump error: 0.3168234313608763
error difference: -0.10560781045362544
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
middle aged white
model error: 0.21049933353839845
stump error: 0.31518507126012507
error difference: -0.10468573772172662
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
native american men
model error: 0.1657142857142857
stump error: 0.26857142857142857
error difference: -0.10285714285714287
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
white women
model error: 0.1844013096351731
stump error: 0.2855472404115996
error difference: -0.10114593077642653
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
married
model error: 0.20583436663838672
stump error: 0.3052274358807022
error difference: -0.0993930692423155
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
middle aged native american
model error: 0.2325581395348837
stump error: 0.33139534883720934
error difference: -0.09883720930232565
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single
model error: 0.14089848308051345
stump error: 0.23842862699338774
error difference: -0.09753014391287429
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
male
model error: 0.19039875105704807
stump error: 0.28660638782280623
error difference: -0.09620763676575816
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
high school grad
model error: 0.1817295188556567
stump error: 0.27673927178153446
error difference: -0.09500975292587777
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
Total
model error: 0.1851788756388416
stump error: 0.2796252129471891
error difference: -0.0944463373083475
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
other race
model error: 0.17191489361702128
stump error: 0.26553191489361705
error difference: -0.09361702127659577
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
FIN
model error: 0.2627627627627628
stump error: 0.3558558558558559
error difference: -0.0930930930930931
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
female
model error: 0.17943764756385494
stump error: 0.2719467696930672
error difference: -0.09250912212921225
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
fifties
model error: 0.21246411079209593
stump error: 0.30434048302651584
error difference: -0.09187637223441991
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
married women
model error: 0.1993067590987868
stump error: 0.29116117850953205
error difference: -0.09185441941074524
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
black women
model error: 0.1984848484848485
stump error: 0.2893939393939394
error difference: -0.09090909090909094
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
forties
model error: 0.1923549009293355
stump error: 0.28195686480799576
error difference: -0.08960196387866026
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single black women
model error: 0.14984709480122327
stump error: 0.23853211009174313
error difference: -0.08868501529051986
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced women
model error: 0.23463687150837986
stump error: 0.3230912476722533
error difference: -0.08845437616387342
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single men
model error: 0.14793433158939207
stump error: 0.234710445607072
error difference: -0.08677611401767993
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
thirties
model error: 0.1955307262569832
stump error: 0.2807661612130886
error difference: -0.08523543495610542
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
sixties
model error: 0.23776610450649704
stump error: 0.3184959911528892
error difference: -0.08072988664639213
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
MGR
model error: 0.20781696854146803
stump error: 0.28852875754687
error difference: -0.08071178900540199
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
CON
model error: 0.24189723320158107
stump error: 0.32252964426877473
error difference: -0.08063241106719365
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
asian_women
model error: 0.19291014014839236
stump error: 0.2732893652102226
error difference: -0.08037922506183026
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
asian men
model error: 0.1868962620747585
stump error: 0.2654346913061739
error difference: -0.07853842923141541
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
full time
model error: 0.2030904079880721
stump error: 0.2777300862964803
error difference: -0.07463967830840823
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
elderly
model error: 0.273109243697479
stump error: 0.34663865546218486
error difference: -0.07352941176470584
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
advanced degree
model error: 0.16821777570963237
stump error: 0.23825034899953468
error difference: -0.07003257328990231
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced
model error: 0.2278860569715142
stump error: 0.29760119940029983
error difference: -0.06971514242878563
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
middle aged black
model error: 0.25113464447806355
stump error: 0.3207261724659607
error difference: -0.06959152798789714
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
for profit worker
model error: 0.1716066191872644
stump error: 0.23936950146627567
error difference: -0.06776288227901128
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
asian
model error: 0.2030893897189162
stump error: 0.2699417574069385
error difference: -0.06685236768802227
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
asian
model error: 0.2030893897189162
stump error: 0.2699417574069385
error difference: -0.06685236768802227
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
middle aged asian
model error: 0.20376055257099002
stump error: 0.2705295471987721
error difference: -0.06676899462778207
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced white women
model error: 0.2198198198198198
stump error: 0.2828828828828829
error difference: -0.06306306306306309
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
african
model error: 0.18840579710144922
stump error: 0.25120772946859904
error difference: -0.06280193236714982
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced asian women
model error: 0.23469387755102045
stump error: 0.29591836734693877
error difference: -0.061224489795918324
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
PRD
model error: 0.19192688499619193
stump error: 0.2490479817212491
error difference: -0.05712109672505716
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
mar minorities
model error: 0.20837209302325577
stump error: 0.2641860465116279
error difference: -0.05581395348837215
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced_native_american_women
model error: 0.2777777777777778
stump error: 0.33333333333333337
error difference: -0.05555555555555558
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
overtime asian
model error: 0.2449799196787149
stump error: 0.29718875502008035
error difference: -0.052208835341365445
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced black women
model error: 0.32098765432098764
stump error: 0.37037037037037035
error difference: -0.04938271604938271
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
multiple race
model error: 0.1669052390495276
stump error: 0.20870312052676787
error difference: -0.041797881477240284
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
CMM
model error: 0.14674441205053446
stump error: 0.1885325558794947
error difference: -0.041788143828960234
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
ENG
model error: 0.14204545454545459
stump error: 0.18181818181818177
error difference: -0.03977272727272718
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
OFF
model error: 0.22432762836185816
stump error: 0.2634474327628362
error difference: -0.039119804400978064
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
TRN
model error: 0.17830540037243947
stump error: 0.2094972067039106
error difference: -0.031191806331471117
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
american
model error: 0.16924778761061943
stump error: 0.1975663716814159
error difference: -0.02831858407079646
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
MED
model error: 0.21038790269559504
stump error: 0.23668639053254437
error difference: -0.026298487836949325
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
some school
model error: 0.12393465909090906
stump error: 0.14666193181818177
error difference: -0.022727272727272707
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
PRS
model error: 0.09820485744456176
stump error: 0.11087645195353746
error difference: -0.012671594508975703
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
FFF
model error: 0.04134366925064603
stump error: 0.05167958656330751
error difference: -0.01033591731266148
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
no school
model error: 0.1216216216216216
stump error: 0.12355212355212353
error difference: -0.0019305019305019266
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
CLN
model error: 0.11742777260018644
stump error: 0.11742777260018644
error difference: 0.0
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
EAT
model error: 0.059175531914893664
stump error: 0.05452127659574468
error difference: 0.004654255319148981
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
HLS
model error: 0.13844621513944222
stump error: 0.12151394422310757
error difference: 0.016932270916334646
~~~~~~~~~~~~~~~~~

## Part 2

2. Say instead you used bootstrapped fairness to postprocess equal error rates on the initial model over the groups you discovered (assuming you had a way to identify those groups ahead of time). How much would you need to inflate each groups' error to get them equal?

The group that I identified with the highest group error from the initial decision stump model were the people who have occupations in the field of education (group name EDU). It had a group error of 0.5546695662679715. This means that I would have to inflate each identified group's error by the following amounts detailed below.

~~~~~~~~~~~~~~~~~
EAT
error inflation: 0.5007167229947399
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
FFF
error inflation: 0.4906611717244248
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
HLS
error inflation: 0.4520302700802882
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
CLN
error inflation: 0.4428313747153243
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
PRS
error inflation: 0.4385375128846889
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
some school
error inflation: 0.41435534910211747
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
no school
error inflation: 0.4137132253116306
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
ENG
error inflation: 0.3616655773143047
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
TRN
error inflation: 0.35570397623208383
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
american
error inflation: 0.35207376313216554
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
multiple race
error inflation: 0.3422144312009102
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
CMM
error inflation: 0.3402317880601139
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single native american women
error inflation: 0.32983735150287086
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
twenties
error inflation: 0.3219383635506641
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
advanced degree
error inflation: 0.3138390486635446
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single men
error inflation: 0.31304008033198705
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced_native_american_women
error inflation: 0.31076712724358124
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
MED
error inflation: 0.30895289263654646
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single
error inflation: 0.3083417411970317
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
for profit worker
error inflation: 0.3062342801568294
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single women
error inflation: 0.3028752269008429
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single white women
error inflation: 0.2988960768658835
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
natie american women
error inflation: 0.29652420285945014
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
mar minorities
error inflation: 0.2918779166568376
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
native american
error inflation: 0.2907223557100831
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single black women
error inflation: 0.290140508151739
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
middle aged asian
error inflation: 0.29008243935798805
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
PRD
error inflation: 0.28774184559575744
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
asian
error inflation: 0.2869215262571577
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
asian
error inflation: 0.2869215262571577
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
part time
error inflation: 0.28686397446190515
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
native american men
error inflation: 0.2853945374992718
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced asian women
error inflation: 0.28509462219638315
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
asian men
error inflation: 0.28282064224513437
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
asian_women
error inflation: 0.28168353901249454
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
female
error inflation: 0.2803284893948964
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
MGR
error inflation: 0.27940378777247965
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
overtime native american
error inflation: 0.27880749730245424
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
overtime asian
error inflation: 0.27862789960130485
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
OFF
error inflation: 0.2745210223888914
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
Total
error inflation: 0.27122282113786433
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
full time
error inflation: 0.2708747033764912
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
high school grad
error inflation: 0.27070010061911653
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
married women
error inflation: 0.26727418523468216
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
white women
error inflation: 0.2663597286172005
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
forties
error inflation: 0.26615126638603526
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
thirties
error inflation: 0.2658406626134864
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
male
error inflation: 0.26309101172138605
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
middle aged native american
error inflation: 0.2621359065372493
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
other race
error inflation: 0.2617803705482614
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single asian women
error inflation: 0.2596755257435376
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced white women
error inflation: 0.2576398632976744
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
white
error inflation: 0.2566190980733253
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
MIL
error inflation: 0.2561347677331729
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
black women
error inflation: 0.25375870094591557
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
fifties
error inflation: 0.25134105686131447
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
US born
error inflation: 0.2509081023468779
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
black
error inflation: 0.25077433456884457
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
african
error inflation: 0.25032174018101494
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
married
error inflation: 0.25007987773103924
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
white men
error inflation: 0.24818561641227954
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
black men
error inflation: 0.24759114850253083
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
european
error inflation: 0.24731261803908866
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced
error inflation: 0.24196834282946533
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
oceanian
error inflation: 0.23724713190042968
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
middle aged white
error inflation: 0.2371150655647366
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
sixties
error inflation: 0.23663991538322737
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
married men
error inflation: 0.236398987457637
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
FIN
error inflation: 0.23436928475405217
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
overtime black
error inflation: 0.22902115128238065
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
over time
error inflation: 0.2262945662679715
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
CON
error inflation: 0.2189356842377932
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
elderly
error inflation: 0.2186448690464089
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
SAL
error inflation: 0.21858241902422237
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
BUS
error inflation: 0.21678769254088393
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
self employed worker
error inflation: 0.21635726478286543
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced women
error inflation: 0.21536588343440333
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
non profit worker
error inflation: 0.21502730337279674
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
cow minorities
error inflation: 0.21368379054848752
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
middle aged black
error inflation: 0.21213345069814804
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
SCI
error inflation: 0.21062194722035243
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
overtime white
error inflation: 0.20873558138875337
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
bachelors degree
error inflation: 0.20251886368467398
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
associates degree
error inflation: 0.19902903089512258
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
local govt worker
error inflation: 0.17178215872528135
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
divorced black women
error inflation: 0.15284308224970666
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
RPR
error inflation: 0.12855062083572133
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
LGL
error inflation: 0.11668507014394047
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
EXT
error inflation: 0.1011811941749482
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
PRT
error inflation: 0.039958206863874635
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
ENT
error inflation: 0.023412614098894724
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
CMS
error inflation: 0.010802259415068738
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
EDU
error inflation: 0.0
~~~~~~~~~~~~~~~~~

# Data Exploration

In order to find promising groups, you may find it helpful to do some data exploration. Please include any code or visualizations that you did to do so here. To get you started, here are some things you may find useful:

1. How to grab the predictions of the current PDL on the training data: because f.predict takes a single value as input, you have to use an apply function for this.

In [48]:
# decision stump to replicate our initial model
initial_model = DecisionTreeClassifier(max_depth = 1, random_state=0)
initial_model.fit(train_x.values, train_y) # added .values to train_x and this fixed a warning printout I was getting
stump = build_initial_pdl(initial_model, train_x, train_y, validation_x, validation_y) #bountyHuntWrapper.build_initial_pdl(initial_model, train_x, train_y, validation_x, validation_y)

In [39]:
# if loading the function from storage, then replace f with content since f refers to the model we trained (and content is the name we give to the model when we load it)
preds = train_x.apply(f.predict, axis=1)

In [40]:
val_preds = validation_x.apply(f.predict, axis=1)

2. Getting the zero-one loss of a model restricted to a group you have defined.

In [45]:
# here are all the groups that successfully led to updates of the PDL (there are 28 of them)
f.pred_names

['Total',
 'EDU',
 'CMS',
 'ENT',
 'PRT',
 'LGL',
 'RPR',
 'local govt worker',
 'associates degree',
 'bachelors degree',
 'SCI',
 'cow minorities',
 'non profit worker',
 'self employed worker',
 'BUS',
 'SAL',
 'CON',
 'over time',
 'FIN',
 'oceanian',
 'black',
 'US born',
 'MIL',
 'white',
 'TRN',
 'PRS',
 'CLN',
 'FFF',
 'EAT']

In [49]:
# get all the group functions and their names into a single list
groups = []
for func, name, _ in ranking:
  groups.append({'function': func, 'name': name})
for func, name in manual_functions:
  groups.append({'function': func, 'name': name})

In [None]:
# TRAIN
# get the group based breakdown of the current model error, initial model error, and their differences
error_data_train = []
for d in groups:
  model_error = measure_group_error(f, d['function'], train_x, train_y)
  stump_error = measure_group_error(stump, d['function'], train_x, train_y)
  error_diff = model_error - stump_error
  error_data_train.append({'function': d['function'], 'name': d['name'], 'model_error': model_error, 'stump_error': stump_error, 'error_diff': error_diff})

error_data_train

In [None]:
# VALIDATION
# do the same for validation
error_data_val = []
for d in groups:
  model_error = measure_group_error(f, d['function'], validation_x, validation_y)
  stump_error = measure_group_error(stump, d['function'], validation_x, validation_y)
  error_diff = model_error - stump_error
  error_data_val.append({'function': d['function'], 'name': d['name'], 'model_error': model_error, 'stump_error': stump_error, 'error_diff': error_diff})

error_data_val

In [56]:
# total error of our model
g = lambda x: 1 #here we define a group that just is all the data, replace as you see fit.

model_error_train_total = measure_group_error(f, g, train_x, train_y)
stump_error_train_total = measure_group_error(stump, g, train_x, train_y)
model_error_val_total = measure_group_error(f, g, validation_x, validation_y)
stump_error_val_total = measure_group_error(stump, g, validation_x, validation_y)

error_data_train.append({'function': g, 'name': 'Total', 'model_error': model_error_train_total, 'stump_error': stump_error_train_total, 'error_diff': model_error_train_total - stump_error_train_total})
error_data_val.append({'function': g, 'name': 'Total', 'model_error': model_error_val_total, 'stump_error': stump_error_val_total, 'error_diff': model_error_val_total - stump_error_val_total})

In [65]:
error_data_tr = [{'name': d['name'], 'model_error': d['model_error'], 'stump_error': d['stump_error'], 'error_diff': d['error_diff']} for d in error_data_train]
error_data_v = [{'name': d['name'], 'model_error': d['model_error'], 'stump_error': d['stump_error'], 'error_diff': d['error_diff']} for d in error_data_val]              

In [67]:
# write these to file for future reference

with open('train_data_errors.json', 'w') as f1:
  json.dump(error_data_tr, f1)

with open('validation_data_errors.json', 'w') as f2:
  json.dump(error_data_v, f2)

In [None]:
error_data_tr.sort(key= lambda d: d['error_diff'])
for d in error_data_tr:
  print('~~~~~~~~~~~~~~~~~')
  print(d['name'])
  print('model error: ' + str(d['model_error']))
  print('stump error: ' + str(d['stump_error']))
  print('error difference: ' + str(d['error_diff']))
  print('~~~~~~~~~~~~~~~~~')

In [None]:
error_data_v.sort(key= lambda d: d['error_diff'])
for d in error_data_v:
  print('~~~~~~~~~~~~~~~~~')
  print(d['name'])
  print('model error: ' + str(d['model_error']))
  print('stump error: ' + str(d['stump_error']))
  print('error difference: ' + str(d['error_diff']))
  print('~~~~~~~~~~~~~~~~~')

In [71]:
error_data_tr.sort(key= lambda d: d['stump_error'])

In [73]:
error_data_tr[len(error_data_tr)-1]

{'error_diff': -0.419113205267609,
 'model_error': 0.1355563610003625,
 'name': 'EDU',
 'stump_error': 0.5546695662679715}

In [74]:
for d in error_data_tr:
  print('~~~~~~~~~~~~~~~~~')
  print(d['name'])
  print('error inflation: ' + str(0.5546695662679715 - d['stump_error']))
  print('~~~~~~~~~~~~~~~~~')

~~~~~~~~~~~~~~~~~
EAT
error inflation: 0.5007167229947399
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
FFF
error inflation: 0.4906611717244248
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
HLS
error inflation: 0.4520302700802882
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
CLN
error inflation: 0.4428313747153243
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
PRS
error inflation: 0.4385375128846889
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
some school
error inflation: 0.41435534910211747
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
no school
error inflation: 0.4137132253116306
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
ENG
error inflation: 0.3616655773143047
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
TRN
error inflation: 0.35570397623208383
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
american
error inflation: 0.35207376313216554
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
multiple race
error inflation: 0.3422144312009102
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
CMM
error inflation: 0.3402317880601139
~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
single native american women
error inf

In [None]:
'''
# get errors from initial model
preds = train_x.apply(f.predict, axis=1)
merged_train = train_x.copy()
merged_train['train_y'] = train_y
merged_train['preds'] = preds
# get all the values that we currently mislabel
mistakes = merged_train[merged_train['preds'] != merged_train['train_y']]
xs = mistakes.drop(columns=['train_y', 'preds'])

categories1 = {
    #'AGEP': xs['AGEP'].unique(), # self-defined categories
    'COW': xs['COW'].unique(),
    #'SCHL': xs['SCHL'].unique(), # self-defined categories
    'MAR': xs['MAR'].unique(),
    #'OCCP': xs['OCCP'].unique(), # self-defined categories
    #'POBP': xs['POBP'].unique(), # self-defined categories
    #'WKHP': xs['WKHP'].unique(), # self-defined categories
    'SEX': xs['SEX'].unique(),
    'RAC1P': xs['RAC1P'].unique()
}

initial_error_rates = {}
for key in categories1.keys():
  for val in categories1[key]:
    initial_error_rates[(key, val)] = {
        'train_x count' : len(train_x[train_x[key] == val]),
        'error count' : len(xs[xs[key] == val]),
        'group error rate' : len(xs[xs[key] == val]) / len(train_x[train_x[key] == val])}

#initial_error_rates
'''

3. You can view the training data by calling `train_x`. If you want to only view the data for a single group defined by your group function, you can run the following:

In [None]:
# replace g with whatever your group is
indices = train_x.apply(g, axis=1) == 1
xs = train_x[indices]
ys = train_y[indices]

4. Inspecting the existing PDL: The PDL is stored as an object, and tracks its training errors, validation set errors, and the group functions that are used in lists where the ith element is the group errors of all groups discovered so far on the ith node in the PDL. If you are more curious about the implementation, you can look at the model.py file in the codebase, which doesn't contain anything you can use to adaptively modify your code. (But lives in the same folder as the rest of the codebase just to make importing things easier)

In [None]:
# f is the current model
print(f.train_errors) # group errors on training set.
print(f.train_errors[0]) # this is the group error of each group on the initial PDL. The ith element of f.train_errors is the group error of each group on the ith version of the PDL.
print(f.test_errors) # group errors on validation set
print(f.predicates) # all of the group functions that have been appended so far
print(f.leaves) # all of the h functions appended so far
print(f.pred_names) # the names you passed in for each of the group functions, to more easily understand which are which.

[[0.28344674513010715]]
[0.28344674513010715]
[[0.2796252129471891]]
[<function DecisionlistNode.__init__.<locals>.<lambda> at 0x7f6112a78560>]
[<bound method BaseDecisionTree.predict of DecisionTreeClassifier(max_depth=1, random_state=0)>]
['Total']


5. Looking at the group error of the ith group over each round of updates: Say you found a group at round 5 and you want to know how its group error looked at previous or subsequent rounds. To do so, you can pull `f.train_errors` or `f.test_errors` and look at the ith element of each list as follows:

In [None]:
target_group = 0 # this sets the group whose error you want to look at at each round to the initial model. If I wanted to look at the 1st group introduced, would change to a 1, e.g.
group_errs = [f.train_errors[i][target_group] for i in range(len(f.train_errors))]
group_errs

[0.28344674513010715]

# Submission File

For details about my groups, please see the section of the notebook titled Bounty Hunting. It should be very straightforward how this was achieved. I generated groups, ordered them according to how the initial model performed on them (those with the highest error being considered first), and then passed them into the simple update function. I did this for single-variable constraints. I also additionally thought of some more complex intersectional groups that were likely to be treated unfairly and passed them into the simple update function as well. 

I have included the definition of the groups in the submissions file for completeness, but everything is really straightforwardly laid out in the notebook, and I don't have separate models h defined for each group since I was using simple update. 

In [None]:
[train_x, train_y, _, _] = get_data() #bountyHuntData.get_data()

# Here, I define all of the g's tested that actually were integrated into my PDL
# possible group definitions

# class of worker
def for_profit_worker(x):
  if x['COW'] == 1:
      return 1
  else:
      return 0

def non_profit_worker(x):
  if x['COW'] == 2:
      return 1
  else:
      return 0

def local_govt_worker(x):
  if x['COW'] == 3:
      return 1
  else:
      return 0

def self_employed_worker(x):
  if x['COW'] == 6:
      return 1
  else:
      return 0

def cow_minorities(x):
  if x['COW'] == 4:
      return 1
  elif x['COW'] == 5:
      return 1
  elif x['COW'] == 7:
      return 1
  elif x['COW'] == 8:
      return 1
  else:
      return 0

# marital status
def married(x):
  if x['MAR'] == 1:
      return 1
  else:
      return 0

def divorced(x):
  if x['MAR'] == 3:
      return 1
  else:
      return 0

def mar_minorities(x):
  if x['MAR'] == 2:
      return 1
  elif x['MAR'] == 4:
      return 1
  else:
      return 0

def single(x):
  if x['MAR'] == 5:
      return 1
  else:
      return 0

# sex
def male(x):
  if x['SEX'] == 1:
      return 1
  else:
      return 0

def female(x):
  if x['SEX'] == 2:
      return 1
  else:
      return 0

# race
def white(x):
  if x['RAC1P'] == 1:
      return 1
  else:
      return 0

def black(x):
  if x['RAC1P'] == 2:
      return 1
  else:
      return 0

def other_race(x):
  if x['RAC1P'] == 9:
      return 1
  else:
      return 0

def asian(x):
  if x['RAC1P'] == 6:
      return 1
  else:
      return 0

def multiple_race(x):
  if x['RAC1P'] == 8:
      return 1
  else:
      return 0

def native_american(x):
  if x['RAC1P'] == 3:
      return 1
  elif x['RAC1P'] == 4:
      return 1
  elif x['RAC1P'] == 5:
      return 1
  elif x['RAC1P'] == 7:
      return 1
  else:
      return 0

# age groups
def twenties(x):
  if x['AGEP'] < 30:
      return 1
  else:
      return 0

def thirties(x):
  if x['AGEP'] < 40 and x['AGEP'] >= 30:
      return 1
  else:
      return 0

def forties(x):
  if x['AGEP'] < 50 and x['AGEP'] >= 40:
      return 1
  else:
      return 0

def fifties(x):
  if x['AGEP'] < 60 and x['AGEP'] >= 50:
      return 1
  else:
      return 0

def sixties(x):
  if x['AGEP'] < 70 and x['AGEP'] >= 60:
      return 1
  else:
      return 0

def elderly(x):
  if x['AGEP'] >= 70:
      return 1
  else:
      return 0

# education level
def no_school(x):
  if x['SCHL'] == 1:
    return 1
  else:
    return 0

def some_school(x):
  if x['SCHL'] > 1 and x['SCHL'] <= 15:
    return 1
  else:
    return 0

def high_school_grad(x):
  if x['SCHL'] > 15 and x['SCHL'] <= 19:
    return 1
  else:
    return 0

def assoc_degree(x):
  if x['SCHL'] == 20:
    return 1
  else:
    return 0

def assoc_degree(x):
  if x['SCHL'] == 20:
    return 1
  else:
    return 0

def bachelor_degree(x):
  if x['SCHL'] == 21:
    return 1
  else:
    return 0

def advanced_degree(x):
  if x['SCHL'] > 21:
    return 1
  else:
    return 0

# work hours
def part_time(x):
  if x['WKHP'] < 30:
      return 1
  else:
      return 0

def full_time(x):
  if x['WKHP'] < 60 and x['WKHP'] >= 30:
      return 1
  else:
      return 0

def over_time(x):
  if x['WKHP'] >= 60:
      return 1
  else:
      return 0

# occupation
def MGR(x):
  if x['OCCP'] <= 440:
    return 1
  else:
    return 0

def BUS(x):
  if x['OCCP'] >= 500 and x['OCCP'] <= 750:
    return 1
  else:
    return 0

def FIN(x):
  if x['OCCP'] >= 800 and x['OCCP'] <= 960:
    return 1
  else:
    return 0 

def CMM(x):
  if x['OCCP'] >= 1005 and x['OCCP'] <= 1240:
    return 1
  else:
    return 0

def ENG(x):
  if x['OCCP'] >= 1305 and x['OCCP'] <= 1560:
    return 1
  else:
    return 0

def SCI(x):
  if x['OCCP'] >= 1600 and x['OCCP'] <= 1980:
    return 1
  else:
    return 0

def CMS(x):
  if x['OCCP'] >= 2001 and x['OCCP'] <= 2060:
    return 1
  else:
    return 0

def LGL(x):
  if x['OCCP'] >= 2105 and x['OCCP'] <= 2180:
    return 1
  else:
    return 0

def EDU(x):
  if x['OCCP'] >= 2205 and x['OCCP'] <= 2555:
    return 1
  else:
    return 0

def ENT(x):
  if x['OCCP'] >= 2600 and x['OCCP'] <= 2920:
    return 1
  else:
    return 0

def MED(x):
  if x['OCCP'] >= 3000 and x['OCCP'] <= 3550:
    return 1
  else:
    return 0

def HLS(x):
  if x['OCCP'] >= 3601 and x['OCCP'] <= 3655:
    return 1
  else:
    return 0

def PRT(x):
  if x['OCCP'] >= 3700 and x['OCCP'] <= 3960:
    return 1
  else:
    return 0

def EAT(x):
  if x['OCCP'] >= 4000 and x['OCCP'] <= 4160:
    return 1
  else:
    return 0

def CLN(x):
  if x['OCCP'] >= 4200 and x['OCCP'] <= 4255:
    return 1
  else:
    return 0

def PRS(x):
  if x['OCCP'] >= 4330 and x['OCCP'] <= 4655:
    return 1
  else:
    return 0

def SAL(x):
  if x['OCCP'] >= 4700 and x['OCCP'] <= 4965:
    return 1
  else:
    return 0

def OFF(x):
  if x['OCCP'] >= 5000 and x['OCCP'] <= 5940:
    return 1
  else:
    return 0

def FFF(x):
  if x['OCCP'] >= 6005 and x['OCCP'] <= 6130:
    return 1
  else:
    return 0

def CON(x):
  if x['OCCP'] >= 6200 and x['OCCP'] <= 6765:
    return 1
  else:
    return 0

def EXT(x):
  if x['OCCP'] >= 6800 and x['OCCP'] <= 6950:
    return 1
  else:
    return 0

def RPR(x):
  if x['OCCP'] >= 7000 and x['OCCP'] <= 7640:
    return 1
  else:
    return 0

def PRD(x):
  if x['OCCP'] >= 7700 and x['OCCP'] <= 8990:
    return 1
  else:
    return 0

def TRN(x):
  if x['OCCP'] >= 9005 and x['OCCP'] <= 9760:
    return 1
  else:
    return 0

def MIL(x):
  if x['OCCP'] >= 9800 and x['OCCP'] <= 9830:
    return 1
  else:
    return 0

# place of birth continents
def us_born(x):
  if x['POBP'] < 100:
    return 1
  else: 
    return 0

def european(x):
  if x['POBP'] < 200 and x['POBP'] >= 100:
    return 1
  else: 
    return 0

def asian(x):
  if x['POBP'] < 300 and x['POBP'] >= 200:
    return 1
  else: 
    return 0

def american(x):
  if x['POBP'] < 400 and x['POBP'] >= 300:
    return 1
  else: 
    return 0

def african(x):
  if x['POBP'] < 500 and x['POBP'] >= 400:
    return 1
  else: 
    return 0

def oceanian(x):
  if x['POBP'] < 600 and x['POBP'] >= 500:
    return 1
  else: 
    return 0

# manually defined groups that tend to be systematically mistreated and would be educated guesses to test on the model

def black_men(x):
  if x['RAC1P'] == 2 and x['SEX'] == 1:
    return 1
  else: 
    return 0

def black_women(x):
  if x['RAC1P'] == 2 and x['SEX'] == 2:
    return 1
  else: 
    return 0

def asian_men(x):
  if x['RAC1P'] == 6 and x['SEX'] == 1:
    return 1
  else: 
    return 0

def asian_women(x):
  if x['RAC1P'] == 6 and x['SEX'] == 2:
    return 1
  else: 
    return 0

def white_men(x):
  if x['RAC1P'] == 1 and x['SEX'] == 1:
    return 1
  else: 
    return 0

def white_women(x):
  if x['RAC1P'] == 1 and x['SEX'] == 2:
    return 1
  else: 
    return 0

def native_american_men(x):
  if x['RAC1P'] == 3 and x['SEX'] == 1:
      return 1
  elif x['RAC1P'] == 4 and x['SEX'] == 1:
      return 1
  elif x['RAC1P'] == 5 and x['SEX'] == 1:
      return 1
  elif x['RAC1P'] == 7 and x['SEX'] == 1:
      return 1
  else:
      return 0

def native_american_women(x):
  if x['RAC1P'] == 3 and x['SEX'] == 2:
      return 1
  elif x['RAC1P'] == 4 and x['SEX'] == 2:
      return 1
  elif x['RAC1P'] == 5 and x['SEX'] == 2:
      return 1
  elif x['RAC1P'] == 7 and x['SEX'] == 2:
      return 1
  else:
      return 0

def middle_aged_black(x):
  if x['RAC1P'] == 2 and x['AGEP'] >= 40 and x['AGEP'] < 70:
      return 1
  else:
      return 0

def middle_aged_asian(x):
  if x['RAC1P'] == 6 and x['AGEP'] >= 40 and x['AGEP'] < 70:
      return 1
  else:
      return 0

def middle_aged_white(x):
  if x['RAC1P'] == 1 and x['AGEP'] >= 40 and x['AGEP'] < 70:
      return 1
  else:
      return 0

def middle_aged_native_american(x):
  if x['RAC1P'] == 3 and x['AGEP'] >= 40 and x['AGEP'] < 70:
      return 1
  elif x['RAC1P'] == 4 and x['AGEP'] >= 40 and x['AGEP'] < 70:
      return 1
  elif x['RAC1P'] == 5 and x['AGEP'] >= 40 and x['AGEP'] < 70:
      return 1
  elif x['RAC1P'] == 7 and x['AGEP'] >= 40 and x['AGEP'] < 70:
      return 1
  else:
      return 0

def overtime_white(x):
  if x['RAC1P'] == 1 and x['WKHP'] >= 60:
      return 1
  else:
      return 0

def overtime_black(x):
  if x['RAC1P'] == 2 and x['WKHP'] >= 60:
      return 1
  else:
      return 0

def overtime_asian(x):
  if x['RAC1P'] == 6 and x['WKHP'] >= 60:
      return 1
  else:
      return 0

def overtime_native_american(x):
  if x['RAC1P'] == 3 and x['WKHP'] >= 60:
      return 1
  elif x['RAC1P'] == 4 and x['WKHP'] >= 60:
      return 1
  elif x['RAC1P'] == 5 and x['WKHP'] >= 60:
      return 1
  elif x['RAC1P'] == 7 and x['WKHP'] >= 60:
      return 1
  else:
      return 0

def married_men(x):
  if x['MAR'] == 1 and x['SEX'] == 1:
    return 1
  else: 
    return 0

def married_women(x):
  if x['MAR'] == 1 and x['SEX'] == 2:
    return 1
  else: 
    return 0

def divorced_men(x):
  if x['MAR'] == 3 and x['SEX'] == 1:
    return 1
  else: 
    return 0

def divorced_women(x):
  if x['MAR'] == 3 and x['SEX'] == 2:
    return 1
  else: 
    return 0

def single_men(x):
  if x['MAR'] == 5 and x['SEX'] == 1:
    return 1
  else: 
    return 0

def single_women(x):
  if x['MAR'] == 5 and x['SEX'] == 2:
    return 1
  else: 
    return 0

def divorced_black_women(x):
  if x['MAR'] == 3 and x['SEX'] == 2 and x['RAC1P'] == 2:
    return 1
  else: 
    return 0   

def single_black_women(x):
  if x['MAR'] == 5 and x['SEX'] == 2 and x['RAC1P'] == 2:
    return 1
  else: 
    return 0

def divorced_white_women(x):
  if x['MAR'] == 3 and x['SEX'] == 2 and x['RAC1P'] == 1:
    return 1
  else: 
    return 0   

def single_white_women(x):
  if x['MAR'] == 5 and x['SEX'] == 2 and x['RAC1P'] == 1:
    return 1
  else: 
    return 0

def divorced_asian_women(x):
  if x['MAR'] == 3 and x['SEX'] == 2 and x['RAC1P'] == 6:
    return 1
  else: 
    return 0   

def single_asian_women(x):
  if x['MAR'] == 5 and x['SEX'] == 2 and x['RAC1P'] == 6:
    return 1
  else: 
    return 0

def divorced_native_american_women(x):
  if x['MAR'] == 3 and x['SEX'] == 2 and (x['RAC1P'] == 3 or x['RAC1P'] == 4 or x['RAC1P'] == 5 or x['RAC1P'] == 7):
    return 1
  else: 
    return 0   

def single_native_american_women(x):
  if x['MAR'] == 5 and x['SEX'] == 2 and (x['RAC1P'] == 3 or x['RAC1P'] == 4 or x['RAC1P'] == 5 or x['RAC1P'] == 7):
    return 1
  else: 
    return 0

# I used simple udate, so all of my h's followed this pattern
# h = build_model(train_x, train_y, g, dt_depth=10) #bountyHuntWrapper.build_model(train_x, train_y, g1, dt_depth=10)  # change this to be your first h