In [None]:
import xarray as xr # type: ignore
from pathlib import Path
import numpy as np # type: ignore
from affine import Affine # type: ignore
from typing import cast
import numpy.typing as npt # type: ignore
import pandas as pd # type: ignore
from typing import Literal, NamedTuple
import itertools
from rra_tools.shell_tools import mkdir # type: ignore
from idd_forecast_mbp import constants as rfc
from idd_forecast_mbp.helper_functions import read_parquet_with_integer_ids
import argparse

parser = argparse.ArgumentParser(description="Add DAH Sceanrios and create draw level dataframes for forecating malaria")

# Define arguments
parser.add_argument("--ssp_scenario", type=str, required=True, help="SSP scenario (e.g., 'ssp126', 'ssp245', 'ssp585')")
parser.add_argument("--dah_scenario", type=str, required=True, help="DAH scenario (e.g., 'Baseline')")
parser.add_argument("--measure", type=str, required=True, help="measure (e.g., 'mortality', 'incidence')")
parser.add_argument("--draw", type=str, required=True, help="Draw number (e.g., '001', '002', etc.)")

# Parse arguments
args = parser.parse_args()

ssp_scenario = args.ssp_scenario
dah_scenario = args.dah_scenario
measure = args.measure
draw = args.draw



PROCESSED_DATA_PATH = rfc.MODEL_ROOT / "02-processed_data"
FORECASTING_DATA_PATH = rfc.MODEL_ROOT / "04-forecasting_data"
UPLOAD_DATA_PATH = rfc.MODEL_ROOT / "05-upload_data"

forecast_df_path = "{FORECASTING_DATA_PATH}/as_malaria_measure_{measure}_ssp_scenario_{ssp_scenario}_dah_scenario_{dah_scenario}_draw_{draw}_with_predictions.parquet"
processed_forecast_df_path = "{UPLOAD_DATA_PATH}/full_as_malaria_measure_{measure}_ssp_scenario_{ssp_scenario}_dah_scenario_{dah_scenario}_draw_{draw}_with_predictions.parquet"


# Hierarchy path
hierarchy_df_path = f'{PROCESSED_DATA_PATH}/full_hierarchy_lsae_1209.parquet'
hierarchy_df = read_parquet_with_integer_ids(hierarchy_df_path)




def process_forecast_data(FORECASTING_DATA_PATH, measure, ssp_scenario, dah_scenario, draw, hierarchy_df):
    """
    Process forecast data, adding necessary columns and formatting.
    
    Parameters:
    -----------
    ssp_scenario : str
        SSP scenario name
    dah_scenario : str
        DAH scenario name
    draw : str
        Draw identifier
    hierarchy_df : pandas.DataFrame
        Hierarchy dataframe for location information
    Returns:
    --------
    pandas.DataFrame
        Full hierarchy forecast dataframe with all necessary columns and aggregations applied.
    """
    df = read_parquet_with_integer_ids(forecast_df_path.format(
        FORECASTING_DATA_PATH=FORECASTING_DATA_PATH,
        measure=measure,
        ssp_scenario=ssp_scenario,
        dah_scenario=dah_scenario,
        draw=draw
    ))

    df = df[df["year_id"] >= 2022]
    df = df.drop(columns=["malaria_mort_rate_baseline", "malaria_mort_rate_baseline_pred_1", "malaria_mort_rate_baseline_pred_2", 
                                            "reference_age_group_id", "reference_sex_id", "relative_risk_as", "rate_pred_1", "rate_pred_2",
                                            "population_aa", "pop_fraction_aa", "count_pred_2", "fhs_location_id", "location_name"])
    df = df.rename(columns={
        "count_pred_1": "count_pred",
    })

    df = df.merge(hierarchy_df[["location_id", "level"]], on="location_id", how="left")

    child_df = df.copy()

    for level in reversed(range(1,6)):
        
        print(f"Processing level {level}...")
        child_df = child_df.merge(hierarchy_df[["location_id", "parent_id"]], on="location_id", how="left")
        print(child_df["level"][0])
        parent_df = child_df.groupby(
            ["parent_id", "year_id", "age_group_id", "sex_id"]).agg({
            "count_pred": "sum",
            "population": "sum"
        }).reset_index()

        parent_df = parent_df.rename(columns={
            "parent_id": "location_id"
        })

        parent_df = parent_df.merge(hierarchy_df[["location_id", "level"]], on="location_id", how="left")
        df = pd.concat([df, parent_df], ignore_index=True)

        child_df = parent_df.copy()

    return df

# Process the forecast data
full_hierarchy_forecast_df = process_forecast_data(FORECASTING_DATA_PATH, measure, ssp_scenario, dah_scenario, draw, hierarchy_df)
# Save the processed dataframe
full_hierarchy_forecast_df.to_parquet(
    processed_forecast_df_path.format(
        UPLOAD_DATA_PATH=UPLOAD_DATA_PATH,
        measure=measure,
        ssp_scenario=ssp_scenario,
        dah_scenario=dah_scenario,
        draw=draw
    ),
    index=False
)