In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from scipy.stats.stats import pearsonr

In [2]:
#Stratification functions

def create_strat_var(some_data, strat_col_names):
    """Creates a variable allowing stratification.
    
    By concatentating the categorical/string values of each variable
    you want to use to stratify, we create a variable with levels that
    represent every extant combination. When passed to the 
    stratification algorithm, this produces subsamples with closely
    matched marginal proportions.

    Args:
        some_data (pandas data frame): The data frame containing the columns you want to use to stratify.
        strat_col_names (list of strings): Column names of the columns in `some_data` to use for stratification.

    Returns:
        A list of strings with the rowwise-concatenated values.
        
    """

    string_data=some_data[strat_col_names].astype(str)
    strat_list=['_'.join(x) for x in string_data[strat_col_names].to_numpy()]
    return(strat_list)

def get_holdout(some_data, strat_col_names, test_size=.2, random_state=None):
    """Generates a stratified train and test set.
    
    Creates a train and test set of variables with roughly equal proportions of the marginals
    of the columns in `strat_col_names`. Providing a data set with just the indices and 
    stratification columns will return a train and test set of indices.

    Args:
        some_data (pandas data frame): The data frame containing the columns you want to use to stratify.
        strat_col_names (list of strings): Column names of the columns in `some_data` to use for stratification.

    Returns:
        Train and test data frames.
        
    """
    strata = create_strat_var(some_data, strat_col_names)
    X_train, X_test = train_test_split(some_data, test_size = test_size, stratify = strata, random_state=random_state)
    return X_train, X_test

def do_cross_tabs(some_data, index_name, col_names):
    """Produces cross tabs.
    
    Useful for checking stratification results.

    Args:
        some_data (pandas data frame): The data frame containing the columns you want to use to stratify.
        index_name (str): Name of the rows, as in `pandas.crosstab`.
        col_names (list of strings): Column names of the columns in `some_data` to use in the crosstab columns.

    Returns:
        Train and test data frames.
        
    """
    data_cols = []
    for col_name in col_names:
        data_cols.append(some_data[col_name].values)
    xtab = pd.crosstab(index = some_data[index_name],
                       columns = data_cols,
                       colnames = col_names,
                       margins = True) / some_data.shape[0]
    return(xtab)

In [3]:
#Test stratification

ncs = pd.read_csv("/users/jflournoy/otherhome/data/ncs-a/ncs_psaq_full.csv", encoding = "ISO-8859-1", low_memory=False)

ncs_ = ncs
ncs = ncs_.loc[~ncs["SC7"].isin(["D", "R"])]
#ncs.filter(regex = 'SC.*', axis = 1)
demo_dict = {
    "Age" : "SC1",
    "Handedness" : "SC5_1",
    "Smoker": "SC7"
}
ncs = ncs.assign(Age_bin = pd.cut(ncs[demo_dict["Age"]], 3, labels = ["young", "medium", "old"]).astype(str),
                 Educ = pd.cut(ncs.educ, bins=pd.IntervalIndex.from_breaks([0,7,10,13]))
)
#ncs[demo_dict["Smoker"]].unique()

X_train, X_test = get_holdout(ncs, test_size = .2, 
                              strat_col_names=["Age_bin", "birth_order_cat", "peduc_cat"], 
                              random_state=123121)

display(do_cross_tabs(ncs, "Age_bin", ["birth_order_cat", "peduc_cat"]))
display(do_cross_tabs(X_train, "Age_bin", ["birth_order_cat", "peduc_cat"]))
display(do_cross_tabs(X_test, "Age_bin", ["birth_order_cat", "peduc_cat"]))

birth_order_cat,1,1,1,1,2,2,2,2,3,3,3,3,All
peduc_cat,<HS,College Grad,HS,Some College,<HS,College Grad,HS,Some College,<HS,College Grad,HS,Some College,Unnamed: 13_level_1
Age_bin,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
medium,0.016277,0.052481,0.038177,0.026142,0.014896,0.041334,0.031962,0.019828,0.027622,0.039262,0.048732,0.028115,0.384828
old,0.01095,0.031863,0.025846,0.018941,0.007004,0.022196,0.017559,0.013909,0.017559,0.021308,0.029891,0.01677,0.233797
young,0.021604,0.045477,0.036599,0.026339,0.016968,0.039262,0.030778,0.01973,0.034034,0.038473,0.044885,0.027227,0.381375
All,0.048831,0.129821,0.100621,0.071422,0.038868,0.102792,0.0803,0.053467,0.079215,0.099043,0.123508,0.072112,1.0


birth_order_cat,1,1,1,1,2,2,2,2,3,3,3,3,All
peduc_cat,<HS,College Grad,HS,Some College,<HS,College Grad,HS,Some College,<HS,College Grad,HS,Some College,Unnamed: 13_level_1
Age_bin,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
medium,0.016278,0.052411,0.038106,0.026144,0.014922,0.041312,0.03194,0.019854,0.027624,0.039216,0.048711,0.028117,0.384634
old,0.010975,0.031817,0.025897,0.018991,0.007029,0.022198,0.017511,0.013935,0.017511,0.021334,0.029843,0.016771,0.233814
young,0.021581,0.045505,0.036626,0.02639,0.017018,0.039216,0.03083,0.019731,0.034036,0.038476,0.044888,0.027254,0.381551
All,0.048835,0.129732,0.100629,0.071525,0.038969,0.102725,0.080281,0.053521,0.079171,0.099026,0.123443,0.072142,1.0


birth_order_cat,1,1,1,1,2,2,2,2,3,3,3,3,All
peduc_cat,<HS,College Grad,HS,Some College,<HS,College Grad,HS,Some College,<HS,College Grad,HS,Some College,Unnamed: 13_level_1
Age_bin,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
medium,0.016272,0.052761,0.038462,0.026134,0.014793,0.04142,0.032051,0.019724,0.027613,0.039448,0.048817,0.028107,0.385602
old,0.010848,0.032051,0.025641,0.018738,0.006903,0.022189,0.017751,0.013807,0.017751,0.021203,0.030079,0.016765,0.233728
young,0.021696,0.045365,0.036489,0.026134,0.016765,0.039448,0.030572,0.019724,0.034024,0.038462,0.044872,0.02712,0.380671
All,0.048817,0.130178,0.100592,0.071006,0.038462,0.103057,0.080375,0.053254,0.079389,0.099112,0.123767,0.071992,1.0
