In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import aequitas.plot as ap
from aequitas.bias import Bias
from aequitas.group import Group
from datetime import datetime
from matplotlib import pyplot as plt
from dateutil.relativedelta import relativedelta
from utils.constants import PREDICTIONS_DIR, CONFIGS_PATH
import postmodeling.analyze_labels as analyze_labels
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from postmodeling.evaluation import (
    get_evaluation,
    get_predictions,
    rank_models,
    get_best_modelsets,
    get_confusion_matrix,
    plot_confusion_matrix,
    get_features_test_pred,
    create_crosstabs,
    get_models_info,
    get_model_info_from_experiment_ids,
    get_test_pred_labels_from_csv,
    _calculate_metric
)
from postmodeling.plotting import (
    plot_pr_curve,
    plot_score_dist,
    plot_crosstabs_models,
    plot_feature_importance,
    plot_disparity
)
from utils.helpers import (
    get_database_connection,
    get_model_ids
)
from postmodeling.fairness import (
    get_score_attr_df,
    get_group_metrics,
    get_demographics_data,
    get_absolute_metrics,
    enrich_demographics
)
from postmodeling.analyze_labels import (
    get_preds_split_labels,
    get_label_aggregations,
    plot_split_labels
)
from altair_saver import save

In [None]:
import warnings
warnings.filterwarnings('ignore') # To ignore seaborn warnings

# Overview

This notebook runs the postmodeling pipeline for a particular model. In particular, we 

    [x] choose the best model and best baseline for the particular county and a particular label group.
    [x] visualize the P/R curve for this model against the best baseline model.
    [x] visualize the feature importance for the best model.
    [x] show the cross tabs for the best model across sex, race, and age.
    [x] do the fairness audit for sex and race.
    [x] analyze which labels the model is picking up on.

In [None]:
COUNTY = 'joco'
COUNTY_k = 75
LABEL_GROUP = 'Potentially fatal'
MONTHS_FUTURE = 6
MIN_DATES = 6 if COUNTY == 'doco' else 4
FIGSIZE = (16, 12)
db_conn = get_database_connection()

plt.rcParams["figure.figsize"] = FIGSIZE

# Model selection
Here we pick the best model and baseline for the particular county.

In [None]:
# Select the best models for each label group
best_models = get_best_modelsets(db_conn, county=COUNTY, rank_on='regret', top=1, months_future=MONTHS_FUTURE, min_dates=MIN_DATES)

In [None]:
best_models

In [None]:
best_baselines = get_best_modelsets(
    db_conn, county=COUNTY, rank_on='regret', top=1,
    model_types=['FeatureRanker', 'LinearRanker'], exclude_types=None, months_future=MONTHS_FUTURE, min_dates=MIN_DATES
)

In [None]:
best_baselines

# Precision / Recall curve
Here we visualize the precision and recall curve for the best model and the best baseline for the latest validation split.

In [None]:
def get_best_model_ids(best_models):
    # Get the best model set id
    best_model_set_id = best_models[best_models['label_group'] == LABEL_GROUP]['model_set_id'].values[0]
    
    # Get the model id for last split
    model_ids = get_model_ids(db_conn, best_model_set_id)
    
    return model_ids
    
def get_best_model_id_last(best_models):
    model_ids = get_best_model_ids(best_models)
    return np.min(model_ids)
    
def get_best_evaluation(best_models, model_type='Model'):
    model_id = get_best_model_id_last(best_models)
    
    # Get the evaluation for this model id
    df_eval = get_evaluation(db_conn, model_id)
    df_eval = df_eval[df_eval['county'] == COUNTY]
    df_eval['Type'] = model_type
    
    return df_eval

In [None]:
df_best = get_best_evaluation(best_models)

In [None]:
df_baseline = get_best_evaluation(best_baselines, model_type='Baseline')

In [None]:
df = pd.concat([df_best, df_baseline])

In [None]:
p = plot_pr_curve(df, county=COUNTY, figsize=FIGSIZE, label_group=LABEL_GROUP);

# Score distribution
Here we visualize the score distribution for the best model.

In [None]:
def add_predictions(df_eval, k=75):
    df_eval = df_eval.copy()
    y_pred = np.where(np.logical_and(df_eval["county"] == COUNTY, df_eval["county_k"] <= k), 1, 0)
    
    df_eval['predictions'] = y_pred
    return df_eval

In [None]:
# Drop duplicate entries (precision, recall)
df_score = df_best[df_best['metric'] == 'precision']
df_score = add_predictions(df_score, k=COUNTY_k)

In [None]:
plot_score_dist(df_score);

# Feature importance
Here we plot the feature importance of the best model.

In [None]:
latest_model_id = get_best_model_id_last(best_models)
p = plot_feature_importance(db_conn, [latest_model_id]);

In [None]:
#p.get_figure().savefig('feature_importance_johnson.eps', dpi=200, bbox_inches='tight')

In [None]:
# Custom feature names for best Johnson county model (random forest)
feature_names = [
    'Age', 'Any event (DSL)', 'Ambulance run (DSL)', 'Douglas service (DSL)', 'JIMS charges (DSL)',
    'Jail booking (DSF)', 'Diagnosis (DSL)', 'JIMS prosecution charges (DSL)', 'Jail booking (DSL)', 'Admission MHC (DSL)',
    'Drug ambulance run (DSL)', 'Overall ambulance runs', 'Alcohol ambulance run (DSL)', 'Johnson service (DSL)',
    'Ambulance runs last 5 years', 'Overall jail bookings', 'Diagnosis MHC (DSL)', 'Discharges MHC (DSL)', 'Ambulance runs last 2 years', 'Drug ambulance runs 2 years'
]
#p = plot_feature_importance(db_conn, [latest_model_id], feature_names=feature_names);

# Confusion matrix and crosstabs
Here we look the confusion matrix and crosstabs for the best model and the latest split.

## Confusion matrix
Here we see the confusion matrix for the best model and latest split.

In [None]:
plot_confusion_matrix(confusion_matrix(df_score.label, df_score.predictions)).plot();

## Crosstabs: Sex

In [None]:
doco_k = None if COUNTY == 'joco' else COUNTY_k
joco_k = None if COUNTY == 'doco' else COUNTY_k

# Crosstabs: categorical demographics
features_test_pred = get_features_test_pred(db_conn, latest_model_id, 'demographics_cat', doco_k=doco_k, joco_k=joco_k)
features_test_pred = features_test_pred.fillna('Missing')

In [None]:
label_crosstab, pred_crosstab = create_crosstabs(features_test_pred, "dem_sex")
print("SEX:")
print('----------------------------------')
print(label_crosstab)
print('----------------------------------')
print(pred_crosstab)

## Crosstabs: Race

In [None]:
label_crosstab, pred_crosstab = create_crosstabs(features_test_pred, "dem_race")
print("RACE:")
print('----------------------------------')
print(label_crosstab)
print('----------------------------------')
print(pred_crosstab)

## Crosstabs: Age

In [None]:
# Crosstabs: demographics numeric 
features_test_pred_age = get_features_test_pred(db_conn, latest_model_id, 'demographics_num', doco_k=doco_k, joco_k=joco_k)

# Age spliting buckets
bins = [0.0, 20.0, 40.0, 65.0, np.inf]
names = ['1. <20', '2. 21-40', '3. 41-65', '4. 66+']
split_tuple = (bins, names)

In [None]:
# Age: Both Counties
label_crosstab, pred_crosstab = create_crosstabs(features_test_pred_age, "dem_age", split_tuple)
print("AGE:")
print('-----------------------------------------------------')
print(label_crosstab)
print('-----------------------------------------------------')
print(pred_crosstab)

# Fairness analysis
Here we assess potential disparities of the best model on the latest split.

In [None]:
df_dem = get_demographics_data(db_conn, attributes=['sex', 'race', 'ethnicity'])

In [None]:
attr_and_ref_groups = {'sex': 'MALE', 'race': 'W', 'hispanic': 'YES'}

In [None]:
# Get the score and attribute dataframe
df_fair = get_score_attr_df(df_dem, latest_model_id)

# Enrich the table with demographics from other tables (that do not have an event date)
df_fair = enrich_demographics(df_fair, df_dem)

In [None]:
df_fair_county = df_fair[df_fair['county'] == COUNTY]
# Calculate the metrics
xtab, df_metrics = get_group_metrics(df_fair_county, attr_and_ref_groups)

## Disparities across Sex

In [None]:
df_fair_county[df_fair_county['county_k'] <= COUNTY_k].value_counts('sex')

In [None]:
df_fair_county.value_counts('sex')

In [None]:
plot_disparity(df_metrics, 'precision', 'sex')

In [None]:
plot_disparity(df_metrics, 'tpr', 'sex')

## Disparities across Race

In [None]:
df_fair_county[df_fair_county['county_k'] <= COUNTY_k].value_counts('race')

In [None]:
df_fair_county.value_counts('race')

In [None]:
df_metrics_r = df_metrics[df_metrics['attribute_value'].str.contains('B|W|I|A|MISSING')]

In [None]:
plot_disparity(df_metrics_r, 'precision', 'race')

In [None]:
plot_disparity(df_metrics_r, 'tpr', 'race')
#save(p, 'race_johnson.png', dpi=600)

In [None]:
# Check how many true labels
df_metrics[['attribute_name', 'attribute_value', 'group_label_pos', 'group_size', 'tp', 'fp']]

In [None]:
# Show all metrics
get_absolute_metrics(xtab)

# Analyzing labels
Here we look into what labels the model seems to be actually picking up, first for the next six months, then for the rest of the time period.

In [None]:
def get_label_counts(df_joco, df_doco):
    
    if COUNTY == 'joco':
        label_counts_all = get_label_aggregations(df_joco)
    if COUNTY == 'doco':
        label_counts_all = get_label_aggregations(df_doco)
        
    return label_counts_all

In [None]:
best_model_ids = get_best_model_ids(best_models)

In [None]:
# Get tables for just the next MONTHS_FUTURE
df_joco, df_doco, _ = get_preds_split_labels(db_conn, best_model_ids)

In [None]:
print('COUNTS FOR THE VALIDATION PERIOD OF ' + str(MONTHS_FUTURE) + ' MONTHS')
label_counts = get_label_counts(df_joco, df_doco)
display(label_counts)

In [None]:
# Plot just within the validation window:
plot_split_labels(label_counts, latest_model_id, months_future=MONTHS_FUTURE);

In [None]:
# Get tables for all time
df_joco_all, df_doco_all, _ = get_preds_split_labels(db_conn, best_model_ids, label_tablename='split_labels_all_time')

In [None]:
print('\nCOUNTS FOR ALL TIME IN FUTURE OF THE AS OF DATE')
label_counts_all = get_label_counts(df_joco_all, df_doco_all)
display(label_counts_all)

In [None]:
# Since we look at the latest split, this is very similar to the first plot, maybe not necessary to show it at all
#plot_split_labels(label_counts_all, latest_model_id, months_future='any');