In [None]:
import numpy as np, pandas as pd, matplotlib.pyplot as plt
# pd.set_option('display.max_rows', 8)
!date
!whoami

%load_ext autoreload
%autoreload 2

In [None]:
import sqlns_summarizer as sqs

## Cell for editing input data

In [None]:
# Base directory for output files. Subdirectories are assumed to be of the form 'location/run_date/'
base_directory = '/share/costeffectiveness/results/sqlns/presentation/'

# Map countries to the correct run date = subdirectory name
locations_run_dates = {
#     'Bangladesh': '2019_07_02_11_55_19',
#     'Burkina_Faso': '2019_07_02_11_56_40',
#     'Ethiopia': '2019_07_02_11_58_02',
#     'India': '2019_07_02_11_58_29',
    'Nigeria': '2019_07_23_10_57_25', #'2019_07_18_13_20_17',
    }

locations = list(locations_run_dates.keys())
    
intervention_colname_mapper = {
        'sqlns.effect_on_child_stunting.permanent': 'stunting_permanent',
        'sqlns.effect_on_child_wasting.permanent': 'wasting_permanent',
        'sqlns.effect_on_iron_deficiency.permanent': 'iron_permanent',
        'sqlns.duration': 'duration',
        'sqlns.effect_on_iron_deficiency.mean': 'iron_mean',
        'sqlns.effect_on_iron_deficiency.sd': 'iron_sd', 
        'sqlns.program_coverage': 'coverage',
    }

## Typical list of data transformations we need for data analysis

**X** indicates that I have already implemented a version of the transformation, either below or in `sqlns_summarizer.py`.

1. **X** Load data from different locations and concatenate into single output file with location names added
2. **X** Verify column name categories
3. Verify random seed, input draw, and scenario counts compared to the number of rows in the output. If random seeds are missing (which often happens), it can lead to weird looking graphs and/or errors in the data processing code (particularly if an entire draw is missing for some scenario).
3. **Perhaps draw some simple graphs before any data processing, to verify the results as close to the raw data as possible**
4. **X** Rename the intervention columns with shorter names
5. **X** Sum over random seeds
6. **X** Parse column names to extract measure, cause, risk, sequela, year, sex, age, etc.
6. **Perhaps draw more graphs at this step, again to verify the results as close to the raw data as possible**
7. Do any desired case-specific aggregations or additions of derived measures (e.g. in mom_food model, sum over age groups and sexes and years, and sum over all neonatal causes).
    * Note: This step cannot be made generic, though perhaps there could be some wrapper functions for common tasks to make them easier.
8. Stack dataframes to put them in a form more convenient for analysis. Replace `NaN`s with 'all' where appropriate (e.g. for person time or total dalys). I think we want to stack before dividing or subtracting in order to make broadcasting easier.
9. For deaths, ylls, ylds, dalys, compute rate per person-year (i.e. divide by person time)
10. For risks and sequela, compute exposures/prevalences (i.e. percentages) in relevant categories. Allow specification of categories, e.g. using my `risk_mapper` module.
10. Do something with the 'diseases_at_end', 'disease_event_count', and 'population' columns if desired.
11. Subtract dataframes to compute averted/delta measures
12. **Draw some more graphs at this point to verify results at the draw level before aggregating**
12. Aggregate over draws to compute mean, upper and lower percentiles
13. Concatenate stacked dataframes to get final output
14. **Draw final graphs displaying desired results**

## Functions to perform data transformations

In [None]:
def load_by_location_and_rundate(base_directory: str, locations_run_dates: dict) -> pd.DataFrame:
    """Load output.hdf files from folders namedd with the convention 'base_directory/location/rundate/output.hdf'"""
    
    # Use dictionary to map countries to the correct path for the Vivarium output to process
    # E.g. /share/costeffectiveness/results/sqlns/bangladesh/2019_06_21_00_09_53
    locactions_paths = {location: f'{base_directory}/{location.lower()}/{run_date}/output.hdf'
                       for location, run_date in locations_run_dates.items()}

    # Read in data from different countries
    locations_outputs = {location: pd.read_hdf(path) for location, path in locactions_paths.items()}

    for location, output in locations_outputs.items():
        output['location'] = location
    
    return pd.concat(locations_outputs.values(), copy=False, sort=False)
    
def print_location_output_shapes(locations, all_output):
    """Print the shapes of outputs for each location to check whether all the same size or if some data is missing"""
    for location in locations:
        print(location, all_output.loc[all_output.location==location].shape)
        
def negate_column(output, column_name):
    """Negate a column of the dataframe ('sqlns_treated_days')"""
    output[column_name] = -1 * output[column_name]

In [None]:
def aggregate_and_reindex_subdataframes(output, intervention_colname_mapper):
    """
    Performs 3 of the transfrmations listed above:
    - Rename intervention columns with shorter names
    - Sum over random seeds
    - Parse column names to extract measure, cause, risk, sequela, year, sex, age, etc.,
      and use these to reindex the categorized subdataframes with MultiIndices
    """
    output.rename_intervention_columns(intervention_colname_mapper)
    output.sum_over_random_seeds()
    output.parse_column_names_and_reindex()

## Load data and check shape of output

In [None]:
all_output = load_by_location_and_rundate(base_directory, locations_run_dates)
print_location_output_shapes(locations, all_output)

In [None]:
all_output.head()

## Fix the negative `'sqlns_treated_days'` column

In [None]:
# Oops, 'sqlns_treated_days' got subtracted in the wrong order
# Fix by replacing column with its negation
negate_column(all_output, 'sqlns_treated_days')
all_output['sqlns_treated_days'].head()

## Create an OutputSummarizer from the data

Then check the column categories.

In [None]:
output = sqs.SQLNSOutputSummarizer(all_output)
output.print_column_report()

In [None]:
output._columns

In [None]:
output.columns('diseases_at_end')

In [None]:
output.columns('disease_event_count', 'population')

In [None]:
output._columns[['disease_event_count', 'population']]

In [None]:
output.columns('population')

In [None]:
output.subdata['population'].head()

In [None]:
9683+11685

In [None]:
10563+1122

In [None]:
import re

In [None]:
# pattern = re.compile('^(?:(?P<category>susceptible)_to_|(?P<category>recovered)_from_|(?P<category>))(?P<cause>\w+)_(?P<measure>event_count)$')
pattern = re.compile('^(?P<category>susceptible|recovered|)(?:_to_|_from_|)(?P<cause>\w+)_(?P<measure>event_count)$')
matches = []

matches.append(pattern.search('susceptible_to_measles_event_count'))
matches.append(pattern.search('measles_event_count'))
matches.append(pattern.search('recovered_from_measles_event_count'))
for match in matches:
    print(match.groups())

In [None]:
pattern = re.compile('^total_(?P<measure>population)(?:_|)(?P<category>\w*)')
matches = []
matches.append(pattern.search('total_population'))
matches.append(pattern.search('total_population_living'))
matches.append(pattern.search('total_population_dead'))
for match in matches:
    print(match.groups())

## Rename the intervention columns, sum over random seeds, parse column names, and reindex subdataframes

In [None]:
aggregate_and_reindex_subdataframes(output, intervention_colname_mapper)
output.subdata['intervention'].head()

In [None]:
output.subdata['random_seed'].head()

In [None]:
output.data[output.columns('random_seed')].head()

In [None]:
output.subdata['mortality'].head()

In [None]:
output.subdata['categorical_risk'].head()

In [None]:
output.data[output.columns('categorical_risk')].head()

In [None]:
output.subdata['population'].head()

In [None]:
# # Uncomment the extraction regex for 'population' to see the results of this
# output.subdata['population'][('population','')][('Nigeria', False)].head() 

In [None]:
output.columns('population')

In [None]:
s = pd.Series({0: [4,5,6], 1: [1,2], 2: [3]})
s

In [None]:
s[[0,1]]

In [None]:
s[1]

In [None]:
s[[1]]

In [None]:
[x for l in s[[0,1]] for x in l]

## Check output for monotonicity with coverage

In [None]:
output.data.head()

In [None]:
idx = pd.IndexSlice

output.data.loc[idx['Nigeria', False, False, False, 365.25, 0.895, 0.0656, :, 357],
                ['death_due_to_other_causes', 'death_due_to_diarrheal_diseases', 'ylds_due_to_iron_deficiency']]

In [None]:
output.data.loc[idx['Nigeria', False, False, False, 365.25, 0.895, 0.0656, :, 55],
                ['years_of_life_lost', 'years_lived_with_disability', 'random_seed_count']]

In [None]:
output.subdata['random_seed'].iloc[:,0].unique()

In [None]:
output.data.loc[output.subdata['random_seed'].iloc[:,0]==3,
                ['years_of_life_lost', 'years_lived_with_disability', 'random_seed_count']]

In [None]:
18*5*108

In [None]:
9720-9683

In [None]:
output.column_categories()

In [None]:
output.subdata['mortality'].head()

In [None]:
output.subdata['categorical_risk'].head()

In [None]:
output.subdata['graded_sequela'].head()

## Reindex sub-dataframes to extract cause names

In [None]:
# Get the yld columns
yld_df = all_output.filter(regex='yld')
yld_decomp = yld_df.columns.str.extract(
    '^(?P<measure>ylds)_due_to_(?P<cause>\w+?)(?:_in_(?P<year>\d{4})|)(?:_among_(?P<sex>male|female)|)(?:_in_age_group_(?P<age_group>\w+)|)$'
)
yld_decomp
# yld_df.columns = pd.MultiIndex.from_frame(yld_decomp.dropna(axis='columns', how='all'))
# yld_df.head()

In [None]:
output.reindex_sub_dataframes()
output.data.head()

In [None]:
output.subdata['mortality'].head()

In [None]:
output.subdata['mortality'].stack(level='measure').head()

In [None]:
output.subdata['mortality'].stack(level=output.subdata['mortality'].columns.names).reset_index().head()


In [None]:
output.subdata['person_time'].head()

In [None]:
output.subdata['person_time'].stack(level=0).head()