In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import warnings  
import csv
# ignore pandas FutureWarning below
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from utils.constants import PREDICTIONS_DIR
from dateutil.relativedelta import relativedelta
from datetime import date
from postmodeling.aggregate_lists import df_val_date, get_referral_lists, get_all_referred_joids, get_aggregated_referral_info
from postmodeling.analyze_labels import get_all_flagged_events
from utils.helpers import get_database_connection
from pipeline.matrix import make_str_array

In [None]:
min_val_date, max_val_date = date(2020, 6, 1), date(2021, 5, 1)

In [None]:
joco_exp_ids = [372, 231, 233, 363, 8]
joco_model_set_ids = [4739, 1924, 1926, 4733, 38]
doco_exp_ids = [149, 231, 233, 201, 371]
doco_model_set_ids = [792, 1924, 1926, 1487, 4737]
# best_models = {
#     'joco': {'exp_ids': joco_exp_ids, 'model_set_ids': joco_model_set_ids},
#     'doco': {'exp_ids': doco_exp_ids, 'model_set_ids': doco_model_set_ids}
# }
best_models = {
    'joco': {'model_set_ids': joco_model_set_ids},
    'doco': {'model_set_ids': doco_model_set_ids}
}
label_groups = 'death only', 'potentially fatal', 'suicide-related only', 'drug-related only', 'all behavioral crises'
db_conn = get_database_connection()

In [None]:
concat_df, deaths_df = get_aggregated_referral_info(db_conn, best_models, label_groups, min_val_date, max_val_date)

In [None]:
deaths_df.head()

In [None]:
concat_df.head()

In [None]:
concat_df[(concat_df['county'] == 'joco') & (concat_df['label_group'] == 'potentially fatal')]

In [None]:
concat_df_no_death_flag = concat_df[concat_df['event_type'] != 'death_flag']

In [None]:
concat_df_no_death_flag.head()

In [None]:
g = sns.catplot(y='event_type', x='validation_period', row='county', col='label_group', kind='bar',
        data=concat_df_no_death_flag, orient='h', color='#33485E')
# for ax in g.fig.axes:
#     ax.tick_params(axis='x', rotation=30)

In [None]:
for county in ['joco', 'doco']:
    display(concat_df[(concat_df['event_type'] == 'suicide_attempt_flag') & (concat_df['county'] == county) & (concat_df['label_group'] == 'potentially fatal')])

concat_df[(concat_df['event_type'] == 'suicide_attempt_flag') & (concat_df['county'] == 'doco') & (concat_df['label_group'] == 'potentially fatal')]

In [None]:
sns.catplot(row='county', col='label_group', kind='bar', data=deaths_df, orient='h', color='#33485E')

### Prettier plots

In [None]:
sns.set(font_scale = 2)

concat_df_no_death_flag = concat_df[concat_df['event_type'] != 'death_flag']
concat_df_no_death_flag['event_type'] = concat_df_no_death_flag['event_type'].str.replace('_flag', '')
concat_df_no_death_flag['event_type'] = concat_df_no_death_flag['event_type'].str.replace('_', ' ')

concat_df_no_death_flag_joco = concat_df_no_death_flag[concat_df_no_death_flag['county']=='joco']
concat_df_no_death_flag_doco = concat_df_no_death_flag[concat_df_no_death_flag['county']=='doco']

g = sns.catplot(y='event_type', x='validation_period', col='label_group', kind='bar',
    data=concat_df_no_death_flag_joco, orient='h', color='#33485E')

g.set_titles("Model for {col_name}", size = 20)
g.set(xticks = [0, 100, 200, 300, 400])
g.set(xlabel = 'Counts')
g.set(ylabel = 'Event Type')
g.fig.suptitle('Johnson County', y = 1.05)


g = sns.catplot(y='event_type', x='validation_period', col='label_group', kind='bar',
    data=concat_df_no_death_flag_doco, orient='h', color='#33485E')

g.set_titles("Model for {col_name}", size = 20)
g.set(xticks = [0, 100, 200, 300, 400])
g.set(xlabel = 'Counts')
g.set(ylabel = 'Event Type')
g.fig.suptitle('Douglas County', y = 1.05)


# for ax in g.fig.axes:
#     ax.tick_params(axis='x', rotation=30)

In [None]:
deaths_df = deaths_df.rename(columns = {'suic_or_od': 'both'})
deaths_df_joco = deaths_df[deaths_df['county'] == 'joco']
deaths_df_doco = deaths_df[deaths_df['county'] == 'doco']

g = sns.catplot(col='label_group', kind='bar', data=deaths_df_joco, orient='h', color='#33485E')
g.set_titles("Model for {col_name}", size = 20)
g.set(xticks = [0,2,4,6,8])
g.fig.suptitle('Johnson County', y = 1.05)

g = sns.catplot(col='label_group', kind='bar', data=deaths_df_doco, orient='h', color='#33485E')
g.set_titles("Model for {col_name}", size = 20)
g.set(xticks = [0,2,4,6,8])
g.fig.suptitle('Douglas County', y = 1.05)

### Get a sample lists for each county

In [None]:
model_set_id = 1924
min_val_date = max_val_date = date(2021, 9, 1)
ref_lists = {}
for county in ['joco', 'doco']:
    low_k = (40 if county == 'doco' else 75)
    high_k = low_k
    all_lists = get_referral_lists(model_set_id, low_k, high_k, county, min_val_date, max_val_date, 1)
    ref_lists[county] = all_lists[high_k][0]

In [None]:
# Write joids to a csv
for county in ['joco', 'doco']:
    print(len(ref_lists[county]))
    with open(f'sample_list_{county}.csv', 'w') as f:
        writer = csv.writer(f)
        for joid in ref_lists[county]:
            writer.writerow([joid])

In [None]:
dfs = []
for county in ['joco', 'doco']:
    print(len(ref_lists[county]))
    with open(f'sample_list_{county}.csv', 'r') as f:
        dfs.append(pd.read_csv(f, header=None, names=['joid']))

In [None]:
table_names = ['joco', 'doco']
db_conn = get_database_connection()
for df, tn in zip(dfs, table_names):
    df.to_sql(tn, db_conn, schema='sample_lists')