In [None]:
import numpy as np
import pandas as pd
import os
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder
from pyspark.ml.regression import GeneralizedLinearRegression
from pyspark.sql.functions import col, count, mean, sum, avg, stddev, min, max, lit
from pyspark.ml.stat import Summarizer
from pyspark.ml.classification import LogisticRegression

In [None]:
path_to_data = os.path.join(os.getcwd(),'data')
path_to_data

In [None]:
spark.sql("use real_world_data_ed_omop_aug_2024")
spark.sql("show tables").toPandas()

In [None]:
spark.catalog.refreshTable("delivery_elig_init_smm")
delivery_elig_init_smm = spark.sql('''
    select * from delivery_elig_init_smm
''').cache()
delivery_elig_init_smm.first()

In [None]:
spark.catalog.refreshTable("delivery_elig_smm")
delivery_elig_smm = spark.sql('''
    select * from delivery_elig_smm
''').cache()
delivery_elig_smm.first()

In [None]:
# import itertools
# def pt_freq_qry(df,stratified_by,n_way=1):
#     '''
#     generate total patient counts for each stratified variables
#     '''
#     sql_str_lst = []
#     # overall count
#     sql_str_lst.append("select 'total' as summ_var,'N' as summ_cat, count(distinct person_id) as pat_cnt, count(distinct person_id || '_' || event_id) as evt_cnt from " + df)
    
#     # 1-way summary
#     for var_str in stratified_by:
#         sql_str_lst.append(
#             "select '" + var_str +"' as summ_var," 
#             + "cast(" + var_str +" as string) as summ_cat,"
#             + "count(distinct person_id) as pat_cnt, "
#             + "count(distinct person_id || '_' || event_id) as evt_cnt "
#             + "from "+ df + " group by "+ var_str
#         )
        
#     # up to n-way summary
#     if n_way > 1:
#         for L in range(2,n_way+1,1):
#             for var_str_comb in itertools.combinations(stratified_by, L):
#                 var_str_concat_by_pipe = "|".join(var_str_comb)
#                 var_str_concat_by_dpipe = "|| '||' ||".join(var_str_comb)
#                 var_str_concat_by_comma = ",".join(var_str_comb)
#                 sql_str_lst.append(
#                     "select 'by_" + var_str_concat_by_pipe +"' as summ_var," 
#                     + "cast(" + var_str_concat_by_dpipe +" as string) as summ_cat,"
#                     + "count(distinct person_id) as pat_cnt, "
#                     + "count(distinct person_id || '_' || event_id) as evt_cnt "
#                     + "from "+ df + " group by "+ var_str_concat_by_comma
#                 )
                
#     # union everything
#     sql_str = " union ".join(sql_str_lst)
#     return(sql_str)

In [None]:
# stratified_by = [
#     'agegrp_at_event',
#     'race_source_value',
#     'ethnicity_source_value',
#     'segment',
#     'speciality',
#     'bed_size',
#     'zip_code',
#     'death_ind',
#     'SMMANY_ind'
# ]
# get_pt_summ = pt_freq_qry(
#     'delivery_elig_tbl1',stratified_by,n_way=2
# )
# summ_stat_long = spark.sql(get_pt_summ).toPandas()
# summ_stat_long.to_csv(os.path.join(path_to_data,'summ_stat_2way.csv'),index=False)

In [None]:
denom_df = pd.DataFrame(
    {
        'N_delivery': delivery_elig_smm.count(),
        'N_person': delivery_elig_smm.select('person_id').distinct().count(),
        'N_init_delivery':delivery_elig_init_smm.count(),
        'Dt_first':delivery_elig_smm.agg({"event_start_date": "min"}).collect()[0][0],
        'Dt_last':delivery_elig_smm.agg({"event_start_date": "max"}).collect()[0][0],
        'LOS_90PCT':delivery_elig_smm.approxQuantile('los',[0.9],0.01)
    }
)
denom_df.to_csv(os.path.join(path_to_data,'denom_summ.csv'),index=False)

In [None]:
# one hot encoding
def ohe_with_map(
    df,              # spark dataframe
    cat_cols         # list of categorical columns for ohe
):   
    # Index the categorical columns
    indexers = [StringIndexer(inputCol=col, outputCol=col+"_index", stringOrderType="frequencyAsc").fit(df) for col in cat_cols]
    
    # Apply the indexers to the DataFrame and collect the encoding map
    index_maps = {}
    for indexer in indexers:
        df = indexer.transform(df)
        index_maps[indexer.getInputCol()] = dict(enumerate(indexer.labels))
    
    # One-hot encode the indexed columns
    encoders = [OneHotEncoder(inputCol=col + "_index", outputCol=col + "_ohe") for col in cat_cols]
    for encoder in encoders:
        df = encoder.transform(df)
    
    return df, index_maps 

In [None]:
delivery_elig_init_smm_ohe = ohe_with_map(
    delivery_elig_init_smm,
    cat_cols = [
        "race_source_value",
        "ethnicity_source_value",
        "agegrp_at_event",
        "bed_size",
        "segment",
        "speciality",
        "zip_code",
        "delivery_type"
    ]
)
delivery_elig_init_smm2 = delivery_elig_init_smm_ohe[0]
delivery_elig_init_smm2.first()

In [None]:
delivery_elig_init_smm_ohe[1]

In [None]:
cov_cat_lst = [
     'race_source_value_ohe'
    ,'ethnicity_source_value_ohe'
    ,'bed_size_ohe'
    ,'speciality_ohe'
    ,'segment_ohe'
    ,'zip_code_ohe'
    ,'agegrp_at_event_ohe'
    ,'delivery_type_ohe'
    ,'los2up_ind'
    ,'los3up_ind'
    ,'los4up_ind'
    ,'los5up_ind'
    ,'los6up_ind'
    ,'los7up_ind'
    ,'hist_HEP_ind'
    ,'hist_IHD_ind'
    ,'hist_AST_ind'
    ,'hist_LIV_ind'
    ,'hist_AFIB_ind'
    ,'hist_STR_ind'
    ,'hist_CKD_ind'
    ,'hist_COPD_ind'
    ,'hist_HTN_ind'
    ,'hist_HF_ind'
    ,'hist_AIDS_ind'
    ,'hist_PVD_ind'
    ,'hist_RA_ind'
    ,'hist_AD_ind'
    ,'hist_DM_ind'
    ,'hist_COVID_ind'
    ,'hist_SUB_ind'
    ,'hist_ALC_ind'
]
cov_num_lst = [
    'age_at_event'
    ,'los'
    ,'hist_HEP_since_index'
    ,'hist_IHD_since_index'
    ,'hist_AST_since_index'
    ,'hist_LIV_since_index'
    ,'hist_AFIB_since_index'
    ,'hist_STR_since_index'
    ,'hist_CKD_since_index'
    ,'hist_COPD_since_index'
    ,'hist_HTN_since_index'
    ,'hist_HF_since_index'
    ,'hist_AIDS_since_index'
    ,'hist_PVD_since_index'
    ,'hist_RA_since_index'
    ,'hist_AD_since_index'
    ,'hist_DM_since_index'
    ,'hist_COVID_since_index'
    ,'hist_SUB_since_index'
    ,'hist_ALC_since_index'
]

cov_lst = cov_cat_lst + cov_num_lst

out_lst = [
     'death_ind' #961, 
    ,'SMMANY_ind' #30267, 1.3%
    ,"SMMANY90PCT_ind"
#     ,'AMI_ind'
#     ,'ANE_ind'
#     ,'ARF_ind'
#     ,'ARDS_ind'
#     ,'AFE_ind'
#     ,'CAVF_ind'
    ,'COCR_ind' #386
#     ,'DIC_ind'
#     ,'ECL_ind'
#     ,'HF_ind'
#     ,'PCD_ind'
#     ,'PE_ind'
#     ,'SAC_ind'
#     ,'SEP_ind'
#     ,'SSH_ind'
#     ,'SCC_ind'
#     ,'ATE_ind'
    ,'BPT_ind' #26734,1.1%
    ,'HYS_ind' #3188, 0.1%
    ,'TT_ind' #119
    ,'VEN_ind' #1946
]

since_index_lst = [
     'death_since_index'
    ,'SMMANY_since_index'
    ,'SMMANY90PCT_since_index'
]

In [None]:
def summ_gen(
    df, 
    cols,
    cat_switch = False,
    outcome = None
):
    N = df.count()
    res_dfs = []
    for col in cols:
        if cat_switch:
            if outcome:
                summary = df.groupBy(col,outcome).agg(
                    count(col).alias("count"),
                    (count(col)/N).alias("prop")
                ).withColumn("var", lit(col)).withColumnRenamed(col, "cat")
            else:
                summary = df.groupBy(col).agg(
                    count(col).alias("count"),
                    (count(col)/N).alias("prop")
                ).withColumn("var", lit(col)).withColumnRenamed(col, "cat")
        else:
            if outcome:
                summary = df.groupBy(outcome).agg(
                    lit(col).alias("var"),
                    mean(col).alias("mean"),
                    stddev(col).alias("stddev"),
                    min(col).alias("min"),
                    max(col).alias("max")
                ).withColumn("var", lit(col))
            else:
                summary = df.agg(
                    lit(col).alias("var"),
                    mean(col).alias("mean"),
                    stddev(col).alias("stddev"),
                    min(col).alias("min"),
                    max(col).alias("max")
                ).withColumn("var", lit(col))
        
        res_dfs.append(summary.toPandas())

    # Concatenate the Pandas DataFrames into a single DataFrame
    res = pd.concat(res_dfs, ignore_index=True)
    return res

In [None]:
res_init_num = summ_gen(
    delivery_elig_init_smm2,
    cov_num_lst+out_lst+since_index_lst,
    outcome = "SMMANY_IND"
)
res_init_num.to_csv(os.path.join(path_to_data,'summ_num_init_smm.csv'),index=False)

In [None]:
res_init_num = summ_gen(
    delivery_elig_init_smm2,
    cov_num_lst+out_lst+since_index_lst,
    outcome = "SMMANY90PCT_IND"
)
res_init_num.to_csv(os.path.join(path_to_data,'summ_num_init_smm90pct.csv'),index=False)

In [None]:
res_init_num = summ_gen(
    delivery_elig_init_smm2,
    cov_num_lst+out_lst+since_index_lst,
    outcome = "death_ind"
)
res_init_num.to_csv(os.path.join(path_to_data,'summ_num_init_dth.csv'),index=False)

In [None]:
cov_cat_lst = [
     'race_source_value'
    ,'ethnicity_source_value'
    ,'bed_size'
    ,'speciality'
    ,'segment'
    ,'zip_code'
    ,'agegrp_at_event'
    ,'los2up_ind'
    ,'los3up_ind'
    ,'los4up_ind'
    ,'los5up_ind'
    ,'los6up_ind'
    ,'los7up_ind'
    ,'hist_HEP_ind'
    ,'hist_IHD_ind'
    ,'hist_AST_ind'
    ,'hist_LIV_ind'
    ,'hist_AFIB_ind'
    ,'hist_STR_ind'
    ,'hist_CKD_ind'
    ,'hist_COPD_ind'
    ,'hist_HTN_ind'
    ,'hist_HF_ind'
    ,'hist_AIDS_ind'
    ,'hist_PVD_ind'
    ,'hist_RA_ind'
    ,'hist_AD_ind'
    ,'hist_DM_ind'
    ,'hist_COVID_ind'
    ,'hist_SUB_ind'
    ,'hist_ALC_ind'
]
res_init_cat = summ_gen(
    delivery_elig_init_smm2,
    cov_cat_lst+out_lst,
    cat_switch = True
)
res_init_cat.to_csv(os.path.join(path_to_data,'summ_init_cat.csv'),index=False)

In [None]:
res_init_cat = summ_gen(
    delivery_elig_init_smm2,
    cov_cat_lst,
    cat_switch = True,
    outcome = "SMMANY_IND"
)
res_init_cat.to_csv(os.path.join(path_to_data,'summ_init_cat_smm.csv'),index=False)

In [None]:
res_init_cat = summ_gen(
    delivery_elig_init_smm2,
    cov_cat_lst,
    cat_switch = True,
    outcome = "SMMANY90PCT_IND"
)
res_init_cat.to_csv(os.path.join(path_to_data,'summ_init_cat_smm90pct.csv'),index=False)

In [None]:
res_init_cat = summ_gen(
    delivery_elig_init_smm2,
    cov_cat_lst,
    cat_switch = True,
    outcome = "death_ind"
)
res_init_cat.to_csv(os.path.join(path_to_data,'summ_cat_init_dth.csv'),index=False)

In [None]:
# # quick single-var summary
# N = delivery_elig_smm2.count()
# agg_df = delivery_elig_smm2.groupBy("bed_size").agg(
#     count("bed_size").alias("count"),
# #     (count("bed_size_ohe")/N).alias('prop'),
#     sum("death_ind").alias("count_death"),
# #     (sum("death_ind")/count("bed_size_ohe")).alias("prop_death"),
#     sum("SMMANY_ind").alias("count_smm")
# #     (sum("SMMANY_ind")/count("bed_size_ohe")).alias("prop_summ")
    
# )
# agg_df.show()

In [None]:
def univar_analysis(
    df,              # spark dataframe
    covariate_cols,  # list of covariates (assume ohe already applied)
    outcome_cols,    # list of outcomes
    outcome_types,   # list of outcome types
    verbose = True   # report progress
):
    # global glm family mapping based on outcome types
    family_map = {
        "bin": "binomial",
        "con": "gaussian",
        "dis": "poisson",
        "pos": "gamma",
        "mix": "tweedie"
    }
    
    odds_ratios = {}
    for idx, outcome in enumerate(outcome_cols):  
        for idx2, covariate in enumerate(covariate_cols):              
            # Fit univariate glm
            vector_assembler = VectorAssembler(inputCols=[covariate], outputCol="features")
            df_assembled = vector_assembler.transform(df)
            glr = GeneralizedLinearRegression(
                family=family_map[outcome_types[idx]], 
                link="logit",
                featuresCol="features", 
                labelCol=outcome
            )
            model = glr.fit(df_assembled)
            summary = model.summary

            # Extract coefficients and calculate odds ratios
            for i, coef in enumerate(model.coefficients):
                odds_ratio = np.exp(coef) 
                if hasattr(summary, 'coefficientStandardErrors'):
                    coefficient_standard_error = summary.coefficientStandardErrors[i]
                else:
                    coefficient_standard_error = 0
                conf_lower = np.exp(coef - 1.96 * coefficient_standard_error)
                conf_upper= np.exp(coef + 1.96 * coefficient_standard_error)
                
                # gather results
                odds_ratios[f"{outcome}_{covariate}_{i}"]={
                    'outcome': outcome,
                    'var': covariate,
                    "encoded": i,
                    "coef": coef,
                    "odds_ratio": odds_ratio,
                    "conf_lower": conf_lower,
                    "conf_upper": conf_upper,
                    "pval": summary.pValues[i]
                }
                
            # report progress         
            if verbose:
                print(f"processed:outcome={outcome};covariate={covariate} \n")
                
    return odds_ratios

In [None]:
cov_cat_lst = [
     'race_source_value_ohe'
    ,'ethnicity_source_value_ohe'
    ,'bed_size_ohe'
    ,'speciality_ohe'
    ,'segment_ohe'
    ,'zip_code_ohe'
    ,'agegrp_at_event_ohe'
    ,'delivery_type_ohe'
    ,'los2up_ind'
    ,'los3up_ind'
    ,'los4up_ind'
    ,'los5up_ind'
    ,'los6up_ind'
    ,'los7up_ind'
    ,'hist_HEP_ind'
    ,'hist_IHD_ind'
    ,'hist_AST_ind'
    ,'hist_LIV_ind'
    ,'hist_AFIB_ind'
    ,'hist_STR_ind'
    ,'hist_CKD_ind'
    ,'hist_COPD_ind'
    ,'hist_HTN_ind'
    ,'hist_HF_ind'
    ,'hist_AIDS_ind'
    ,'hist_PVD_ind'
    ,'hist_RA_ind'
    ,'hist_AD_ind'
    ,'hist_DM_ind'
    ,'hist_COVID_ind'
    ,'hist_SUB_ind'
    ,'hist_ALC_ind'
]
cov_num_lst = [
     'age_at_event'
    ,'los'
]

cov_lst = cov_cat_lst + cov_num_lst

out_lst = [
     'death_ind' #961, 
    ,'SMMANY_ind' #30267, 1.3%
    ,"SMMANY90PCT_ind"
]
type_lst = ['bin']*len(out_lst)

In [None]:
type_lst = ['bin']*len(out_lst)
res_init = univar_analysis(
    df = delivery_elig_init_smm2,
    covariate_cols = cov_lst,
    outcome_cols = out_lst,
    outcome_types = type_lst
)
res_init_df = pd.json_normalize(res_init.values())
res_init_df.to_csv(os.path.join(path_to_data,'univar_filter_init.csv'), index=False)