In [1]:
from idd_forecast_mbp.helper_functions import check_column_for_problematic_values
from idd_forecast_mbp.rake_and_aggregate_functions import make_aa_df_square, prep_df

In [7]:
import numpy as np

import pandas as pd
from pathlib import Path
from datetime import datetime
from rra_tools.shell_tools import mkdir  # type: ignore
from idd_forecast_mbp import constants as rfc
from idd_forecast_mbp.helper_functions import check_column_for_problematic_values
from idd_forecast_mbp.parquet_functions import read_parquet_with_integer_ids, write_parquet
from idd_forecast_mbp.cause_processing_functions import format_aa_gbd_df, process_lsae_df
from idd_forecast_mbp.rake_and_aggregate_functions import rake_aa_count_lsae_to_gbd, make_aa_full_rate_df_from_aa_count_df, check_concordance, aggregate_aa_rate_lsae_to_gbd


PROCESSED_DATA_PATH = rfc.PROCESSED_DATA_PATH

GBD_DATA_PATH = rfc.GBD_DATA_PATH
LSAE_INPUT_PATH = rfc.LSAE_INPUT_PATH

aa_full_malaria_df_path = PROCESSED_DATA_PATH / "aa_full_malaria_df.parquet"
aa_full_dengue_df_path = PROCESSED_DATA_PATH / "aa_full_dengue_df.parquet"
################################################################
#### Hierarchy Paths, loading, and cleaning
################################################################

aa_full_population_df_path = f"{PROCESSED_DATA_PATH}/aa_2023_full_population.parquet"
full_2023_hierarchy_path = f"{PROCESSED_DATA_PATH}/full_hierarchy_2023_lsae_1209.parquet"

hierarchy_df = read_parquet_with_integer_ids(full_2023_hierarchy_path)
aa_full_population_df = pd.read_parquet(aa_full_population_df_path)

aa_gbd_malaria_df_path = f"{GBD_DATA_PATH}/gbd_2023_malaria_aa.csv"
aa_gbd_dengue_df_path = f"{GBD_DATA_PATH}/gbd_2023_dengue_aa.csv"

measure_map = rfc.measure_map
ploblematic_rule_map = rfc.problematic_rule_map

################################################################
###  MALARIA DATA PROCESSING AND RAKING
################################################################

###----------------------------------------------------------###
### 1. Data Loading - Get raw data sources
### Imports raw malaria prevalence data (PfPR) and reference GBD datasets needed for raking.
### These foundational datasets provide the base inputs for all subsequent processing.
###----------------------------------------------------------###

# Load PfPR (parasite prevalence) and GBD reference data
aa_lsae_malaria_pfpr_df = process_lsae_df("malaria", "pfpr", aa_full_population_df, hierarchy_df)
aa_full_malaria_pfpr_df = aggregate_aa_rate_lsae_to_gbd(rate_variable = "malaria_pfpr", hierarchy_df = hierarchy_df, aa_lsae_rate_df = aa_lsae_malaria_pfpr_df, aa_full_population_df=aa_full_population_df, return_full_df = True)

# Load GBD reference data for raking
aa_gbd_malaria_df = pd.read_csv(aa_gbd_malaria_df_path, low_memory=False)

###----------------------------------------------------------###
### 2. Incidence Processing & Raking
### This section processes malaria incidence data, including counts and rates.
### It formats the GBD data, processes the LSAE data, and then rakes the incidence counts
### to the GBD data. Finally, it calculates incidence rates from the counts.
###----------------------------------------------------------###
cause = 'malaria'
measure = 'incidence'
short_measure = measure_map[measure]['short']
metric = 'count'
count_variable = f'{cause}_{short_measure}_{metric}'
rate_variable = f'{cause}_{short_measure}_rate'

problematic_rules = ploblematic_rule_map[cause][measure]

aa_inc_gbd_count_df = format_aa_gbd_df(cause, 'incidence', 'count', aa_gbd_malaria_df)
aa_mort_gbd_count_df = format_aa_gbd_df(cause, 'mortality', 'count', aa_gbd_malaria_df)
aa_gbd_count_df = aa_inc_gbd_count_df.merge(
    aa_mort_gbd_count_df,
    on=["location_id", "year_id"],
    how="left")
aa_gbd_count_df = aa_gbd_count_df.merge(
    aa_full_population_df,
    on=["location_id", "year_id"],
    how="left")

aa_inc_lsae_count_df = process_lsae_df(cause, 'incidence', aa_full_population_df, hierarchy_df)
aa_mort_lsae_count_df = process_lsae_df(cause, 'mortality', aa_full_population_df, hierarchy_df)
aa_lsae_count_df = aa_inc_lsae_count_df.merge(
    aa_mort_lsae_count_df[["location_id", "year_id", "malaria_mort_count"]],
    on=["location_id", "year_id"],
    how="left")

In [None]:
# aa_full_malaria_inc_count_df = rake_aa_count_lsae_to_gbd(count_variable = count_variable, 
#                                                  hierarchy_df =hierarchy_df, 
#                                                  aa_gbd_count_df = aa_gbd_count_df, 
#                                                  aa_lsae_count_df = aa_lsae_count_df,
#                                                  problematic_rules = problematic_rules,
#                                                  aa_full_count_df_path = None, return_full_df=True)

In [12]:
inc_count_variable = 'malaria_inc_count'
mort_count_variable = 'malaria_mort_count'
hierarchy_df = hierarchy_df
aa_gbd_count_df = aa_gbd_count_df
aa_lsae_count_df = aa_lsae_count_df

In [14]:
count_variables = [inc_count_variable, mort_count_variable]
aa_gbd_count_df = prep_df(aa_gbd_count_df, hierarchy_df)
aa_gbd_count_df['cfr'] = aa_gbd_count_df[mort_count_variable] / aa_gbd_count_df[inc_count_variable]
aa_gbd_count_df.loc[aa_gbd_count_df[mort_count_variable] == 0, 'cfr'] = 0

In [17]:
aa_gbd_count_0_to_3_df = aa_gbd_count_df[aa_gbd_count_df['level'] <= 3].copy()

aa_lsae_count_df = prep_df(aa_lsae_count_df, hierarchy_df)
aa_lsae_count_df = make_aa_df_square(count_variables, aa_lsae_count_df, hierarchy_df, level_start = 3, level_end = 5)
aa_lsae_count_df['cfr'] = aa_lsae_count_df[mort_count_variable] / aa_lsae_count_df[inc_count_variable]
aa_lsae_count_df.loc[aa_lsae_count_df[mort_count_variable] == 0, 'cfr'] = 0

In [18]:
for variable in count_variables + ['cfr']:
    aa_gbd_count_df[f'{variable}_gbd'] = aa_gbd_count_df[variable]

gbd_variables = [f'{variable}_gbd' for variable in count_variables + ['cfr']]

aa_gbd_count_df['set_by_gbd'] = True

aa_lsae_count_df = aa_lsae_count_df.merge(
    aa_gbd_count_df[['location_id', 'year_id', 'set_by_gbd'] + gbd_variables],
    on=['location_id', 'year_id'],
    how='left'
)

In [19]:
# Track which locations are already set by GBD
mask = aa_lsae_count_df['set_by_gbd'].isna()
aa_lsae_count_df.loc[mask, 'set_by_gbd'] = False
aa_lsae_count_df['set_by_gbd'] = aa_lsae_count_df['set_by_gbd'].astype('boolean')

In [20]:
# Handle all count variables and CFR
for var in count_variables + ['cfr']:
    aa_lsae_count_df[var] = aa_lsae_count_df[f'{var}_gbd'].fillna(aa_lsae_count_df[var])
    aa_lsae_count_df = aa_lsae_count_df.drop(columns=[f'{var}_gbd'])

Start here

In [63]:
level_m1_df = aa_gbd_count_0_to_3_df[aa_gbd_count_0_to_3_df['level'] == 3].copy()
level_df = aa_lsae_count_df[aa_lsae_count_df['level'] == 4].copy()

In [64]:
rename_dict = {'location_id': 'parent_id'}
for var in count_variables + ['cfr']:
    rename_dict[var] = f'parent_{var}'

parent_columns = list(rename_dict.values())

level_m1_df = level_m1_df.rename(columns=rename_dict)

level_df = level_df.merge(
    hierarchy_df[['location_id', 'parent_id']],
    on='location_id',
    how='left'
)

In [65]:
level_df['current_mortality_rate'] = level_df[mort_count_variable] / level_df['population']
level_df['current_incidence_rate'] = level_df[inc_count_variable] / level_df['population']
level_df['current_cfr'] = level_df[mort_count_variable] / level_df[inc_count_variable]

# Replace NaN and Inf values with 0 for all current rate columns
current_rate_cols = ['current_mortality_rate', 'current_incidence_rate', 'current_cfr']
for col in current_rate_cols:
    level_df[col] = level_df[col].replace([np.inf, -np.inf, np.nan], 0)
# If any of the three 'current_' variables are 0, set 'effective_population' to 0
level_df['effective_population'] = np.where(
    (level_df['current_mortality_rate'] == 0) |
    (level_df['current_incidence_rate'] == 0) |
    (level_df['current_cfr'] == 0),
    0,
    level_df['population']
)

In [66]:
# Aggregate all count variables and CFR by parent_id
agg_dict = {
    'population': 'sum',
    'effective_population': 'sum'
}
for var in count_variables:
    agg_dict[var] = 'sum'

level_m1_agg_df = level_df.groupby(['parent_id', 'year_id']).agg(agg_dict).reset_index()


In [67]:
level_m1_agg_df = level_m1_agg_df.rename(columns={'population': 'parent_population'})
level_m1_agg_df = level_m1_agg_df.rename(columns={'effective_population': 'parent_effective_population'})

In [68]:
# Merge in the level - 1 df
level_m1_agg_df = level_m1_agg_df.merge(
    level_m1_df[['year_id'] + parent_columns],
    on=['year_id', 'parent_id'],
    how='left'
)
level_m1_agg_df['cfr'] = level_m1_agg_df[mort_count_variable] / level_m1_agg_df[inc_count_variable]

In [71]:
level_m1_agg_df

Unnamed: 0,parent_id,year_id,parent_population,parent_effective_population,malaria_inc_count,malaria_mort_count,parent_malaria_inc_count,parent_malaria_mort_count,parent_cfr,cfr
0,6,2000,1.258333e+09,5.011522e+07,6.075997e+04,22.050010,6.075997e+04,22.050010,0.000363,0.000363
1,6,2001,1.263712e+09,5.053049e+07,4.580781e+04,25.828507,4.580781e+04,25.828507,0.000564,0.000564
2,6,2002,1.269927e+09,5.098043e+07,5.283493e+04,44.653627,5.283493e+04,44.653627,0.000845,0.000845
3,6,2003,1.276870e+09,5.145764e+07,5.823846e+04,34.619070,5.823846e+04,34.619070,0.000594,0.000594
4,6,2004,1.284501e+09,5.195527e+07,5.329673e+04,39.579923,5.329673e+04,39.579923,0.000743,0.000743
...,...,...,...,...,...,...,...,...,...,...
4687,522,2018,3.480359e+07,3.480359e+07,1.733528e+06,2102.198346,2.199947e+06,2743.228159,0.001247,0.001213
4688,522,2019,3.560610e+07,3.560610e+07,1.996085e+06,2201.875475,2.524225e+06,2777.029989,0.001100,0.001103
4689,522,2020,3.614354e+07,3.614354e+07,2.198147e+06,2546.554656,2.681114e+06,3198.802707,0.001193,0.001159
4690,522,2021,3.656424e+07,3.656424e+07,2.186278e+06,2686.098845,2.673778e+06,3210.596273,0.001201,0.001229


In [72]:
level_m1_agg_df[f'{mort_count_variable}_rakingfactor'] = level_m1_agg_df[mort_count_variable] / level_m1_agg_df[f'parent_{mort_count_variable}']
level_m1_agg_df[f'{inc_count_variable}_rakingfactor'] = level_m1_agg_df[inc_count_variable] / level_m1_agg_df[f'parent_{inc_count_variable}']
level_m1_agg_df['cfr_rakingfactor'] = level_m1_agg_df['cfr'] / level_m1_agg_df['parent_cfr']

In [73]:
for var in count_variables + ['cfr']:
    level_m1_agg_df[f'full_population_{var}_rakingfactor'] = level_m1_agg_df[f'parent_{var}'] / level_m1_agg_df['parent_population']
    level_m1_agg_df[f'population_{var}_rakingfactor'] = level_m1_agg_df[f'parent_{var}'] / level_m1_agg_df['parent_effective_population']

In [74]:
for var in count_variables + ['cfr']:
    # set the raking factor to 0 where the parent count is 0
    level_m1_agg_df.loc[level_m1_agg_df[f'parent_{var}'] == 0, f'{var}_rakingfactor'] = 0
    level_m1_agg_df.loc[level_m1_agg_df[f'parent_{var}'] == 0, f'full_population_{var}_rakingfactor'] = 0
    level_m1_agg_df.loc[level_m1_agg_df[f'parent_{var}'] == 0, f'population_{var}_rakingfactor'] = 0

In [75]:
# Build the list of columns to merge for all variables
merge_cols = ['year_id', 'parent_id']

# Add parent variables
for var in count_variables + ['cfr']:
    merge_cols.append(f'parent_{var}')
    merge_cols.append(f'{var}_rakingfactor')
    merge_cols.append(f'full_population_{var}_rakingfactor')
    merge_cols.append(f'population_{var}_rakingfactor')

In [None]:
level_df = level_df.merge(
    level_m1_agg_df[merge_cols],
    on=['year_id', 'parent_id'],
    how='left'
)

In [None]:



# replace population_raking_factor with full_population_raking_factor if population_raking_factor is inf or the rate we will get if we use effective will be too high
level_df['used_full_population_raking_factor'] = np.where(level_df['population_raking_factor'] > problematic_rules['rate_max'][level], True, False)
level_df['population_raking_factor'] = np.where(level_df['population_raking_factor'] > problematic_rules['rate_max'][level], level_df['full_population_raking_factor'], level_df['population_raking_factor'])
# drop zero_population_raking_factor
level_df = level_df.drop(columns=['full_population_raking_factor'])

# Setting up which populaiton to use for raking and multiplying
effective_population_mask = level_df['used_full_population_raking_factor'] == False
level_df['population_to_use'] = level_df['population']
level_df.loc[effective_population_mask,'population_to_use'] = level_df.loc[effective_population_mask,'effective_population']

level_df['count_based_count'] = level_df[count_variable] * level_df['count_raking_factor']
level_df['population_based_count'] = level_df['population_to_use'] * level_df['population_raking_factor']

# Always use the actual population here!
level_df['count_based_rate'] = level_df['count_based_count'] / level_df['population']
level_df['population_based_rate'] = level_df['population_based_count'] / level_df['population']
level_df['parent_year_id'] = level_df['parent_id'].astype(str).str.cat(level_df['year_id'].astype(str), sep='_')

# Which level_df rows have either count_raking_factor = inf or count_based_rate is big
problematic_rows = level_df[(level_df['count_raking_factor'] > problematic_rules['count_raking_factor_max']) | 
                            (level_df['count_based_rate'] > problematic_rules['rate_max'][level]) | 
                            ((level_df['count_raking_factor'] > problematic_rules['count_raking_factor_conditional']) & (level_df['count_based_rate'] > problematic_rules['rate_max_conditional']))]
problematic_rows = problematic_rows[problematic_rows['count_based_rate'] > problematic_rows['population_based_rate']]
npinf_rows = level_df[(level_df['count_raking_factor'] == np.inf)]
problematic_rows = pd.concat([problematic_rows, npinf_rows]).drop_duplicates()

problematic_parent_years = problematic_rows['parent_year_id'].drop_duplicates().reset_index(drop=True).to_frame(name='parent_year_id')
problematic_parent_years['use_population'] = True
level_df = level_df.merge(
    problematic_parent_years,
    on='parent_year_id',
    how='left'
)

mask = level_df['use_population'].isna()
level_df.loc[mask, 'use_population'] = False
level_df['use_population'] = level_df['use_population'].astype('boolean')

# For rows where use_population is True
population_mask = (level_df['use_population'] == True) & (level_df['set_by_gbd'] == False)
count_mask = (level_df['use_population'] == False) & (level_df['set_by_gbd'] == False)
left_alone = len(count_mask[count_mask == True])
changed = len(population_mask[population_mask == True])
total = left_alone + changed
# Note we use 'population_to_use' here!!!
level_df.loc[count_mask, count_variable] = level_df.loc[count_mask, count_variable] * level_df.loc[count_mask, 'count_raking_factor']
level_df.loc[population_mask, count_variable] = level_df.loc[population_mask, 'population_to_use'] * level_df.loc[population_mask, 'population_raking_factor']
    # Apply the raking factor to the count variable
    
drop_cols = [col for col in level_df.columns if 'based' in col or 'raking' in col] + ['parent_id', f'parent_{count_variable}', 'use_population', 'used_full_population_raking_factor', 'population_to_use', 'effective_population', 'parent_year_id', 'current_rate']
# Drop the raking factor
level_df = level_df.drop(columns=drop_cols)
#

In [None]:
def rake_level_at_once(variables, level_df, level_m1_df, problematic_rules, hierarchy_df, level):
    '''
    Rakes the level DataFrame to the next level using the hierarchy DataFrame.
    '''
    # Change the name of the count variable and prep for matching by parent_id from level
    level_m1_df = level_m1_df.rename(columns={
        count_variable: f'parent_{count_variable}',
        'location_id': 'parent_id'})
    # Prep the level df
    # Add in parent_id
    level_df = level_df.merge(
        hierarchy_df[['location_id', 'parent_id']],
        on='location_id',
        how='left'
    )
    level_df['current_rate'] = level_df[count_variable] / level_df['population']
    # Make a column called 'effective_population' that is equal to population where count_variable > 0, else 0
    level_df['effective_population'] = np.where(level_df[count_variable] > 0, level_df['population'], 0)
    # Aggregate the count variable by parent_id
    level_m1_agg_df= level_df.groupby(['parent_id', 'year_id']).agg({
        count_variable: 'sum',
        'population': 'sum',
        'effective_population': 'sum'
    }).reset_index()
    level_m1_agg_df = level_m1_agg_df.rename(columns={'population': 'parent_population'})
    level_m1_agg_df = level_m1_agg_df.rename(columns={'effective_population': 'parent_effective_population'})
    # Merge in the level - 1 df
    level_m1_agg_df = level_m1_agg_df.merge(
        level_m1_df[['year_id', 'parent_id', f'parent_{count_variable}']],
        on=['year_id', 'parent_id'],
        how='left'
    )
    # Calculate the raking factor
    level_m1_agg_df['count_raking_factor'] = level_m1_agg_df[f'parent_{count_variable}'] / level_m1_agg_df[count_variable]
    level_m1_agg_df['full_population_raking_factor'] = level_m1_agg_df[f'parent_{count_variable}'] / level_m1_agg_df['parent_population']
    level_m1_agg_df['population_raking_factor'] = level_m1_agg_df[f'parent_{count_variable}'] / level_m1_agg_df['parent_effective_population']
    # Set the raking factor to 1 where the parent count variable is 0
    level_m1_agg_df.loc[level_m1_agg_df[f'parent_{count_variable}'] == 0, 'count_raking_factor'] = 0
    level_m1_agg_df.loc[level_m1_agg_df[f'parent_{count_variable}'] == 0, 'full_population_raking_factor'] = 0
    level_m1_agg_df.loc[level_m1_agg_df[f'parent_{count_variable}'] == 0, 'population_raking_factor'] = 0


    level_df = level_df.merge(
        level_m1_agg_df[['year_id', 'parent_id', f'parent_{count_variable}','count_raking_factor', 'full_population_raking_factor', 'population_raking_factor']],
        on=['year_id', 'parent_id'],
        how='left'
    )

    # replace population_raking_factor with full_population_raking_factor if population_raking_factor is inf or the rate we will get if we use effective will be too high
    level_df['used_full_population_raking_factor'] = np.where(level_df['population_raking_factor'] > problematic_rules['rate_max'][level], True, False)
    level_df['population_raking_factor'] = np.where(level_df['population_raking_factor'] > problematic_rules['rate_max'][level], level_df['full_population_raking_factor'], level_df['population_raking_factor'])
    # drop zero_population_raking_factor
    level_df = level_df.drop(columns=['full_population_raking_factor'])

    # Setting up which populaiton to use for raking and multiplying
    effective_population_mask = level_df['used_full_population_raking_factor'] == False
    level_df['population_to_use'] = level_df['population']
    level_df.loc[effective_population_mask,'population_to_use'] = level_df.loc[effective_population_mask,'effective_population']

    level_df['count_based_count'] = level_df[count_variable] * level_df['count_raking_factor']
    level_df['population_based_count'] = level_df['population_to_use'] * level_df['population_raking_factor']

    # Always use the actual population here!
    level_df['count_based_rate'] = level_df['count_based_count'] / level_df['population']
    level_df['population_based_rate'] = level_df['population_based_count'] / level_df['population']
    level_df['parent_year_id'] = level_df['parent_id'].astype(str).str.cat(level_df['year_id'].astype(str), sep='_')

    # Which level_df rows have either count_raking_factor = inf or count_based_rate is big
    problematic_rows = level_df[(level_df['count_raking_factor'] > problematic_rules['count_raking_factor_max']) | 
                                (level_df['count_based_rate'] > problematic_rules['rate_max'][level]) | 
                                ((level_df['count_raking_factor'] > problematic_rules['count_raking_factor_conditional']) & (level_df['count_based_rate'] > problematic_rules['rate_max_conditional']))]
    problematic_rows = problematic_rows[problematic_rows['count_based_rate'] > problematic_rows['population_based_rate']]
    npinf_rows = level_df[(level_df['count_raking_factor'] == np.inf)]
    problematic_rows = pd.concat([problematic_rows, npinf_rows]).drop_duplicates()

    problematic_parent_years = problematic_rows['parent_year_id'].drop_duplicates().reset_index(drop=True).to_frame(name='parent_year_id')
    problematic_parent_years['use_population'] = True
    level_df = level_df.merge(
        problematic_parent_years,
        on='parent_year_id',
        how='left'
    )
    
    mask = level_df['use_population'].isna()
    level_df.loc[mask, 'use_population'] = False
    level_df['use_population'] = level_df['use_population'].astype('boolean')

    # For rows where use_population is True
    population_mask = (level_df['use_population'] == True) & (level_df['set_by_gbd'] == False)
    count_mask = (level_df['use_population'] == False) & (level_df['set_by_gbd'] == False)
    left_alone = len(count_mask[count_mask == True])
    changed = len(population_mask[population_mask == True])
    total = left_alone + changed
    # Note we use 'population_to_use' here!!!
    level_df.loc[count_mask, count_variable] = level_df.loc[count_mask, count_variable] * level_df.loc[count_mask, 'count_raking_factor']
    level_df.loc[population_mask, count_variable] = level_df.loc[population_mask, 'population_to_use'] * level_df.loc[population_mask, 'population_raking_factor']
        # Apply the raking factor to the count variable
        
    drop_cols = [col for col in level_df.columns if 'based' in col or 'raking' in col] + ['parent_id', f'parent_{count_variable}', 'use_population', 'used_full_population_raking_factor', 'population_to_use', 'effective_population', 'parent_year_id', 'current_rate']
    # Drop the raking factor
    level_df = level_df.drop(columns=drop_cols)
    #
    # Additional mortality-specific checkslevel_df['current_rate'] = level_df[count_variable] / level_df['population']
    mort_stats = {}
    if 'malaria_mort_count' in count_variable.lower():
        # Only calculate for numeric, non-problematic values
        tmp_df = level_df.copy()
        tmp_df['rate'] = tmp_df[count_variable] / tmp_df['population']
        mort_stats = {
            'mean': tmp_df['rate'].mean(),
            'max': tmp_df['rate'].max(),
            'count_gt_0_001': (tmp_df['rate'] > 0.001).sum(),
            'count_gt_0_01': (tmp_df['rate'] > 0.01).sum(),
            'count_gt_0_1': (tmp_df['rate'] > 0.1).sum(),
            'count_gt_1': (tmp_df['rate'] > 1).sum()
        }
        print(f"\nMortality stats for {count_variable} at level {level}:")
        print(f"Mortality Statistics:")
        print(f"Mean: {mort_stats['mean']:.6f}")
        print(f"Max: {mort_stats['max']:.6f}")
        print(f"Count > 0.001: {mort_stats['count_gt_0_001']}")
        print(f"Count > 0.01: {mort_stats['count_gt_0_01']}")
        print(f"Count > 0.1: {mort_stats['count_gt_0_1']}")
        print(f"Count > 1: {mort_stats['count_gt_1']}")
        print(f"Changed {changed} ({100*(changed / total):.1f}%)rows, left alone {left_alone} rows, total {total} rows")
    return level_df

In [None]:












level_4_df = rake_level(count_variable, level_df, level_m1_df, problematic_rules, hierarchy_df, level = 4)
# Rake 5 to 4
level_m1_df = level_4_df.copy()
level_df = aa_lsae_count_df[aa_lsae_count_df['level'] == 5].copy()
level_df = make_aa_df_square(count_variable, level_df, hierarchy_df, 5, 5)
level_5_df = rake_level(count_variable, level_df, level_m1_df, problematic_rules, hierarchy_df, level = 5)
# Make aa_full_df
aa_full_count_df = pd.concat([
    aa_gbd_count_0_to_3_df,
    level_4_df,
    level_5_df
], ignore_index=True)
# Drop level column if it exists
if 'level' in aa_full_count_df.columns:
    aa_full_count_df = aa_full_count_df.drop(columns=['level'])

# Save the aa_full_df
if aa_full_count_df_path is not None:
    write_parquet(aa_full_count_df, aa_full_count_df_path)
# Return the full DataFrame if requested, return nothing otherwise
if return_full_df:
    return aa_full_count_df

In [None]:
def rake_aa_counts_lsae_to_gbd(inc_count_variable, mort_count_variable, hierarchy_df, 
                               aa_gbd_count_df, aa_lsae_count_df, 
                               problematic_rules):
    '''
    Rakes the LSAE age-aggregated data to match the GBD age-aggregated data.
    '''
    count_variables = [inc_count_variable, mort_count_variable]
    aa_gbd_count_df = prep_df(aa_gbd_count_df, hierarchy_df)
    aa_gbd_count_df['cfr'] = aa_gbd_count_df[mort_count_variable] / aa_gbd_count_df[inc_count_variable]
    aa_gbd_count_df[aa_gbd_count_df[mort_count_variable] == 0, 'cfr'] = 0

    aa_gbd_count_0_to_3_df = aa_gbd_count_df[aa_gbd_count_df['level'] <= 3].copy()

    aa_lsae_count_df = prep_df(aa_lsae_count_df, hierarchy_df)
    aa_lsae_count_df = make_aa_df_square(count_variables, hierarchy_df, level_start = 3, level_end = 5)
    aa_lsae_count_df['cfr'] = aa_lsae_count_df[mort_count_variable] / aa_lsae_count_df[inc_count_variable]
    aa_lsae_count_df[aa_lsae_count_df[mort_count_variable] == 0, 'cfr'] = 0


    for variable in count_variables + ['cfr']:
        aa_gbd_count_df[f'{variable}_gbd'] = aa_gbd_count_df[variable]
    
    gbd_variables = [f'{variable}_gbd' for variable in count_variables + ['cfr']]

    aa_gbd_count_df['set_by_gbd'] = True
    
    aa_lsae_count_df = aa_lsae_count_df.merge(
        aa_gbd_count_df[['location_id', 'year_id', 'set_by_gbd'] + gbd_variables],
        on=['location_id', 'year_id'],
        how='left'
    )


    # Track which locations are already set by GBD
    mask = aa_lsae_count_df['set_by_gbd'].isna()
    aa_lsae_count_df.loc[mask, 'set_by_gbd'] = False
    aa_lsae_count_df['set_by_gbd'] = aa_lsae_count_df['set_by_gbd'].astype('boolean')

    aa_lsae_count_df[count_variable] = aa_lsae_count_df[f'{count_variable}_gbd'].fillna(aa_lsae_count_df[count_variable])
    aa_lsae_count_df = aa_lsae_count_df.drop(columns=[f'{count_variable}_gbd'])

    # Rake 4 to 3 def make_aa_df_square(variable, df, hierarchy_df, level_start, level_end):
    level_m1_df = aa_gbd_count_0_to_3_df[aa_gbd_count_0_to_3_df['level'] == 3].copy()
    level_df = aa_lsae_count_df[aa_lsae_count_df['level'] == 4].copy()
    level_df = make_aa_df_square(count_variable, level_df, hierarchy_df, 4, 4)
    level_4_df = rake_level(count_variable, level_df, level_m1_df, problematic_rules, hierarchy_df, level = 4)
    # Rake 5 to 4
    level_m1_df = level_4_df.copy()
    level_df = aa_lsae_count_df[aa_lsae_count_df['level'] == 5].copy()
    level_df = make_aa_df_square(count_variable, level_df, hierarchy_df, 5, 5)
    level_5_df = rake_level(count_variable, level_df, level_m1_df, problematic_rules, hierarchy_df, level = 5)
    # Make aa_full_df
    aa_full_count_df = pd.concat([
        aa_gbd_count_0_to_3_df,
        level_4_df,
        level_5_df
    ], ignore_index=True)
    # Drop level column if it exists
    if 'level' in aa_full_count_df.columns:
        aa_full_count_df = aa_full_count_df.drop(columns=['level'])

    # Save the aa_full_df
    if aa_full_count_df_path is not None:
        write_parquet(aa_full_count_df, aa_full_count_df_path)
    # Return the full DataFrame if requested, return nothing otherwise
    if return_full_df:
        return aa_full_count_df