In [1]:
from splitting.disaggregate import split_datapoint
import numpy as np
from splitting.models import RateMultiplicativeModel
from splitting.models import LMO_model
from splitting.models import LogOdds_model
import pandas as pd

## Building a test dataset

In [2]:
groups_to_split_into=[0,1,2,3]
baseline_patterns=pd.DataFrame.from_dict({
    0:np.array([0.4,0.2,0.3,0.75]),
    1:np.array([0.1,0.15,0.2,0.7])
},orient='index')
baseline_patterns.index.name='baseline_id'


#Population ids correspond to the overall aggregated population we are splitting
#For example, this would be the id for a country

pop_group_sizes={
    0:[10,10,20,20],
    1:[10,20,20,13],
    2:[20,20,10,10],
    3:[5,40,30,20]
}
population_sizes=pd.DataFrame.from_dict(pop_group_sizes,orient='index',columns=groups_to_split_into)
population_sizes.index.name='pop_id'

pop_df=pd.DataFrame(
    {
        'pop_id':[0,1,2,3],
        'baseline_id':[0,0,1,1]
    }
)

group_partitions={
    0:[(0,1),(2,3)],
    1:[(0,1,2),(3,)],
    2:[(0,1,2,3)],
    3:[(0,),(1,),(2,3)]
}



In [3]:
def get_dummies(partition,splitting_groups):
   dummies=[
      [int(group in data_partition) for group in splitting_groups]
      for data_partition in partition
   ]
   return dummies

def build_dummy_df(dummies,id,splitting_groups):
   df=pd.DataFrame(
        dummies,
        columns=groups_to_split_into
   )
   df['pop_id']=id
   return df


In [4]:
splitting_df=pd.concat([
    build_dummy_df(get_dummies(partition,groups_to_split_into),id,groups_to_split_into)
         for id,partition in group_partitions.items()
])


In [5]:
observations=[
    7,
    20,
    10,
    11,
    10,
    0.8,
    8,
    22
]

SE_vals=[
    1,
    2,
    3,
    1,
    1.5,
    0.1,
    1,
    3
]

data_df=pop_df.merge(splitting_df,on='pop_id')

data_df['obs']=observations
data_df['obs_se']=SE_vals

## Example input data

In [6]:
data_df

Unnamed: 0,pop_id,baseline_id,0,1,2,3,obs,obs_se
0,0,0,1,1,0,0,7.0,1.0
1,0,0,0,0,1,1,20.0,2.0
2,1,0,1,1,1,0,10.0,3.0
3,1,0,0,0,0,1,11.0,1.0
4,2,1,1,1,1,1,10.0,1.5
5,3,1,1,0,0,0,0.8,0.1
6,3,1,0,1,0,0,8.0,1.0
7,3,1,0,0,1,1,22.0,3.0


In [7]:
baseline_patterns

Unnamed: 0_level_0,0,1,2,3
baseline_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0.4,0.2,0.3,0.75
1,0.1,0.15,0.2,0.7


In [9]:
population_sizes

Unnamed: 0_level_0,0,1,2,3
pop_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,10,10,20,20
1,10,20,20,13
2,20,20,10,10
3,5,40,30,20


In [10]:
split_result=data_df.set_index('pop_id').apply(
    lambda x:split_datapoint(
        x['obs'],
        population_sizes.loc[x.name]*x[groups_to_split_into],
        baseline_patterns.loc[x['baseline_id']],
        model=RateMultiplicativeModel()
    ),
    axis=1
).reset_index().groupby('pop_id').sum()
split_result['total']=split_result.sum(axis=1)
split_result

Unnamed: 0_level_0,0,1,2,3,total
pop_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,4.666667,2.333333,5.714286,14.285714,27.0
1,2.857143,2.857143,4.285714,11.0,21.0
2,1.428571,2.142857,1.428571,5.0,10.0
3,0.8,8.0,6.6,15.4,30.8


In [11]:
data_df.groupby('pop_id')['obs'].sum()

pop_id
0    27.0
1    21.0
2    10.0
3    30.8
Name: obs, dtype: float64

In [14]:
def split_dataframe(
    groups_to_split_into,
    observation_group_membership_df,
    population_sizes,
    baseline_patterns,
    use_se=False,
    model=LogOdds_model(),
    ):
    '''
    Disaggregates datapoints and pivots observations into estimates for each group per pop id

    groups_to_split_into: list of groups to disaggregate observations into

    observation_group_membership_df: dataframe with columns pop_id, baseline_id, obs, 
        and columns for each of the groups_to_split_into
        with dummy variables that represent whether or not 
        each group is included in the observations for that row. 
        This also optionally contains a obs_se column which will be used if use_se is True
        pop_id represents the population that the observation comes from
        baseline_id gives the baseline that should be used for splitting

    population_sizes: dataframe with pop_id as the index containing the 
        size of each group within each population (given the pop_id)

    baseline_prevalences: dataframe with baseline_id as the index, and columns 
        for each of the groups_to_split where the entries represent the baseline
        prevalence in the given group to use for splitting. 

    use_se: Boolean, whether or not to report standard errors along with estimates
        if set to True, then observation_group_membership_df must have an obs_se column
    '''
    splitting_df=observation_group_membership_df.copy()
    if use_se==False:
        def split_row(x):
            return split_datapoint(
                    x['obs'],
                    population_sizes.loc[x.name]*x[groups_to_split_into],
                    baseline_patterns.loc[x['baseline_id']],
                    model=model
                )
        result=(
            splitting_df
            .set_index('pop_id')
            .apply(
                split_row,
                axis=1)
            .reset_index()
            .groupby('pop_id')
            .sum()
        )
    result['total']=result.sum(axis=1)
    return result


In [15]:
for i in range(5):
    split_dataframe(
        groups_to_split_into,
        data_df,
        population_sizes,
        baseline_patterns
    )

In [16]:
big_df=data_df.copy()
big_pops=population_sizes.copy()

for i in range(1,50):
    new_data=data_df.copy()
    new_data['pop_id']+=10*i
    big_df=pd.concat([big_df,new_data])

    new_pop=population_sizes.copy()
    new_pop.index+=10*i
    big_pops=pd.concat([big_pops,new_pop])

big_df=big_df.reset_index(drop=True)

In [17]:
split_dataframe(
        groups_to_split_into=groups_to_split_into,
        observation_group_membership_df=big_df,
        population_sizes=big_pops,
        baseline_patterns=baseline_patterns,
        model=LMO_model(2)
    )

Unnamed: 0_level_0,0,1,2,3,total
pop_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,4.597089,2.402911,5.472041,14.527959,27.0
1,2.951925,2.775020,4.273055,11.000000,21.0
2,1.225946,1.853404,1.249221,5.671428,10.0
3,0.800000,8.000000,7.134617,14.865383,30.8
10,4.597089,2.402911,5.472041,14.527959,27.0
...,...,...,...,...,...
483,0.800000,8.000000,7.134617,14.865383,30.8
490,4.597089,2.402911,5.472041,14.527959,27.0
491,2.951925,2.775020,4.273055,11.000000,21.0
492,1.225946,1.853404,1.249221,5.671428,10.0


In [19]:
big_df

Unnamed: 0,pop_id,baseline_id,0,1,2,3,obs,obs_se
0,0,0,1,1,0,0,7.0,1.0
1,0,0,0,0,1,1,20.0,2.0
2,1,0,1,1,1,0,10.0,3.0
3,1,0,0,0,0,1,11.0,1.0
4,2,1,1,1,1,1,10.0,1.5
...,...,...,...,...,...,...,...,...
395,491,0,0,0,0,1,11.0,1.0
396,492,1,1,1,1,1,10.0,1.5
397,493,1,1,0,0,0,0.8,0.1
398,493,1,0,1,0,0,8.0,1.0


In [18]:
split_dataframe(
        groups_to_split_into=groups_to_split_into,
        observation_group_membership_df=big_df,
        population_sizes=big_pops,
        baseline_patterns=baseline_patterns,
        model=LogOdds_model()
    )

Unnamed: 0_level_0,0,1,2,3,total
pop_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,4.587886,2.412114,5.485838,14.514162,27.0
1,2.976175,2.742213,4.281612,11.000000,21.0
2,1.195737,1.834592,1.251663,5.718007,10.0
3,0.800000,8.000000,7.121403,14.878597,30.8
10,4.587886,2.412114,5.485838,14.514162,27.0
...,...,...,...,...,...
483,0.800000,8.000000,7.121403,14.878597,30.8
490,4.587886,2.412114,5.485838,14.514162,27.0
491,2.976175,2.742213,4.281612,11.000000,21.0
492,1.195737,1.834592,1.251663,5.718007,10.0
