In [1]:
import pandas as pd
import numpy as np

# Assuming the CatSplitter and configuration classes have been imported correctly
from pydisagg.ihme.splitter import (
    CatSplitter,
    CatDataConfig,
    CatPatternConfig,
    CatPopulationConfig,
)

# Set a random seed for reproducibility
np.random.seed(42)

# -------------------------------
# Example DataFrames
# -------------------------------

# Pre-split DataFrame with 3 rows
pre_split = pd.DataFrame(
    {
        "study_id": np.random.randint(1000, 9999, size=3),  # Unique study IDs
        "year_id": [2010, 2010, 2010],
        "location_id": [
            [1234, 1235, 1236],  # List of location_ids for row 1
            [2345, 2346, 2347],  # List of location_ids for row 2
            [3456],  # Single location_id for row 3 (no need to split)
        ],
        "mean": [0.2, 0.3, 0.4],
        "std_err": [0.01, 0.02, 0.03],
    }
)

# Create a list of all location_ids mentioned
all_location_ids = [
    1234,
    1235,
    1236,
    2345,
    2346,
    2347,
    3456,
    4567,  # Additional location_ids
    5678,
]

# Pattern DataFrame for all location_ids
data_pattern = pd.DataFrame(
    {
        "location_id": all_location_ids,
        "year_id": [2010] * len(all_location_ids),
        "mean": np.random.uniform(0.1, 0.5, len(all_location_ids)),
        "std_err": np.random.uniform(0.01, 0.05, len(all_location_ids)),
    }
)

# Population DataFrame for all location_ids
data_pop = pd.DataFrame(
    {
        "location_id": all_location_ids,
        "year_id": [2010] * len(all_location_ids),
        "population": np.random.randint(10000, 1000000, len(all_location_ids)),
    }
)

# Print the DataFrames
print("Pre-split DataFrame:")
print(pre_split)
print("\nPattern DataFrame:")
print(data_pattern)
print("\nPopulation DataFrame:")
print(data_pop)

# -------------------------------
# Configurations
# -------------------------------

data_config = CatDataConfig(
    index=["study_id", "year_id"],  # Include study_id in the index
    target="location_id",  # Column containing list of targets
    val="mean",
    val_sd="std_err",
)

pattern_config = CatPatternConfig(
    index=["year_id"],
    target="location_id",
    val="mean",
    val_sd="std_err",
)

population_config = CatPopulationConfig(
    index=["year_id"],
    target="location_id",
    val="population",
)

# Initialize the CatSplitter
splitter = CatSplitter(
    data=data_config, pattern=pattern_config, population=population_config
)

# Perform the split
try:
    final_split_df = splitter.split(
        data=pre_split,
        pattern=data_pattern,
        population=data_pop,
        model="rate",
        output_type="rate",
    )
    final_split_df.sort_values(by=["study_id", "location_id"], inplace=True)
    print("\nFinal Split DataFrame:")
    print(final_split_df)
except ValueError as e:
    print(f"Error: {e}")


Pre-split DataFrame:
   study_id  year_id         location_id  mean  std_err
0      8270     2010  [1234, 1235, 1236]   0.2     0.01
1      1860     2010  [2345, 2346, 2347]   0.3     0.02
2      6390     2010              [3456]   0.4     0.03

Pattern DataFrame:
   location_id  year_id      mean   std_err
0         1234     2010  0.392798  0.048796
1         1235     2010  0.339463  0.043298
2         1236     2010  0.162407  0.018494
3         2345     2010  0.162398  0.017273
4         2346     2010  0.123233  0.017336
5         2347     2010  0.446470  0.022170
6         3456     2010  0.340446  0.030990
7         4567     2010  0.383229  0.027278
8         5678     2010  0.108234  0.021649

Population DataFrame:
   location_id  year_id  population
0         1234     2010      166730
1         1235     2010      880910
2         1236     2010      394681
3         2345     2010      159503
4         2346     2010      664811
5         2347     2010      537035
6         3456     2

In [2]:
final_split_df

Unnamed: 0,mean,study_id,std_err,location_id,year_id,cat_pat_mean,cat_pat_std_err,population,split_result,split_result_se,split_flag,orig_group
3,0.3,1860,0.02,2345,2010,0.162398,0.017273,159503.0,0.190806,0.02444,1,"[2345, 2346, 2347]"
4,0.3,1860,0.02,2346,2010,0.123233,0.017336,664811.0,0.14479,0.019012,1,"[2345, 2346, 2347]"
5,0.3,1860,0.02,2347,2010,0.44647,0.02217,537035.0,0.52457,0.040101,1,"[2345, 2346, 2347]"
6,0.4,6390,0.03,3456,2010,0.340446,0.03099,658143.0,0.4,0.03,0,[3456]
0,0.2,8270,0.01,1234,2010,0.392798,0.048796,166730.0,0.264351,0.039018,1,"[1234, 1235, 1236]"
1,0.2,8270,0.01,1235,2010,0.339463,0.043298,880910.0,0.228457,0.015557,1,"[1234, 1235, 1236]"
2,0.2,8270,0.01,1236,2010,0.162407,0.018494,394681.0,0.1093,0.015518,1,"[1234, 1235, 1236]"
