In [None]:
# script to test effect of demographics, SES, and comorbidities on probability of being annotated with a diagnosis
import pandas as pd
import os
import subprocess
from datetime import datetime
from matplotlib import pyplot
import matplotlib.pyplot as plt
import seaborn as sns
import re
import numpy as np

# get the bucket name
my_bucket = os.getenv('WORKSPACE_BUCKET')

# Get the BigQuery curated dataset for the current workspace context.
CDR = os.environ['WORKSPACE_CDR']

from google.cloud import bigquery
# Instantiate a BigQuery client
client = bigquery.Client()
#!pip install upsetplot #if necessary

from tqdm import tqdm

pd.set_option('display.max_columns', 5000)
pd.set_option('display.max_row', 5000)

In [None]:
def getTable(table_name, folder):
    """
    Purpose: Download TSV from gs://<WORKSPACE_BUCKET>/data/<folder>/ and load as DataFrame (tab-delimited).
    Inputs: table_name (str), folder (str).
    Returns: pandas DataFrame.
    """
    my_bucket = os.getenv('WORKSPACE_BUCKET')

    # copy csv file from the bucket to the current working space
    os.system(f"gsutil cp '{my_bucket}/data/{folder}/{table_name}' .")

    print(f'[INFO] {table_name} is successfully downloaded into your working space')
    # save dataframe in a csv file in the same workspace as the notebook
    table_read = pd.read_csv(table_name, sep="\t")
    return table_read
def getFile(table_name, folder):
    """
    Purpose: Download a file from gs://<WORKSPACE_BUCKET>/data/<folder>/ into the working directory.
    """
    my_bucket = os.getenv('WORKSPACE_BUCKET')

    # copy csv file from the bucket to the current working space
    os.system(f"gsutil cp '{my_bucket}/data/{folder}/{table_name}' .")

    print(f'[INFO] {table_name} is successfully downloaded into your working space')
def saveToBucket(df, df_filename, data_folder):
    """
    Purpose: Save DataFrame as TSV and upload to GCS under data/<data_folder>/ using gsutil.
    Inputs: df (DataFrame), df_filename (str), data_folder (str). Requires WORKSPACE_BUCKET.
    """
    df.to_csv(df_filename, sep = "\t", index=False)

    # get the bucket name
    my_bucket = os.getenv('WORKSPACE_BUCKET')

    # copy csv file to the bucket
    args = ["gsutil", "cp", f"./{df_filename}", f"{my_bucket}/data/{data_folder}/"]
    output = subprocess.run(args, capture_output=True)

    # print output from gsutil
    output.stderr
def widen_comorbidity(pID_cond_table):
    """
    Purpose: Map ICD codes to comorbidity categories and pivot to a wide one-hot matrix by person_id.
    Inputs: pID_cond_table with columns person_id and source_concept_code.
    Returns: Wide DataFrame with standardized binary comorbidity columns.
    """
    pID_cond_table['comorbidities'] = pID_cond_table['source_concept_code'].apply(assign_comorbidity)
    
    std_cols = [
        "comorbidities_hemonc",
        "comorbidities_DM",
        "cmspec_DM_microvascular",
        "comorbidities_vitAdef",
        "comordbities_obesity",
        "comordbities_rare",
        "comordbities_neuropsych",
        "comordbities_ent",
        "comordbities_circulatory",
        "comordbities_HTN",
        "comordbities_ischemicHD",
        "comorb_GI_skin_MSK_GU",
        "comordbities_congenital",
        "comordbities_dactyly",
        "comordbities_external",
        "comorbidities_HLD"
    ]
    
    wide = pd.crosstab(pID_cond_table['person_id'], pID_cond_table['comorbidities']).reset_index()
    
    # Ensure all standardized columns are present; if a column is missing, create it with zeros.
    for col in std_cols:
        if col not in wide.columns:
            wide[col] = 0

    # Reorder the columns: person_id first, then standardized comorbidity columns
    wide = wide[['person_id'] + std_cols]


    # Convert counts to binary flags: set to 1 if count is at least 1, else 0.
    for col in std_cols:
        wide[col] = wide[col].apply(lambda x: 1 if x >= 1 else 0)
        
    return wide
def assign_comorbidity(code):
    
    if code in icd9_map:
        return icd9_map[code]
    """
    Given an ICD10CM code, assign a comorbidity category based on the following scheme:
    
    - If code is in E08-E13, assign "comorbidities_DM". But if the decimal part (if present)
      is between 0.2 and 0.5 then assign "cmspec_DM_microvascular".
    - If code starts with E50 exactly, assign "comordbities_vitAdef".
    - If code is in E65-E68, assign "comordbities_obesity".
    - If code is in E70-E75, assign "comordbities_rare".
    - If code starts with F or G, assign "comordbities_neuropsych".
    - If code is in H60-H95, assign "comordbities_ent".
    - If code starts with I, assign "comordbities_circulatory". However, if the numeric part is:
        - 10 to 15 → "comordbities_HTN"
        - 20 to 25 → "comordbities_ischemicHD"
    - If code starts with K, L, M, or N, assign "comorb_GI_skin_MSK_GU".
    - If code starts with Q, assign "comordbities_congenital". However, if the numeric part is
      between 69 and 74, then assign "comordbities_dactyly".
    - If code starts with any letter from T through Z, assign "comordbities_external".
    
    Note: This example assumes that the code format is a letter followed by two digits,
    optionally followed by a decimal and additional digits.
    """
    # Use regex to extract the letter and the first two digits
    match = re.match(r"([A-Z])(\d{2})(?:\.(\d+))?", code)
    if not match:
        return None  # or return code if format is unexpected

    letter, digits, dec = match.groups()
    num = int(digits)

    # 1. E08-E13 with potential subcategory check on the decimal part
    if letter == 'E' and 8 <= num <= 13:
        '''        if dec is not None:
            try:
                dec_val = float("0." + dec)
                # if the decimal part is between 0.2 and 0.5, assign the microvascular specification
                if 0.2 <= dec_val <= 0.5:
                    return "cmspec_DM_microvascular"
            except ValueError:
                pass'''
        return "comorbidities_DM"
    
    # 2. E50 exactly → vitamin A deficiency
    if code.startswith("E50"):
        return "comordbities_vitAdef"
    
    # 3. E65-E68 → obesity
    if letter == 'E' and 65 <= num <= 68:
        return "comordbities_obesity"
    
    # 4. E70-E75 → rare
    if letter == 'E' and 70 <= num <= 75:
        return "comordbities_rare"
    
    # 4. E78 → hypercholesterolemia
    if letter == 'E' and num == 78:
        return "comorbidities_HLD"
    
    # 5. F or G → neuropsych
    if letter in ['F', 'G']:
        return "comordbities_neuropsych"
    
    # 6. H60-H95 → ENT disorders
    if letter == 'H' and 60 <= num <= 95:
        return "comordbities_ent"
    
    # 7. I codes: default circulatory; with subcategories for HTN and ischemicHD
    if letter == 'I':
        if 10 <= num <= 15:
            return "comordbities_HTN"
        elif 20 <= num <= 25:
            return "comordbities_ischemicHD"
        else:
            return "comordbities_circulatory"
    
    # 8. K, L, M, N → GI, skin, MSK, GU
    if letter in ['K', 'L', 'M', 'N']:
        return "comorb_GI_skin_MSK_GU"
    
    # 9. Q codes: default congenital; with subcategory for dactyly
    if letter == 'Q':
        if 69 <= num <= 74:
            return "comordbities_dactyly"
        else:
            return "comordbities_congenital"
    
    # 10. T through Z → external
    if letter >= 'T' and letter <= 'Z':
        return "comordbities_external"
    
    if letter == 'S':
        return "comordbities_external"
    
    if letter in ['C', 'D']:
        return "comorbidities_hemonc"
    
    # If none of the conditions match, return None or a default category
    return None
def pID_mobidity_master(pid_var_con_table):
    """
    Purpose: Build condition occurrence table (ICD9/10 only) for given people, annotate comorbidities, and produce wide+long outputs.
    Inputs: pid_var_con_table with person_id.
    Returns: [wide_comorbidity_df, ds_occurrenceICD].
    """
    personid_list =  ','.join(map(str, pid_var_con_table["person_id"].tolist()))

    ds_occurrence = pd.read_gbq(f'''
    SELECT
        dsc.*,
        c.*
    FROM
        `{CDR}.ds_condition_occurrence` dsc
        JOIN `{CDR}.concept` c on c.concept_id = dsc.condition_concept_id

    WHERE
        person_id IN ({personid_list})
        #AND dsc.ds_sleep_level IN (1585375, 1585370)


    ''',  progress_bar_type="tqdm_notebook")
    
    ds_occurrence.columns = ds_occurrence.columns.str.lower()
    ds_occurrenceICD = ds_occurrence[ds_occurrence["source_vocabulary"].isin(["ICD10CM", "ICD9CM"])]
    ds_occurrenceICD['comorbidities'] = ds_occurrenceICD['source_concept_code'].apply(assign_comorbidity)

    ds_occurrence_minimal = ds_occurrenceICD[["person_id", "source_concept_code", "comorbidities"]]   
    wide = widen_comorbidity(ds_occurrence_minimal)
    
    return [wide, ds_occurrenceICD]
def assign_LS_smoke(answer_SES):
    """
    Given an education level (string), assign SES_education category based on the following scheme:
    
    - "Prefer Not To Answer" or "Skip" → np.nan
    - "One Through Four", "Five Through Eight", "Nine Through Eleven" → 0
    - "Twelve or GED" → 1
    - "College One to Three" → 2
    - "College Graduate" → 3
    - "Advanced Degree" → 4
    """
    if "Dont Know" in answer_SES or "Skip" in answer_SES:
        return np.nan
    elif "No" in answer_SES:
        return 0
    elif "Yes" in answer_SES:
        return 1
    else:
        return np.nan
def assign_SES_married(answer_SES):
    """
    Given an education level (string), assign SES_education category based on the following scheme:
    
    - "Prefer Not To Answer" or "Skip" → np.nan
    - "One Through Four", "Five Through Eight", "Nine Through Eleven" → 0
    - "Twelve or GED" → 1
    - "College One to Three" → 2
    - "College Graduate" → 3
    - "Advanced Degree" → 4
    """
    if "Prefer Not To Answer" in answer_SES or "Skip" in answer_SES or "Other Arrangement" in answer_SES:
        return np.nan
    elif "Separated" in answer_SES or "Widowed" in answer_SES or "Divorced" in answer_SES or "Never Married" in answer_SES:
        return 0
    elif "Married" in answer_SES or "Living With Partner" in answer_SES:
        return 1
    else:
        return np.nan
def pID_survey_master(pid_var_con_table):
    personid_list =  ','.join(map(str, pid_var_con_table["person_id"].tolist()))

    pid_survey = pd.read_gbq(f'''
    SELECT
        dsc.*
    FROM
        `{CDR}.ds_survey` dsc
        #JOIN `{CDR}.concept` c on c.concept_id = dsc.condition_concept_id

    WHERE
        person_id IN ({personid_list})
        #AND dsc.ds_sleep_level IN (1585375, 1585370)


    ''',  progress_bar_type="tqdm_notebook")
    
    pid_survey.columns = pid_survey.columns.str.lower()
    
    #assign values to questions
    pid_survey['SES_education'] = pid_survey[pid_survey["question_concept_id"] == 1585940]["answer"].apply(assign_SES_education)
    pid_survey['SES_income'] = pid_survey[pid_survey["question_concept_id"] == 1585375]["answer"].apply(assign_SES_income)
    pid_survey['SES_healthins'] = pid_survey[pid_survey["question_concept_id"] == 1585386]["answer"].apply(assign_SES_healthins)
    pid_survey['SES_homeown'] = pid_survey[pid_survey["question_concept_id"] == 1585370]["answer"].apply(assign_SES_homeown)
    pid_survey['SES_married'] = pid_survey[pid_survey["question_concept_id"] == 1585892]["answer"].apply(assign_SES_married)
    pid_survey['LS_smoke'] = pid_survey[pid_survey["question_concept_id"] == 1585857]["answer"].apply(assign_LS_smoke)
    
    #make one wide table
    out_table = pid_survey[pid_survey["question_concept_id"] == 1585940][["person_id", "SES_education"]].drop_duplicates()
    out_table_inc = pid_survey[pid_survey["question_concept_id"] == 1585375][["person_id", "SES_income"]].drop_duplicates()
    out_table_hi = pid_survey[pid_survey["question_concept_id"] == 1585386][["person_id", "SES_healthins"]].drop_duplicates()
    out_table_home = pid_survey[pid_survey["question_concept_id"] == 1585370][["person_id", "SES_homeown"]].drop_duplicates()
    out_table_married = pid_survey[pid_survey["question_concept_id"] == 1585892][["person_id", "SES_married"]].drop_duplicates()
    out_table_LS_smoke = pid_survey[pid_survey["question_concept_id"] == 1585857][["person_id", "LS_smoke"]].drop_duplicates()
    
    #merge tables
    out_table = pd.merge(out_table, out_table_inc, on = "person_id", how = "left")
    out_table = pd.merge(out_table, out_table_hi, on = "person_id", how = "left")
    out_table = pd.merge(out_table, out_table_home, on = "person_id", how = "left")
    out_table = pd.merge(out_table, out_table_married, on = "person_id", how = "left")
    out_table = pd.merge(out_table, out_table_LS_smoke, on = "person_id", how = "left")

    
    #return [wide, ds_occurrenceICD10]
    return [out_table, pid_survey]
def assign_SES_healthins(answer_SES):
    """
    Given an education level (string), assign SES_education category based on the following scheme:
    
    - "Prefer Not To Answer" or "Skip" → np.nan
    - "One Through Four", "Five Through Eight", "Nine Through Eleven" → 0
    - "Twelve or GED" → 1
    - "College One to Three" → 2
    - "College Graduate" → 3
    - "Advanced Degree" → 4
    """
    if "Prefer Not To Answer" in answer_SES or "Skip" in answer_SES or "Dont Know" in answer_SES:
        return np.nan
    elif "Yes" in answer_SES:
        return 1
    elif "No" in answer_SES:
        return 0
    else:
        return np.nan
def assign_SES_education(answer_SES):
    """
    Given an education level (string), assign SES_education category based on the following scheme:
    
    - "Prefer Not To Answer" or "Skip" → np.nan
    - "One Through Four", "Five Through Eight", "Nine Through Eleven" → 0
    - "Twelve or GED" → 1
    - "College One to Three" → 2
    - "College Graduate" → 3
    - "Advanced Degree" → 4
    """
    if "Prefer Not To Answer" in answer_SES or "Skip" in answer_SES:
        return np.nan
    elif "One Through Four" in answer_SES or "Five Through Eight" in answer_SES or "Nine Through Eleven" in answer_SES:
        return 0
    elif "Twelve Or GED" in answer_SES:
        return 1
    elif "College One to Three" in answer_SES:
        return 2
    elif "College Graduate" in answer_SES:
        return 3
    elif "Advanced Degree" in answer_SES:
        return 4
    else:
        return np.nan
def assign_SES_income(answer_SES):
    """
    Given an education level (string), assign SES_education category based on the following scheme:
    
    - "Prefer Not To Answer" or "Skip" 1585376,  → np.nan
    - "One Through Four", "Five Through Eight", "Nine Through Eleven" → 0
    - "Twelve or GED" → 1
    - "College One to Three" → 2
    - "College Graduate" → 3
    - "Advanced Degree" → 4
    """
    if "Prefer Not To Answer" in answer_SES or "Skip" in answer_SES:
        return np.nan
    elif "less 10k" in answer_SES or "10k 25k" in answer_SES:
        return 0
    elif "25k 35k" in answer_SES or "35k 50k" in answer_SES:
        return 1
    elif "50k 75k" in answer_SES or "75k 100k" in answer_SES:
        return 2
    elif "100k 150k" in answer_SES or "150k 200k" in answer_SES:
        return 3
    elif "more 200k" in answer_SES:
        return 4
    else:
        return np.nan
def assign_SES_homeown(answer_SES):
    """
    Given an education level (string), assign SES_education category based on the following scheme:
    
    - "Prefer Not To Answer" or "Skip" → np.nan
    - "One Through Four", "Five Through Eight", "Nine Through Eleven" → 0
    - "Twelve or GED" → 1
    - "College One to Three" → 2
    - "College Graduate" → 3
    - "Advanced Degree" → 4
    """

    if "Current Home Own: Own" in answer_SES:
        return 1
    elif "Current Home Own: Rent" in answer_SES:
        return 0
    elif "Current Home Own: Other Arrangement" in answer_SES:
        return 0
    else:
        return np.nan

In [None]:
#Load in files

ALL_IRD = getTable("ALL_IRD_raceproc.tsv", "personID_variant")
in_retdys = getTable("in_retdys.tsv", "personID_variant_concept")
notin_retdys = getTable("notin_retdys.tsv", "personID_variant_concept")

in_retdegen = getTable("in_retdegen.tsv", "personID_variant_concept")
notin_retdegen = getTable("notin_retdegen.tsv", "personID_variant_concept")

in_screenretdys = getTable("in_screenretdys.tsv", "personID_variant_concept")
notin_screenretdys = getTable("notin_screenretdys.tsv", "personID_variant_concept")

In [None]:
#within genotype affected vs unaffected comparisons
#get morbidities for each table
in_retdys_result = pID_mobidity_master(in_retdys)
in_retdys_matrix = in_retdys_result[0]
in_retdys_full = in_retdys_result[1]
in_retdys_full = in_retdys_full[~in_retdys_full.source_concept_code.isin(ScreenRetDysICD[["ICD_code"]])]



notin_retdys_result = pID_mobidity_master(notin_retdys)
notin_retdys_matrix = notin_retdys_result[0]
notin_retdys_full = notin_retdys_result[1]
notin_retdys_full = notin_retdys_full[~notin_retdys_full.source_concept_code.isin(ScreenRetDysICD[["ICD_code"]])]


in_retdegen_result = pID_mobidity_master(in_retdegen)
in_retdegen_matrix = in_retdegen_result[0]
in_retdegen_full = in_retdegen_result[1]
in_retdegen_full = in_retdegen_full[~in_retdegen_full.source_concept_code.isin(ScreenRetDysICD[["ICD_code"]])]


notin_retdegen_result = pID_mobidity_master(notin_retdegen)
notin_retdegen_matrix = notin_retdegen_result[0]
notin_retdegen_full = notin_retdegen_result[1]
notin_retdegen_full = notin_retdegen_full[~notin_retdegen_full.source_concept_code.isin(ScreenRetDysICD[["ICD_code"]])]


in_screenretdys_result = pID_mobidity_master(in_screenretdys)
in_screenretdys_matrix = in_screenretdys_result[0]
in_screenretdys_full = in_screenretdys_result[1]
in_screenretdys_full = in_screenretdys_full[~in_screenretdys_full.source_concept_code.isin(ScreenRetDysICD[["ICD_code"]])]


notin_screenretdys_result = pID_mobidity_master(notin_screenretdys)
notin_screenretdys_matrix = notin_screenretdys_result[0]
notin_screenretdys_full = notin_screenretdys_result[1]
notin_screenretdys_full = notin_screenretdys_full[~notin_screenretdys_full.source_concept_code.isin(ScreenRetDysICD[["ICD_code"]])]

In [None]:
#function to calculate number of visits for each person
def countVisits(matrix, df_visits):
    """
    Purpose: Count number of condition visits per person and merge as 'NumberOfVisits'.
    Inputs: matrix (DataFrame with person_id), df_visits (long table with condition_start_datetime).
    Returns: matrix with added NumberOfVisits column.
    """
    visit_counts = (
        df_visits
        .groupby('person_id')['condition_start_datetime']
        .size()
        .reset_index(name='NumberOfVisits')
    )
    merged_matrix = pd.merge(matrix, visit_counts, on = "person_id", how = "right")
    
    return(merged_matrix)

In [None]:
in_retdys_matrix = countVisits(in_retdys_matrix, in_retdys_full)
in_retdegen_matrix = countVisits(in_retdegen_matrix, in_retdegen_full)
in_screenretdys_matrix = countVisits(in_screenretdys_matrix, in_screenretdys_full)
notin_retdys_matrix = countVisits(notin_retdys_matrix, notin_retdys_full)
notin_retdegen_matrix = countVisits(notin_retdegen_matrix, notin_retdegen_full)
notin_screenretdys_matrix = countVisits(notin_screenretdys_matrix, notin_screenretdys_full)

In [None]:
#survey data and SES
in_retdys_survey = pID_survey_master(in_retdys)
in_retdys_survey_matrix = in_retdys_survey[0]
in_retdys_survey_full = in_retdys_survey[1]
notin_retdys_survey = pID_survey_master(notin_retdys)
notin_retdys_survey_matrix = notin_retdys_survey[0]
notin_retdys_survey_full = notin_retdys_survey[1]

in_retdegen_survey = pID_survey_master(in_retdegen)
in_retdegen_survey_matrix = in_retdegen_survey[0]
in_retdegen_survey_full = in_retdegen_survey[1]
notin_retdegen_survey = pID_survey_master(notin_retdegen)
notin_retdegen_survey_matrix = notin_retdegen_survey[0]
notin_retdegen_survey_full = notin_retdegen_survey[1]

in_screenretdys_survey = pID_survey_master(in_screenretdys)
in_screenretdys_survey_matrix = in_screenretdys_survey[0]
in_screenretdys_survey_full = in_screenretdys_survey[1]
notin_screenretdys_survey = pID_survey_master(notin_screenretdys)
notin_screenretdys_survey_matrix = notin_screenretdys_survey[0]
notin_screenretdys_survey_full = notin_screenretdys_survey[1]

In [None]:
#functions to merge survey/SES and morbidity data
def mergeSurveyforAnno(matrix, source_df):
    """
    Purpose: Mark affected="Yes" and attach Age (from AgeAtVisit_Years), Sex, and race from source_df.
    Returns: merged DataFrame with standardized 'Age' and 'Sex' columns.
    """
    matrix["affected"] = "Yes"
    matrix = pd.merge(matrix, source_df[["person_id", "AgeAtVisit_Years", "sex_at_birth_source_value", "race"]],
                                      on="person_id", how="left")



    matrix = matrix.rename(columns={'AgeAtVisit_Years': 'Age',
                                                                 'sex_at_birth_source_value': "Sex"})
    
    return matrix

def mergeSurveyforUNAnno(matrix, source_df):
    """
    Purpose: Mark affected="No" and attach Age (from AgeAtConsent_Years), Sex, and race from source_df.
    Returns: merged DataFrame with standardized 'Age' and 'Sex' columns.
    """
    matrix["affected"] = "No"
    matrix = pd.merge(matrix, source_df[["person_id", "AgeAtConsent_Years", "sex_at_birth_source_value", "race"]],
                                      on="person_id", how="left")



    matrix = matrix.rename(columns={'AgeAtConsent_Years': 'Age',
                                                                 'sex_at_birth_source_value': "Sex"})
    
    return matrix

In [None]:
#merge survey/SES and morbidity data
in_retdys_survey_matrix = mergeSurveyforAnno(in_retdys_survey[0], in_retdys)
notin_retdys_survey_matrix = mergeSurveyforUNAnno(notin_retdys_survey[0], notin_retdys)
in_retdegen_survey_matrix = mergeSurveyforAnno(in_retdegen_survey[0], in_retdegen)
notin_retdegen_survey_matrix = mergeSurveyforUNAnno(notin_retdegen_survey[0], notin_retdegen)
in_screenretdys_survey_matrix = mergeSurveyforAnno(in_screenretdys_survey[0], in_screenretdys)
notin_screenretdys_survey_matrix = mergeSurveyforUNAnno(notin_screenretdys_survey[0], notin_screenretdys)

In [None]:
#prepare for logistic regression models

#drop person_id
retdys_matrix_merge_glmnet = retdys_matrix_merge.drop("person_id", axis = 1)
retdegen_matrix_merge_glmnet = retdegen_matrix_merge.drop("person_id", axis = 1)
screenretdys_matrix_merge_glmnet = screenretdys_matrix_merge.drop("person_id", axis = 1)

def makeSexNumeric(df):
    """
    Purpose: Map categorical Sex (SexAtBirth_*) to numeric 0/1 as 'sex_numeric' and drop 'Sex'.
    Returns: DataFrame with numeric sex feature.
    """
    sex_map = {
        'SexAtBirth_Female': 0,
        'SexAtBirth_Male':   1
    }
    
    df['sex_numeric'] = df['Sex'].map(sex_map)
    df['sex_numeric'] = df['sex_numeric'].astype(float)
    df = df.drop("Sex", axis = 1)
    
    return(df)

retdys_matrix_merge_glmnet = makeSexNumeric(retdys_matrix_merge_glmnet)
retdegen_matrix_merge_glmnet = makeSexNumeric(retdegen_matrix_merge_glmnet)
screenretdys_matrix_merge_glmnet = makeSexNumeric(screenretdys_matrix_merge_glmnet)

def dummyRace(df):
    """
    Purpose: One-hot encode 'race' column (drop_first=True).
    Returns: DataFrame with race dummies.
    """
    df = pd.get_dummies(df, columns=['race'], drop_first=True)
    return df

retdys_matrix_merge_glmnet = dummyRace(retdys_matrix_merge_glmnet)
retdegen_matrix_merge_glmnet = dummyRace(retdegen_matrix_merge_glmnet)
screenretdys_matrix_merge_glmnet = dummyRace(screenretdys_matrix_merge_glmnet)

In [None]:
#do both univariate and multivariate at once
import numpy as np
import pandas as pd
import statsmodels.api as sm
from statsmodels.tools.sm_exceptions import PerfectSeparationError
from statsmodels.stats.outliers_influence import variance_inflation_factor

# --- your helpers (unchanged) ---
def bin_outcome(df):    
    df['outcome_bin'] = df['affected'].map({'Yes': 1, 'No': 0})
    return df

def divideNumVisit(df, amt = 10):
    if "NumberOfVisits" in df.columns:
        df["NumberOfVisits"] = df["NumberOfVisits"] / amt
    return df

def divideAge(df, amt = 1):
    if "Age" in df.columns:
        df["Age"] = df["Age"] / amt
    return df

# --- internal utility ---
def _fit_logit(y, X):
    """Fit Logit with intercept, return (model, result) or (None, None) if it fails."""
    try:
        X = sm.add_constant(X, has_constant='add')
        model = sm.Logit(y, X)
        res = model.fit(disp=False)
        return model, res
    except (PerfectSeparationError, np.linalg.LinAlgError, ValueError, ZeroDivisionError, FloatingPointError):
        return None, None

def _tidy_res(res, var_name):
    """Extract OR, CI, p for var_name from a LogitResults; return NaNs if not available."""
    if res is None:
        return dict(OR=np.nan, lo=np.nan, hi=np.nan, p=np.nan)
    try:
        summ = res.summary2().tables[1]
        if var_name not in summ.index:
            return dict(OR=np.nan, lo=np.nan, hi=np.nan, p=np.nan)
        coef = summ.loc[var_name, 'Coef.']
        lo   = summ.loc[var_name, '[0.025']
        hi   = summ.loc[var_name, '0.975]']
        p    = summ.loc[var_name, 'P>|z|']
        return dict(OR=float(np.exp(coef)),
                    lo=float(np.exp(lo)),
                    hi=float(np.exp(hi)),
                    p=float(p))
    except Exception:
        return dict(OR=np.nan, lo=np.nan, hi=np.nan, p=np.nan)

def doLogit_both(df,
                 drop_cols=None,
                 visits_divisor=10,
                 age_divisor=1,
                 compute_vif=False):
    """
    Multivariate: use all predictors not in drop_cols (rows = complete cases across all).
    Univariate: for each predictor, RELOAD the full df and drop NaNs only for that predictor + outcome.
    Returns:
      {
        "multi": (multi_model, multi_res, X_multi, y_multi),
        "uni": {var: (model, res)},
        "combined_table": 9-col DataFrame,
        "vif": DataFrame or None
      }
    """
    df_full = df.copy()

    # Prepare the full outcome
    if 'outcome_bin' not in df_full.columns:
        raise ValueError("df must contain 'outcome_bin'; run bin_outcome first.")
    y_full = df_full['outcome_bin'].astype(float)

    # ---------- MULTIVARIATE ----------
    # Build X from full df, then drop requested columns
    X_multi = df_full.drop(columns=['affected', 'outcome_bin'], errors='ignore')
    if drop_cols:
        X_multi = X_multi.drop(columns=drop_cols, errors='ignore')

    # Coerce numeric and clean
    X_multi = X_multi.apply(pd.to_numeric, errors='coerce').astype(float)
    X_multi = X_multi.replace([np.inf, -np.inf], np.nan)

    # Optional scaling
    X_multi = divideNumVisit(X_multi, amt=visits_divisor)
    X_multi = divideAge(X_multi, amt=age_divisor)

    # Keep only complete cases across all multivariate predictors + outcome
    data_multi = pd.concat([X_multi, y_full], axis=1).dropna()
    if data_multi.empty:
        raise ValueError("No rows left for multivariate fit after NaN cleaning.")
    X_multi_cc = data_multi[X_multi.columns]
    y_multi_cc = data_multi[y_full.name]

    multi_model, multi_res = _fit_logit(y_multi_cc, X_multi_cc)

    # ---------- UNIVARIATE (reload df for each predictor) ----------
    uni = {}
    predictors = list(X_multi.columns)  # univariates on the same set you planned to use multivariately
    for var in predictors:
        # take ONLY this predictor from the full df (not multivariate CC)
        X_uni = df_full[[var]].copy()

        # numeric conversion & clean for this var alone
        X_uni[var] = pd.to_numeric(X_uni[var], errors='coerce').astype(float)
        X_uni[var] = X_uni[var].replace([np.inf, -np.inf], np.nan)

        # scaling if it's Age or NumberOfVisits (safe no-op otherwise)
        if var == "NumberOfVisits":
            X_uni = divideNumVisit(X_uni, amt=visits_divisor)
        if var == "Age":
            X_uni = divideAge(X_uni, amt=age_divisor)

        # drop NaNs only in this var + outcome
        data_uni = pd.concat([X_uni, y_full], axis=1).dropna()
        if data_uni.empty:
            uni[var] = (None, None)
            continue

        Xu = data_uni[[var]]
        yu = data_uni[y_full.name]
        uni[var] = _fit_logit(yu, Xu)

    # ---------- Optional VIF (on multivariate design) ----------
    vif_df = None
    if compute_vif and multi_res is not None and X_multi_cc.shape[1] > 1:
        X_for_vif = sm.add_constant(X_multi_cc, has_constant='add')
        X_no_const = X_for_vif.drop(columns=['const'], errors='ignore')
        vif_df = pd.DataFrame({
            "feature": X_no_const.columns,
            "VIF": [variance_inflation_factor(X_no_const.values, i)
                    for i in range(X_no_const.shape[1])]
        })

    # ---------- Build the 9-column combined table ----------
    rows = []
    for var in predictors:
        u_model, u_res = uni[var]
        u = _tidy_res(u_res, var)
        m = _tidy_res(multi_res, var)
        rows.append({
            "variable": var,
            "uni_OR": u['OR'],
            "uni_CI_lo": u['lo'],
            "uni_CI_hi": u['hi'],
            "uni_p": u['p'],
            "multi_OR": m['OR'],
            "multi_CI_lo": m['lo'],
            "multi_CI_hi": m['hi'],
            "multi_p": m['p'],
        })

    combined = pd.DataFrame(rows).sort_values(
        by=["multi_p", "uni_p"], na_position="last"
    ).reset_index(drop=True)

    return {
        "multi": (multi_model, multi_res, X_multi_cc, y_multi_cc),
        "uni": uni,
        "combined_table": combined,
        "vif": vif_df
    }

def tidy_logit_result_both(results_dict, html_path=None, float_fmt="{:0.5f}"):
    """
    Pretty-print the 9-col table from doLogit_both. Returns (df, html_str).
    """
    df = results_dict["combined_table"].copy()
    num_cols = ["uni_OR","uni_CI_lo","uni_CI_hi","uni_p",
                "multi_OR","multi_CI_lo","multi_CI_hi","multi_p"]
    for c in num_cols:
        df[c] = df[c].map(lambda x: "" if pd.isna(x) else float_fmt.format(x))
    html = df.to_html(index=False, border=0, escape=False)
    if html_path:
        with open(html_path, "w") as f:
            f.write(html)
    return df, html



In [None]:
# columns you want gone before modeling:
# commented - keep
# uncommented - drop
to_drop = [
    'comorbidities_vitAdef',
    'comorb_GI_skin_MSK_GU',

    'comordbities_dactyly',
    'comordbities_external',
    'comordbities_congenital',
    'comordbities_neuropsych',
    'comordbities_ent',
    'comordbities_rare',
    'cmspec_DM_microvascular',
    'comordbities_circulatory',
    'comordbities_ischemicHD',
    'comorbidities_hemonc',
    'race_WhatRaceEthnicity_Black',
    'race_Unknown',
    'race_WhatRaceEthnicity_AIAN',
    'race_WhatRaceEthnicity_MENA',
    'race_WhatRaceEthnicity_White',
    #'SES_education',
    'SES_income',
    'SES_homeown',
    'SES_healthins',
    'SES_married',
    #'comorbidities_DM',
    'comordbities_obesity',
    'comordbities_HTN',
    'comorbidities_HLD',
    #'LS_smoke',
    #'NumberOfVisits',
    #'sex_numeric',
    'race_WhatRaceEthnicity_Asian'
    
]



In [None]:
# Do both univariate and multivariate and output as single HTML table - change table names as needed
res = doLogit_both(
    retdys_matrix_merge_glmnet,
    drop_cols=to_drop,
    visits_divisor=100,
    age_divisor=10,
    compute_vif=True  # optional
)

#Get the 9-column tidy table (and optional HTML)
table9, html_str = tidy_logit_result_both(res, html_path="table_retdys_both.html")

#Access fitted models if you need detailed summaries:
multi_model, multi_res, X_multi, y = res["multi"]
print(multi_res.summary())