In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import geopandas as gpd
import statsmodels.formula.api as smf
import statsmodels.api as sm
from pathlib import Path



### Binning

In [None]:
def bin_column(df, col, nbins, label_strategy = None, bin_strategy = 'quantiles'):
    if label_strategy == 'means':
        label_strategy = [f'{col}_bin_{i}' for i in range(nbins)] #not quite right
    if bin_strategy == 'quantiles':
        binned = pd.qcut(df[col], q=nbins, labels=label_strategy, duplicates='drop', precision = 1)
    elif bin_strategy == 'equal':
        binned = pd.cut(df[col], bins=nbins, labels=label_strategy, include_lowest = True)
    elif bin_strategy == 'readable_5':
        bins = [df[col].min(), 20] + list(range(22, 2+int(np.ceil(df[col].max())), 2))#list(range(5*int(df[col].min() // 5), 5+5*int(df[col].max() // 5), 5))
        binned = pd.cut(df[col], bins = bins, labels = label_strategy, include_lowest = True, right=False)
    elif bin_strategy == '0_1_more':
        _ , otherbins = pd.qcut(df.loc[df[col] > 1, col], q=nbins-2, retbins = True, precision = 1)
        otherbins[-1] = otherbins[-1]+1
        bins = [0,1] + list(otherbins)
        binned = pd.cut(df[col], bins=bins, labels = label_strategy, include_lowest = True, right=False)
    elif bin_strategy == '0_more':
        _ , otherbins = pd.qcut(df.loc[df[col] >= 1, col], q=nbins-2, retbins = True, precision = 1)
        otherbins[-1] = otherbins[-1]+1
        bins = [0,] + list(otherbins)
        binned = pd.cut(df[col], bins=bins, labels = label_strategy, include_lowest = True, right=False)
    else:
        raise ValueError('bin_strategy must be either "quantiles" or "equal"')
    return binned

def group_and_bin_column(df, group_cols, bin_col, nbins, bin_strategy = 'quantiles', label_strategy = None, result_column = None, keep_count = False):
    if result_column == None:
        result_column = bin_col + '_bin'
    grouped_df = df.groupby(group_cols + [bin_col], as_index=False).size()
    if not keep_count:
        grouped_df = grouped_df.drop(columns=['size'])
    grouped_df[result_column] = bin_column(grouped_df, bin_col, nbins, label_strategy, bin_strategy) 
    return pd.merge(df, grouped_df, how='left')
    
def group_and_bin_column_definition(df, bin_col, bin_category, nbins, bin_strategy = None, result_column = None):
    if bin_category == 'household':
        group_cols = ['nid', 'hh_id', 'psu', 'year_start']
        bin_strategy = 'quantiles' if bin_strategy is None else bin_strategy
    if bin_category == 'location':
        group_cols = ['lat', 'long']
        bin_strategy = 'quantiles' if bin_strategy is None else bin_strategy
    if bin_category == 'country':
        group_cols = ['iso3']
        bin_strategy = 'quantiles' if bin_strategy is None else bin_strategy
    return group_and_bin_column(df, group_cols, bin_col, nbins, bin_strategy = bin_strategy, result_column = result_column)

printable_names = {
        'income_per_day_bin':'Income Per Day (2010 USD)',
        'temp_bin': 'Yearly Temperature (Mean)',
        'temp_bin_quants' : 'Mean Yearly Temperature',
        'temp_avg_bin' : 'Mean Temperature (5 year)',
        'precip_bin': 'Yearly Precipitation',
        'precip_avg_bin': 'Mean Yearly Precipitation (5 year avg)',
        'over30_bin': 'Days over 30 C',
        'over30_avg_bin': 'Days over 30 C (5 year avg)',
        'stunting': 'Stunting',
        'wasting':'Wasting',
        'underweight':'Underweight',
        'over30_avgperyear_bin' : 'Avg Days over 30 in life',
        'over30_birth_bin' : 'Days over 30 in year of birth', 
        'temp_diff_birth_bin': 'temp diff birth',
        'bmi': 'BMI',
        'low_adult_bmi' : 'Low adult BMI',
    }


In [None]:
OUT_ROOT = Path("/mnt/team/rapidresponse/pub/population/data/02-processed-data/cgf_bmi")
merged_df = pd.read_parquet(OUT_ROOT / "bmi_processed.parquet")

In [None]:
nbins = 10
merged_binned_df = merged_df.copy()
merged_binned_df = group_and_bin_column_definition(merged_binned_df, 'income_per_day', 'household', nbins)
merged_binned_df = group_and_bin_column_definition(merged_binned_df, 'over30', 'location', 10, bin_strategy = '0_more')
merged_binned_df = group_and_bin_column_definition(merged_binned_df, 'temp', 'location', nbins, bin_strategy = 'readable_5')
merged_binned_df = group_and_bin_column_definition(merged_binned_df, 'temp', 'location', nbins, bin_strategy = 'quantiles', result_column = 'temp_bin_quants')
merged_binned_df = group_and_bin_column_definition(merged_binned_df, 'precip', 'location', nbins, bin_strategy = 'quantiles')
#merged_binned_df = group_and_bin_column_definition(merged_binned_df, 'over30_avg', 'location', nbins, bin_strategy = '0_1_more')
#merged_binned_df = group_and_bin_column_definition(merged_binned_df, 'temp_avg', 'location', nbins, bin_strategy = 'quantiles')
#merged_binned_df = group_and_bin_column_definition(merged_binned_df, 'precip_avg', 'location', nbins, bin_strategy = 'quantiles')


#TODO Assert NAs and length here
cols_to_verify = ['over30', 'over30_bin', 'temp', 'temp_bin', 'precip', 'precip_bin', 'income_per_day', 'income_per_day_bin',] 
        #'over30_avg', 'over30_avg_bin', 'temp_avg', 'temp_avg_bin', 'precip_avg', 'precip_avg_bin']
assert(merged_binned_df[cols_to_verify].notna().all().all())

In [None]:
def plot_heatmap(df, temp_col, wealth_col = 'income_per_day_bin', country = None, year = None, margins = False, filter = None, value_col = 'cgf_value'):  
    plot_df = df

    if filter is not None:
        plot_df = plot_df.query(filter)
    if country:
        plot_df = plot_df[plot_df['iso3'] == country]
    if year:
        plot_df = plot_df[plot_df['year_start'] == year]
    plot_df = plot_df.rename(columns = printable_names)
    pivot_table_mean = plot_df.pivot_table(values=printable_names[value_col], index=printable_names[wealth_col], 
        columns=printable_names[temp_col], aggfunc='mean', dropna=False, margins=margins)
    pivot_table_count = plot_df.pivot_table(values=printable_names[value_col], index=printable_names[wealth_col], 
        columns=printable_names[temp_col], aggfunc='count', dropna=False, margins=margins)

    plt.figure(figsize=(10, 8))
    sns.heatmap(pivot_table_mean, annot=True, fmt=".2f", cmap='RdYlBu_r')

    # Overlay the counts on the heatmap
    for i, row in enumerate(pivot_table_mean.values):
        for j, value in enumerate(row):
            plt.text(j+0.5, i+0.6, f'\n({pivot_table_count.values[i][j]})', 
                    ha="center", va="center");

    plt.title((f'{printable_names[wealth_col]} x {printable_names[temp_col]} x {printable_names[value_col]}' 
    f'(mean Proportion & Count) {country if country else ", All Locations"}'));
    plt.show();

plot_heatmap(merged_binned_df, 'temp_bin_quants', value_col='low_adult_bmi')

In [None]:
plot_heatmap(merged_binned_df, 'over30_bin', value_col='low_adult_bmi')