In [None]:
import pandas as pd
import json
import urllib.request as urlreq
import os
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder
from pyspark.ml.classification import LogisticRegression
from pyspark.sql.functions import col
import numpy as np

In [None]:
path_to_data = os.path.join(os.getcwd(),'University of Missouri','DRIVERS','data')
path_to_data

In [None]:
spark.sql("use real_world_data_ed_omop_dec_2023")

In [None]:
cd_meta = spark.sql('''
    select concept_id,concept_name,concept_code,vocabulary_id,domain_id
    from concept
    where vocabulary_id = 'DRG' and
          concept_code in (
            '765','766','767','768',
            '774','775',
            '783','784','785','786', '787','788',
            '796','797','798',
            '805','806','807'
          )
    union all 
    select concept_id,concept_name,concept_code,vocabulary_id,domain_id
    from concept
    where vocabulary_id = 'CPT4' and
          concept_code in (
            '59409','59514', '59612','59620'
          )
    union all 
    select concept_id,concept_name,concept_code,vocabulary_id,domain_id
    from concept
    where vocabulary_id = 'ICD10PCS' and
          concept_code in (
            '10D00Z0','10D00Z1','10D00Z2','10D07Z3','10D07Z4', '10D07Z5', '10D07Z6','10D07Z7','10D07Z8',
            '10E0XZZ'
          )
    union all
    select concept_id,concept_name,concept_code,vocabulary_id,domain_id
    from concept
    where vocabulary_id = 'LOINC' and
          concept_code in (
            
          )
''').toPandas()
cd_meta.to_csv(os.path.join(path_to_data,'cd_meta_omop.csv'),index=False)

In [None]:
def cd_where_clause(meta_tbl,voc_id):
    '''
    generate where clause details based on omop concept_id
    '''
    cd_lst = meta_tbl.loc[meta_tbl['vocabulary_id']==voc_id,'concept_id'].tolist()
    cd_quote = []
    for code in cd_lst:
        cd_quote.append("'"+ str(code) +"'")
    cd_quote_str = ",".join(cd_quote)
    return cd_quote_str     

meta_tbl = pd.read_csv(os.path.join(path_to_data,'cd_meta_omop.csv'))
drg_where = cd_where_clause(meta_tbl,'DRG')
cpt4_where = cd_where_clause(meta_tbl,'CPT4')
icd10_where = cd_where_clause(meta_tbl,'ICD10PCS')

In [None]:
delivery_init = spark.sql('''
    select person_id,
           visit_occurrence_id, 
           visit_detail_id, 
           observation_date as event_date, 
           observation_concept_id as event_identifier,
           'DRG' as event_source
    from observation
    where observation_concept_id in ('''+ drg_where +''')
    union all
    select person_id,
           visit_occurrence_id, 
           visit_detail_id, 
           procedure_date as event_date, 
           procedure_concept_id as event_identifier,
           'CPT4' as event_source
    from procedure_occurrence
    where procedure_concept_id in ('''+ cpt4_where +''')
    union all
    select person_id,
           visit_occurrence_id, 
           visit_detail_id, 
           procedure_date as event_date, 
           procedure_concept_id as event_identifier,
           'ICD10PCS' as event_source
    from procedure_occurrence
    where procedure_concept_id in ('''+ icd10_where +''')
''').cache()
delivery_init.createOrReplaceTempView("delivery_init")
delivery_init.first()

In [None]:
spark.sql('''
    select event_source, count(distinct person_id)
    from delivery_init
    group by event_source 
''').toPandas()

In [None]:
delivery_ip = spark.sql('''
    select distinct
           a.person_id, 
           a.visit_occurrence_id,
           v.visit_start_date,
           v.visit_end_date,
           v.care_site_id
    from delivery_init a 
    join visit_occurrence v 
    on a.person_id = v.person_id and 
       a.visit_occurrence_id = v.visit_occurrence_id
    where v.visit_concept_id in (
        9201, -- IP
        262   -- ER to IP
    )
''').cache()
delivery_ip.createOrReplaceTempView("delivery_ip")
delivery_ip.first()

In [None]:
delivery_consolidate = spark.sql('''
    with cd_filter as (
        select v.person_id, 
               v.visit_occurrence_id,
               v.care_site_id,
               v.visit_start_date,
               v.visit_end_date,
               a.event_source,
               a.event_date,
               row_number() over (partition by v.person_id, v.visit_occurrence_id, a.event_source order by a.event_date) as rn_asc,
               row_number() over (partition by v.person_id, v.visit_occurrence_id, a.event_source order by a.event_date desc) as rn_desc
        from delivery_ip v
        join delivery_init a 
        on v.person_id = a.person_id and 
           v.visit_occurrence_id = a.visit_occurrence_id
        where a.event_date between date_sub(v.visit_start_date,3) and date_add(v.visit_end_date,3)
    ), f_pvt as (
        select * 
        from (
            select person_id, visit_occurrence_id,
                   event_source, event_date
            from cd_filter
            where rn_asc = 1       
        )
        pivot (
            min(event_date) for event_source in (
                'DRG' as F_DRG_DT,'CPT4' as F_CPT_DT,'ICD10PCS' as F_ICD_DT
            )
        )
    ), l_pvt as (
        select * 
        from (
            select person_id, visit_occurrence_id,
                   event_source, event_date
            from cd_filter
            where rn_desc = 1       
        )
        pivot (
            max(event_date) for event_source in (
                'DRG' as L_DRG_DT,'CPT4' as L_CPT_DT,'ICD10PCS' as L_ICD_DT
            )
        )
    )
    select a.person_id, 
           a.visit_occurrence_id,
           a.visit_start_date,
           a.visit_end_date,
           a.care_site_id,
           f.F_DRG_DT,f.F_CPT_DT,f.F_ICD_DT,
           l.L_DRG_DT,l.L_CPT_DT,l.L_ICD_DT
    from delivery_ip a 
    left join f_pvt f on a.person_id = f.person_id and a.visit_occurrence_id = f.visit_occurrence_id
    left join l_pvt l on a.person_id = l.person_id and a.visit_occurrence_id = l.visit_occurrence_id
''').cache()
delivery_consolidate.createOrReplaceTempView("delivery_consolidate")
delivery_consolidate.first()

In [None]:
delivery_elig = spark.sql('''
    with date_consolid as (
        select distinct 
               person_id,
               visit_occurrence_id,
               care_site_id,
               coalesce(F_DRG_DT,visit_start_date,F_ICD_DT,F_CPT_DT) as event_start_dt,
               coalesce(L_DRG_DT,visit_end_date,L_ICD_DT,L_CPT_DT) as event_end_dt
        from delivery_consolidate    
    ), visit_diffs as (
        select a.*, 
               lead(a.event_start_dt, 1, '9999-12-31') OVER (PARTITION BY person_id ORDER BY event_start_dt) AS next_event_start_dt
        from date_consolid a 
    ), visit_session as (
        select b.*, 
               case 
                   when datediff(b.next_event_start_dt,b.event_start_dt) > 211 then 1
                   else 0 
               end as new_session_flag
        from visit_diffs b
    ), sessions as (
        select d.*, 
               sum(d.new_session_flag) over (PARTITION BY d.person_id ORDER BY d.event_start_dt) as event_id
        from visit_session d
    ), session_order as (
        select e.*, 
               row_number() over (partition by e.person_id, e.event_id order by e.event_start_dt) as rn,
               max(e.event_end_dt) over (partition by e.person_id, e.event_id) as event_end_date
    from sessions e
    )
    select s.person_id, 
           s.event_id, 
           s.visit_occurrence_id,
           cs.care_site_source_value,
           s.event_start_dt as event_start_date,
           s.event_end_date
    from session_order s 
    join care_site cs on s.care_site_id = cs.care_site_id
    where s.rn = 1
    order by s.person_id, s.event_id
''').cache()
delivery_elig.createOrReplaceTempView("delivery_elig")
delivery_elig.first()

In [None]:
delivery_elig_tbl1 = spark.sql('''
    select d.person_id,
           d.event_id,
           d.event_start_date, 
           d.event_end_date,
           coalesce(datediff(d.event_end_date,d.event_start_date),1) as los, 
           d.visit_occurrence_id,
           p.year_of_birth,
           year(d.event_start_date) - p.year_of_birth as age_at_event,
           p.month_of_birth,
           p.day_of_birth,
           p.race_source_value,
           p.ethnicity_source_value,
           p.location_id,
           p.care_site_id,
           d.care_site_source_value,
           tnt.bed_size,
           tnt.speciality,
           tnt.segment,
           tnt.zip_code,
           dth.death_date,
           case when dth.death_date is not null then 1 else 0 end as death_ind
    from delivery_elig d
    join person p on d.person_id = p.person_id
    left join tenant_attributes tnt on d.care_site_source_value = tnt.tenant
    left join death dth on d.person_id = dth.person_id 
    where year(d.event_start_date) - p.year_of_birth between 10 and 55
''').cache()
delivery_elig_tbl1.createOrReplaceTempView("delivery_elig_tbl1")
delivery_elig_tbl1.first()

In [None]:
# load SMM code list and get omop concept_id
json_url = urlreq.urlopen('https://raw.githubusercontent.com/RWD2E/phecdm/main/res/valueset_curated/vs-mmm.json')
smm_json = json.loads(json_url.read())
qry_lst = []
def add_quote(lst):
    lst_quote = ["'"+str(x)+"'" for x in lst]
    return (lst_quote)
for k,v in smm_json.items():
    # exclude delivery codes
    if k.startswith('d_'): continue
    for cd,sig in v.items():
        if cd=='long': continue
        # entail the range
        if 'range' in sig:
            for x in sig['range']:
                key_quote = [str(y) for y in list(range(int(x.split('-')[0]),int(x.split('-')[1])+1))]
                sig['exact'].extend(key_quote)
            
        # generate dynamic queries
        qry = '''
            select ''' + "'" + k + "'" + ''' as SMM_GRP, 
                   ''' + "'" + v['long'] + "'" + ''' as SMM_GRP_LONG,
                   concept_id,concept_name,concept_code,vocabulary_id,domain_id
            from concept
            where vocabulary_id = '''+ "'" + cd.upper() + "'" +''' and
        '''
        if 'icd' in cd and 'pcs' not in cd:
            where_lev0 = '''substring_index(concept_code,'.',1) in ('''+ ','.join(add_quote(sig['lev0'])) +''')''' if sig['lev0'] else None
            where_lev1 = '''substring(concept_code,1,5) in ('''+ ','.join(add_quote(sig['lev1'])) +''')''' if sig['lev1'] else None
            where_lev2 = '''substring(concept_code,1,6) in ('''+ ','.join(add_quote(sig['lev2'])) +''')''' if sig['lev2'] else None
            where_nonempty = [s for s in [where_lev0,where_lev1,where_lev2] if s is not None]
            
            qry += '''
            (
                 ''' + ' or '.join(where_nonempty) + '''  
            )         
            '''
        else:
            qry += '''
            (
                 concept_code in ('''+ ','.join(add_quote(sig['exact'])) +''')
            )         
            '''
        qry_lst.append(qry)

        
# qry_final = ' union all '.join(qry_lst)
# print(qry_final)
# smm_omop_cd = spark.sql(' union all '.join(qry_lst)).toPandas()
# smm_omop_cd.to_csv(os.path.join(path_to_data,'cd_meta_omop_smm.csv'),index=False)

smm_omop_meta = spark.sql(' union all '.join(qry_lst)).cache()
smm_omop_meta.createOrReplaceTempView("smm_omop_meta")
smm_omop_meta.first()

In [None]:
smm_init = spark.sql('''
    select px.person_id,
           px.visit_occurrence_id, 
           px.procedure_date as event_date,
           m.SMM_GRP
    from procedure_occurrence px
    join smm_omop_meta m
    on px.procedure_concept_id = m.concept_id
    where m.vocabulary_id in ('CPT4','HCPCS','ICD9PR','ICD10PCS')
    union all
    select person_id,
           visit_occurrence_id, 
           condition_start_date as event_date,
           m.SMM_GRP
    from condition_occurrence dx
    join smm_omop_meta m
    on dx.condition_concept_id = m.concept_id
    where m.vocabulary_id in ('ICD9CM','ICD10CM')
''').cache()
smm_init.createOrReplaceTempView("smm_init")
smm_init.first()

In [None]:
smm_post_delivery = spark.sql('''
    select a.person_id, b.event_id,
           a.SMM_GRP,
           b.event_start_date,a.event_date,
           datediff(a.event_date,b.event_start_date) AS days_since_index
    from smm_init a 
    join delivery_elig_tbl1 b 
    on a.person_id = b.person_id
    where datediff(a.event_date,b.event_start_date) between 0 and 365
''').cache()
smm_post_delivery.createOrReplaceTempView("smm_post_delivery")
smm_post_delivery.first()

In [None]:
smm_post_delivery_wide = spark.sql('''
    select *
    from (
        select person_id, event_id, SMM_GRP,days_since_index
        from smm_post_delivery
     )
    pivot 
    (
        min(days_since_index) for SMM_GRP in (
            'ami' as AMI_since_index,
            'ane' as ANE_since_index,
            'arf' as ARF_since_index,
            'ards' as ARDS_since_index,
            'afe' as AFE_since_index,
            'cavf' as CAVF_since_index,
            'cocr' as COCR_since_index,
            'dic' as DIC_since_index,
            'ecl' as ECL_since_index,
            'hf' as HF_since_index,
            'pcd' as PCD_since_index,
            'pe' as PE_since_index,
            'sac' as SAC_since_index,
            'sep' as SEP_since_index,
            'ssh' as SSH_since_index,
            'scc' as SCC_since_index,
            'ate' as ATE_since_index,
            'bpt' as BPT_since_index,
            'hys' as HYS_since_index,
            'tt' as TT_since_index,
            'ven' as VEN_since_index
        )
    )
''').cache()
smm_post_delivery_wide.createOrReplaceTempView("smm_post_delivery_wide")
smm_post_delivery_wide.first()

In [None]:
delivery_elig_smm = spark.sql('''
    with smm_any as (
        select person_id, event_id, 1 as SMMANY_ind,
               min(days_since_index) as SMMANY_since_index
        from smm_post_delivery
        group by person_id, event_id
    )
    select e.*,
           case when e.age_at_event between 10 and 19 then 'agegrp1'
                when e.age_at_event between 10 and 29 then 'agegrp2'
                when e.age_at_event between 30 and 39 then 'agegrp3'
                else 'agegrp4' 
           end as agegrp_at_event,
           case when los >= 1 then 1 else 0 end as los1up_ind,
           case when los >= 2 then 1 else 0 end as los2up_ind,
           case when los >= 3 then 1 else 0 end as los3up_ind,
           case when los >= 4 then 1 else 0 end as los4up_ind,
           case when los >= 5 then 1 else 0 end as los5up_ind,
           case when los >= 6 then 1 else 0 end as los6up_ind,
           case when los >= 7 then 1 else 0 end as los7up_ind,           
           a.SMMANY_since_index,
           coalesce(a.SMMANY_ind,0) as SMMANY_ind,
           s.AMI_since_index,
           IF(s.AMI_since_index IS NOT NULL, 1, 0) AMI_ind,
           s.ANE_since_index,
           IF(s.ANE_since_index IS NOT NULL, 1, 0) ANE_ind,
           s.ARF_since_index,
           IF(s.ARF_since_index IS NOT NULL, 1, 0) ARF_ind,
           s.ARDS_since_index,
           IF(s.ARDS_since_index IS NOT NULL, 1, 0) ARDS_ind,
           s.AFE_since_index,
           IF(s.AFE_since_index IS NOT NULL, 1, 0) AFE_ind,
           s.CAVF_since_index,
           IF(s.CAVF_since_index IS NOT NULL, 1, 0) CAVF_ind,
           s.COCR_since_index,
           IF(s.COCR_since_index IS NOT NULL, 1, 0) COCR_ind,
           s.DIC_since_index,
           IF(s.DIC_since_index IS NOT NULL, 1, 0) DIC_ind,
           s.ECL_since_index,
           IF(s.ECL_since_index IS NOT NULL, 1, 0) ECL_ind,
           s.HF_since_index,
           IF(s.HF_since_index IS NOT NULL, 1, 0) HF_ind,
           s.PCD_since_index,
           IF(s.PCD_since_index IS NOT NULL, 1, 0) PCD_ind,
           s.PE_since_index,
           IF(s.PE_since_index IS NOT NULL, 1, 0) PE_ind,
           s.SAC_since_index,
           IF(s.SAC_since_index IS NOT NULL, 1, 0) SAC_ind,
           s.SEP_since_index,
           IF(s.SEP_since_index IS NOT NULL, 1, 0) SEP_ind,
           s.SSH_since_index,
           IF(s.SSH_since_index IS NOT NULL, 1, 0) SSH_ind,
           s.SCC_since_index,
           IF(s.SCC_since_index IS NOT NULL, 1, 0) SCC_ind,
           s.ATE_since_index,
           IF(s.ATE_since_index IS NOT NULL, 1, 0) ATE_ind,
           s.BPT_since_index,
           IF(s.BPT_since_index IS NOT NULL, 1, 0) BPT_ind,
           s.HYS_since_index,
           IF(s.HYS_since_index IS NOT NULL, 1, 0) HYS_ind,
           s.TT_since_index,
           IF(s.TT_since_index IS NOT NULL, 1, 0) TT_ind,
           s.VEN_since_index,
           IF(s.VEN_since_index IS NOT NULL, 1, 0) VEN_ind
    from delivery_elig_tbl1 e
    left join smm_any a 
    on e.person_id = a.person_id and e.event_id = a.event_id
    left join smm_post_delivery_wide s 
    on e.person_id = s.person_id and e.event_id = s.event_id
''').cache()
delivery_elig_smm.createOrReplaceTempView("delivery_elig_smm")
delivery_elig_smm.first()

In [None]:
# save elig table as parquet file
delivery_elig_tbl1.write.save('delivery_elig_tbl1.parquet')

In [None]:
delivery_elig_tbl1 =  spark.read.load("delivery_elig_tbl1.parquet")
delivery_elig_tbl1.first()

In [None]:
# import itertools
# def pt_freq_qry(df,stratified_by,n_way=1):
#     '''
#     generate total patient counts for each stratified variables
#     '''
#     sql_str_lst = []
#     # overall count
#     sql_str_lst.append("select 'total' as summ_var,'N' as summ_cat, count(distinct person_id) as pat_cnt, count(distinct person_id || '_' || event_id) as evt_cnt from " + df)
    
#     # 1-way summary
#     for var_str in stratified_by:
#         sql_str_lst.append(
#             "select '" + var_str +"' as summ_var," 
#             + "cast(" + var_str +" as string) as summ_cat,"
#             + "count(distinct person_id) as pat_cnt, "
#             + "count(distinct person_id || '_' || event_id) as evt_cnt "
#             + "from "+ df + " group by "+ var_str
#         )
        
#     # up to n-way summary
#     if n_way > 1:
#         for L in range(2,n_way+1,1):
#             for var_str_comb in itertools.combinations(stratified_by, L):
#                 var_str_concat_by_pipe = "|".join(var_str_comb)
#                 var_str_concat_by_dpipe = "|| '||' ||".join(var_str_comb)
#                 var_str_concat_by_comma = ",".join(var_str_comb)
#                 sql_str_lst.append(
#                     "select 'by_" + var_str_concat_by_pipe +"' as summ_var," 
#                     + "cast(" + var_str_concat_by_dpipe +" as string) as summ_cat,"
#                     + "count(distinct person_id) as pat_cnt, "
#                     + "count(distinct person_id || '_' || event_id) as evt_cnt "
#                     + "from "+ df + " group by "+ var_str_concat_by_comma
#                 )
                
#     # union everything
#     sql_str = " union ".join(sql_str_lst)
#     return(sql_str)

In [None]:
# stratified_by = [
#     'agegrp_at_event',
#     'race_source_value',
#     'ethnicity_source_value',
#     'segment',
#     'speciality',
#     'bed_size',
#     'zip_code',
#     'death_ind',
#     'SMMANY_ind'
# ]
# get_pt_summ = pt_freq_qry(
#     'delivery_elig_tbl1',stratified_by,n_way=2
# )
# summ_stat_long = spark.sql(get_pt_summ).toPandas()
# summ_stat_long.to_csv(os.path.join(path_to_data,'summ_stat_2way.csv'),index=False)

In [None]:
def calculate_univariate_odds_ratios(
    df,             # spark dataframe
    covariate_cols, # list of covariates
    outcome_cols    # list of outcomes
):
    odds_ratios = {}
    for outcome in outcome_cols:
        for covariate in covariate_cols:       
            # Determine if the covariate is categorical
            if dict(df.dtypes)[covariate] == 'string':
                # Index the categorical column
                indexer = StringIndexer(inputCol=covariate, outputCol=covariate + "_indexed")
                indexed_df = indexer.fit(df).transform(df)

                # One-hot encode the indexed column
                encoder = OneHotEncoder(inputCol=covariate + "_indexed", outputCol=covariate + "_encoded")
                encoded_df = encoder.transform(indexed_df)

                # Assemble features for logistic regression
                assembler = VectorAssembler(inputCols=[covariate + "_encoded"], outputCol="features")
                assembled_df = assembler.transform(encoded_df.select(outcome, covariate + "_encoded"))
            else:
                # Assemble features for logistic regression
                assembler = VectorAssembler(inputCols=[covariate], outputCol="features")
                assembled_df = assembler.transform(df.select(outcome, covariate))

            # Fit logistic regression model
            lr = LogisticRegression(featuresCol="features", labelCol=outcome)
            model = lr.fit(assembled_df)

            # Extract coefficients and calculate odds ratios
            for i, coef in enumerate(model.coefficients):
                odds_ratio = np.exp(coef)
                # Collect summary for confidence intervals
                summary = model.summary
                if hasattr(summary, 'coefficientStandardErrors'):
                    coefficient_standard_error = summary.coefficientStandardErrors[i]
                else:
                    coefficient_standard_error = 0
                z_value = 1.96  # for 95% confidence interval
                conf_lower = np.exp(coef - z_value * coefficient_standard_error)
                conf_upper = np.exp(coef + z_value * coefficient_standard_error)

                odds_ratios[f"{covariate}_{i}"] = {
                    'odds_ratio': odds_ratio,
                    'conf_lower': conf_lower,
                    'conf_upper': conf_upper
                }
    return odds_ratios

In [None]:
cov_lst = ['race_source_value']
out_lst = ['SMMANY_ind']
results = calculate_univariate_odds_ratios(
    df = delivery_elig_smm,
    covariate_cols = cov_lst,
    outcome_cols = out_lst
)
print(results)