In [1]:
from dotenv import load_dotenv
from phmlondon.snow_utils import SnowflakeConnection
import pandas as pd
from statsmodels.formula.api import logit
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import altair as alt
from tqdm import tqdm

pd.set_option('display.max_rows', 200)
pd.set_option('display.max_columns', 100)

# CONFIG

In [2]:
PHENOTYPES_OF_INTEREST = {
    "ASTHMA": "Asthma diagnoses simple reference set",
    "COPD": "Chronic obstructive pulmonary disorder, emphysema, and associated lung diseases simple reference set",
    "DIABETES_ANY": "Diabetes diagnoses simple reference set",
    "DIABETES_T2": "Diabetes type 2 diagnoses simple reference set",
    "DIABETES_T1": "Diabetes type 1 diagnoses simple reference set",
    "HYPERTENSION": "Systemic hypertension diagnoses simple reference set",
    "ANGINA_CHD": "Angina and coronary heart disease diagnoses simple reference set",
    "MYOCARDIAL_INFARCTION": "Myocardial infarction diagnoses simple reference set",
    "TIA": "Transient ischaemic attack diagnoses simple reference set",
    "NON_HAEMORRHAGIC_STROKE": "Non haemorrhagic strokes simple reference set",
    "CKD_1": "Chronic kidney disease 1",
    "CKD_3": "Chronic kidney disease 3",
    "DEPRESSION": "Depression diagnoses simple reference set",
    "PSYCHOSIS_SCHIZOPHRENIA_BIPOLAR": "Psychosis, schizophrenia and bipolar affective disorder simple reference set"
}

# FUNCTIONS

## Helper

In [3]:
def clean_column_name(col):
    """
    Clean column names for Patsy compatibility.
    """
    return col.replace('-', '_minus_').replace('+', '_plus_').replace(' ', '_')

def create_dummies(df, col, ref):
    dummies = pd.get_dummies(df[col], prefix=col, drop_first=False)
    ref_col = f"{col}_{ref}"
    if ref_col in dummies.columns:
        dummies = dummies.drop(columns=[ref_col])
    return dummies

## Data

In [4]:
def get_modeling_data(
        snowsesh,
        phenotype_name
        ):
    """
    Retrieves modeling dataset combining active patients and phenotype data.

    Args:
        snowsesh:
            Snowflake session
        phenotype_name:
            The specific phenotype to analyse

    Returns:
        DataFrame containing features and phenotype flags for modeling
    """
    # dynamic SQL to pull phenotype
    if phenotype_name not in PHENOTYPES_OF_INTEREST:
        raise ValueError(f"Invalid phenotype name. Must be one of {list(PHENOTYPES_OF_INTEREST.keys())}")
    phenotype_db_name = PHENOTYPES_OF_INTEREST[phenotype_name]
    phenotype_col = f'COALESCE(ph."{phenotype_db_name}", 0) as "{phenotype_name}"'

    query = f"""
    WITH age_categories AS (
        SELECT
            PERSON_ID,
            CASE
                WHEN DATEDIFF(year, DATE_OF_BIRTH, CURRENT_DATE()) < 18 THEN '0-17'
                WHEN DATEDIFF(year, DATE_OF_BIRTH, CURRENT_DATE()) < 25 THEN '18-24'
                WHEN DATEDIFF(year, DATE_OF_BIRTH, CURRENT_DATE()) < 35 THEN '25-34'
                WHEN DATEDIFF(year, DATE_OF_BIRTH, CURRENT_DATE()) < 45 THEN '35-44'
                WHEN DATEDIFF(year, DATE_OF_BIRTH, CURRENT_DATE()) < 55 THEN '45-54'
                WHEN DATEDIFF(year, DATE_OF_BIRTH, CURRENT_DATE()) < 65 THEN '55-64'
                WHEN DATEDIFF(year, DATE_OF_BIRTH, CURRENT_DATE()) < 75 THEN '65-74'
                ELSE '75+'
            END as AGE_BAND,
            DATEDIFF(year, DATE_OF_BIRTH, CURRENT_DATE()) as AGE_YEARS
        FROM INTELLIGENCE_DEV.AI_CENTRE_FEATURE_STORE.PERSON_NEL_MASTER_INDEX
    )
    SELECT
        -- Features
        p.PERSON_ID,
        p.PATIENT_LSOA_2011,
        -- Use Other/Mixed/NotStated/Unknown as REFERENCE group
        CASE
            WHEN p.ETHNIC_AIC_CATEGORY IN ('Other', 'Not Stated', 'Mixed') THEN 'Unknown'
            ELSE COALESCE(p.ETHNIC_AIC_CATEGORY, 'Unknown') -- fold into reference
        END as ETHNIC_AIC_CATEGORY,
        -- Use Female/Other as REFERENCE group
        CASE
            WHEN p.GENDER IN ('Male', 'Female') THEN p.GENDER
            ELSE 'Female'  -- fold Other into reference
        END as GENDER,
        ac.AGE_BAND,
        ac.AGE_YEARS,
        -- Normalise IMD rank to 0-1 scale
        CAST(p.LONDON_IMD_RANK as FLOAT) / (
            SELECT MAX(LONDON_IMD_RANK)
            FROM INTELLIGENCE_DEV.AI_CENTRE_FEATURE_STORE.PERSON_NEL_MASTER_INDEX
        ) as NORMALISED_IMD_RANK,
        -- p.IMD_QUINTILE, -- exclude as we are using normalised rank
        -- Phenotype flags (coalesce to 0 for patients without records)
        {phenotype_col}
    FROM INTELLIGENCE_DEV.AI_CENTRE_FEATURE_STORE.PERSON_NEL_MASTER_INDEX p
    INNER JOIN age_categories ac
        ON p.PERSON_ID = ac.PERSON_ID
    LEFT JOIN INTELLIGENCE_DEV.AI_CENTRE_FEATURE_STORE.PERSON_5YEAR_PHENOTYPE ph
        ON p.PERSON_ID = ph.PERSON_ID
    WHERE p.PATIENT_STATUS = 'ACTIVE'
    AND p.LONDON_IMD_RANK IS NOT NULL -- Exclude missing IMD for demo run
    AND p.INCLUDE_IN_LIST_SIZE_FLAG = 1
    """

    try:
        df = snowsesh.execute_query_to_df(query)
        print(f"Retreived columns: {df.columns}")
        return df
    except Exception as e:
        print(f"Error retrieving modeling data: {e}")
        raise e

def prepare_modeling_data(
        df,
        phenotype_name
        ):
    """
    Prepares data for modeling by encoding categorical variables and handling missing values.

    Args:
        df:
            Input DataFrame
        phenotype_name:
            Name of the phenotype (now outcome column header)

    Returns:
        Tuple of modeling DataFrame, metadata dict)
    """
    df_prep = df.copy()

    # Set reference categories for categorical variables
    cat_refs = {
        'ETHNIC_AIC_CATEGORY': 'Unknown',
        'GENDER': 'Female',
        'AGE_BAND': '45-54',  # Middle age as reference
    }

    # create dummy variables
    for col, ref in cat_refs.items():
        df_prep = pd.concat([df_prep, create_dummies(df_prep, col, ref)], axis=1)

    # clean columns
    column_mapping = {col: clean_column_name(col) for col in df_prep.columns}
    df_prep = df_prep.rename(columns=column_mapping)

    # create features
    feature_cols = [col for col in df_prep.columns if col.startswith(tuple(cat_refs.keys()))]
    for col in cat_refs:
        feature_cols.remove(col) # Remove original columns after dummy creation
    feature_cols.append('NORMALISED_IMD_RANK') # add other columns of interest here

    # Create final modeling dataset
    X = df_prep[feature_cols]
    y = df_prep[phenotype_name]

    # Combine into single dataframe
    model_df = pd.concat([X, y], axis=1)

    # Store column information
    metadata = {
        'feature_cols': feature_cols,
        'outcome_col': phenotype_name,
        'categorical_refs': cat_refs
    }

    # print basic checks (debug)
    # print("Dataset Shape:")
    # print(model_df.shape)

    # print("Metadata Dictionary:")
    # print(metadata)

    # print("Outcome Distribution:")
    # outcome = metadata['outcome_col']
    # print(model_df[outcome].value_counts(normalize=True))

    # print("Missing Values:")
    # print(model_df.isnull().sum()[model_df.isnull().sum() > 0])

    # model_df.to_csv('model_df.csv')

    return model_df, metadata

## Fit Models

In [5]:
def fit_logistic_models(
        model_df,
        metadata,
        phenotype_name
        ):
    """
    Fits logistic regression models with interaction terms and extracts odds ratios.
    Interactions: ethnicity x age; ethnicity x deprivation

    Args:
        model_df:
            Prepared modeling DataFrame
        metadata:
            Dictionary containing feature information
        phenotype_name:
            Name of the phenotype/outcome column

    Returns:
        Tuple containing:
        - Dictionary of fitted models
        - DataFrame for effect modification plots
        - DataFrame for stratified probability plots
    """

    # list features for interactions
    feature_cols = metadata['feature_cols']
    ethnic_cols = [col for col in feature_cols if col.startswith('ETHNIC_AIC_CATEGORY_')]
    age_cols = [col for col in feature_cols if col.startswith('AGE_BAND_')]
    interaction_col = "NORMALISED_IMD_RANK"

    # dyanmic formula generation including interaction terms
    # e = ethnic subgropu
    # a = age band
    formula = f"{phenotype_name} ~ {' + '.join(feature_cols)} + {' + '.join([f'{e}:{a}' for e in ethnic_cols for a in age_cols])} + {' + '.join([f'{e}:{interaction_col}' for e in ethnic_cols])}"

    # fit!
    try:
        model = logit(formula, data=model_df).fit(method='bfgs') ## more robust to low variance/singular matrix errors
    except Exception as e:
        print(f"Formula that caused error: {formula}")
        raise e

    odds_ratios = pd.DataFrame({
        'odds_ratio': np.exp(model.params),
        'lower_ci': np.exp(model.conf_int()[0]),
        'upper_ci': np.exp(model.conf_int()[1])
    })

    pseudo_r_squared = model.prsquared

    return {
        'model': model,
        'odds_ratios': odds_ratios,
        'pseudo_r_squared': pseudo_r_squared,
        'formula': formula # store the formula for debugging
    }

def print_model_summary(model_dict):
    """
    Creates a detailed summary of the logistic regression model results.

    Args:
        model_dict:
            Dictionary containing the fitted model and other information.

    Returns:
        DataFrame containing formatted model results
    """
    model = model_dict['model']

    summary_df = pd.DataFrame({
        'Coefficient': model.params,
        'Std Error': model.bse,
        'z-value': model.tvalues,
        'P>|z|': model.pvalues,
        'OR': np.exp(model.params),
        '[0.025': np.exp(model.conf_int()[0]),
        '0.975]': np.exp(model.conf_int()[1])
    })

    summary_df = summary_df.round(4)

    summary_df['Significance'] = ''
    summary_df.loc[summary_df['P>|z|'] < 0.05, 'Significance'] = '*'
    summary_df.loc[summary_df['P>|z|'] < 0.01, 'Significance'] = '**'
    summary_df.loc[summary_df['P>|z|'] < 0.001, 'Significance'] = '***'

    print("Model Fit Statistics:")
    print(f"Number of observations: {model.nobs}")
    print(f"Pseudo R-squared: {model.prsquared:.4f}")
    print(f"Log-Likelihood: {model.llf:.4f}")
    print(f"AIC: {model.aic:.4f}")
    print(f"BIC: {model.bic:.4f}")
    print("Likelihood Ratio Test:")
    print(f"Chi2: {model.llr:.4f}")
    print(f"p-value: {model.llr_pvalue:.4f}")

    return summary_df


## Risk Calculations

In [6]:
def calculate_individual_risks(
        df,
        model_dict,
        metadata
        ):
    """
    Calculates individual risk scores using fitted model.

    Args:
        df:
            Original dataframe containing patient data
        model_dict:
            Dictionary containing fitted model and results
        metadata:
            Dictionary containing feature information

    Returns:
        DataFrame with original data plus risk scores and predictions
    """
    model = model_dict['model']

    # Create prediction dataframe wit same structure as training data
    pred_df = prepare_modeling_data(df, metadata['outcome_col'])[0]

    # return linear predictors /log-odds
    linear_predictor = model.predict(pred_df, linear=True)

    # return probabilities
    probabilities = model.predict(pred_df)

    # Add predictions to original dataframe
    results_df = df.copy()
    results_df['predicted_risk'] = probabilities
    results_df['linear_predictor'] = linear_predictor

    # Add binary prediction using 0.5 threshold (we can adjust this - I haven't calibrated)
    results_df['predicted_case'] = (probabilities >= 0.5).astype(int)

    return results_df

def analyse_geographic_risk(
        risk_df,
        phenotype_name,
        grouping_level='LSOA'):
    """
    Aggregates risks and actual cases by geographic area.

    Args:
        risk_df:
            DataFrame with individual risks and actual phenotype status
        phenotype_name:
            Name of the phenotype being analyzed
        grouping_level:
            Geographic level for aggregation (only LSOA for now)

    Returns:
        DataFrame with geographic risk analysis
    """
    if grouping_level == 'LSOA':
        geo_col = 'PATIENT_LSOA_2011'
    else:
        raise ValueError(f"Unsupported geographic level: {grouping_level}")

    # Aggregate by geographic area
    geo_analysis = risk_df.groupby(geo_col).agg({
        'predicted_risk': ['count', 'mean', 'sum'],  # sum gives expected number of cases
        phenotype_name: ['sum'],  # actual cases
        'PERSON_ID': 'count'  # = pop size
    }).reset_index()

    # flatten multi-level df into multipart column names
    geo_analysis.columns = [
        f"{'' if col[0] == geo_col else col[0]}_{col[1]}"
        if col[1] != '' else col[0]
        for col in geo_analysis.columns
    ]

    # summarise metrics
    geo_analysis['expected_cases'] = geo_analysis['predicted_risk_sum']
    geo_analysis['actual_cases'] = geo_analysis[f'{phenotype_name}_sum']
    geo_analysis['population'] = geo_analysis['PERSON_ID_count']
    geo_analysis['case_difference'] = geo_analysis['actual_cases'] - geo_analysis['expected_cases']
    geo_analysis['standardized_difference'] = geo_analysis['case_difference'] / np.sqrt(geo_analysis['population'])

    # Calculate 95% confidence intervals for the difference
    geo_analysis['difference_ci_lower'] = geo_analysis['case_difference'] - (1.96 * np.sqrt(geo_analysis['population']))
    geo_analysis['difference_ci_upper'] = geo_analysis['case_difference'] + (1.96 * np.sqrt(geo_analysis['population']))

    # Flag areas with significant under-diagnosis
    geo_analysis['significant_under_diagnosis'] = geo_analysis['difference_ci_upper'] < 0

    return geo_analysis

def summarise_risk_analysis(
        geo_analysis,
        phenotype_name
        ):
    """
    Provides summary statistics of the geographic risk analysis.

    Args:
        geo_analysis:
            Output from analyse_geographic_risk
        phenotype_name:
            Name of the phenotype being analyzed

    Returns:
        Dictionary containing summary statistics
    """
    summary = {
        'phenotype': phenotype_name,
        'total_population': int(geo_analysis['population'].sum()),
        'total_actual_cases': int(geo_analysis['actual_cases'].sum()),
        'total_expected_cases': int(round(geo_analysis['expected_cases'].sum())),
        'total_case_difference': int(round(geo_analysis['case_difference'].sum())),
        'areas_analyzed': len(geo_analysis),
        'areas_under_diagnosed': int(geo_analysis['significant_under_diagnosis'].sum()),
        'percent_areas_under_diagnosed': round(100 * geo_analysis['significant_under_diagnosis'].mean(), 1)
    }

    return summary

## Effect Modification

In [7]:
def create_effect_modification_df(
        model_dict,
        metadata
        ):
    """
    Creates effect modification dataframe from model coefficients showing odds ratios
    across different strata.

    Args:
        model_dict:
            Dictionary containing fitted model and results
        metadata:
            Dictionary containing feature information

    Returns:
        DataFrame containing effect modification results for age and IMD interactions
    """
    model = model_dict['model']
    records = []

    # Extract coefficients and confidence intervals
    coef = model.params
    conf_int = model.conf_int()

    # Get ethnic groups
    ethnic_groups = [col.replace('ETHNIC_AIC_CATEGORY_', '').replace('[T.True]', '')
                    for col in coef.index
                    if col.startswith('ETHNIC_AIC_CATEGORY_') and ':' not in col]

    # Get age bands
    age_bands = [col.replace('AGE_BAND_', '').replace('[T.True]', '')
                 for col in coef.index
                 if col.startswith('AGE_BAND_') and ':' not in col]

    # Process age band interactions
    for ethnic in ethnic_groups:
        ethnic_col = f'ETHNIC_AIC_CATEGORY_{ethnic}[T.True]'

        # Get main effect/ci for ethnicity
        ethnic_effect = coef[ethnic_col]
        ethnic_ci = conf_int.loc[ethnic_col]

        # calculate adjusted effect at each age band
        for age in age_bands:
            age_col = f'AGE_BAND_{age}[T.True]'
            interaction_term = f"{ethnic_col}:{age_col}"

            # adjusted effect of main + interaction)
            main_age_effect = coef[age_col] if age_col in coef else 0
            interaction_effect = coef[interaction_term] if interaction_term in coef else 0
            total_effect = ethnic_effect + main_age_effect + interaction_effect

            # confidence intervals
            main_age_ci = conf_int.loc[age_col] if age_col in conf_int.index else [0, 0]
            interaction_ci = conf_int.loc[interaction_term] if interaction_term in conf_int.index else [0, 0]
            lower_ci = ethnic_ci[0] + main_age_ci[0] + interaction_ci[0]
            upper_ci = ethnic_ci[1] + main_age_ci[1] + interaction_ci[1]

            records.append({
                'effect_modifier': 'age_band',
                'modifier_value': age,
                'ethnic_group': ethnic,
                'odds_ratio': np.exp(total_effect),
                'lower_ci': np.exp(lower_ci),
                'upper_ci': np.exp(upper_ci)
            })

    # Process IMD interactions
    imd_values = np.linspace(0, 1, 20)  # Take 20 points across IMD range

    for ethnic in ethnic_groups:
        ethnic_col = f'ETHNIC_AIC_CATEGORY_{ethnic}[T.True]'

        # Get main effect/ci for ethnicity
        ethnic_effect = coef[ethnic_col]
        ethnic_ci = conf_int.loc[ethnic_col]

        # Get IMD effect/ci
        imd_effect = coef['NORMALISED_IMD_RANK']
        imd_ci = conf_int.loc['NORMALISED_IMD_RANK']

        # Get interaction effect
        interaction_term = f"{ethnic_col}:NORMALISED_IMD_RANK"
        interaction_effect = coef[interaction_term] if interaction_term in coef else 0
        interaction_ci = conf_int.loc[interaction_term] if interaction_term in conf_int.index else [0, 0]

        for imd in imd_values:
            # imd and interaction effects as log-odds change per unit of imd
            total_effect = ethnic_effect + (imd_effect * imd) + (interaction_effect * imd)

            # confidence intervals
            lower_ci = (ethnic_ci[0] + (imd_ci[0] * imd) + (interaction_ci[0] * imd))
            upper_ci = (ethnic_ci[1] + (imd_ci[1] * imd) + (interaction_ci[1] * imd))

            records.append({
                'effect_modifier': 'imd_rank',
                'modifier_value': imd,
                'ethnic_group': ethnic,
                'odds_ratio': np.exp(total_effect),
                'lower_ci': np.exp(lower_ci),
                'upper_ci': np.exp(upper_ci)
            })

    return pd.DataFrame(records)

## Visualisation

In [8]:
## seaborn / matplotlib

def plot_effect_modifications(
        effect_df,
        phenotype_name
        ):
    """
    Creates two plots showing effect modification by age and IMD.
    """
    sns.set_theme()
    sns.set_palette("deep")

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # age band plot
    age_df = effect_df[effect_df['effect_modifier'] == 'age_band']

    # # convert age bands to categorical for proper ordering
    # age_order = ['0_minus_17', '18_minus_24', '25_minus_34', '35_minus_44',
    #              '45_minus_54', '55_minus_64', '65_minus_74', '75_plus_']

    for ethnic in age_df['ethnic_group'].unique():
        ethnic_data = age_df[age_df['ethnic_group'] == ethnic]
        ax1.plot(ethnic_data['modifier_value'], ethnic_data['odds_ratio'],
                marker='o', label=ethnic)
        ax1.fill_between(ethnic_data['modifier_value'],
                        ethnic_data['lower_ci'],
                        ethnic_data['upper_ci'],
                        alpha=0.2)

    ax1.set_title(f'Age Effect Modification for {phenotype_name}')
    ax1.set_xlabel('Age Band')
    ax1.set_ylabel('OR')
    ax1.grid(True, alpha=0.3)
    ax1.tick_params(axis='x', rotation=45)

    # IMD rank plot
    imd_df = effect_df[effect_df['effect_modifier'] == 'imd_rank']

    for ethnic in imd_df['ethnic_group'].unique():
        ethnic_data = imd_df[imd_df['ethnic_group'] == ethnic].copy()
        ethnic_data['modifier_value'] = pd.to_numeric(ethnic_data['modifier_value'])
        ethnic_data = ethnic_data.sort_values('modifier_value')

        ax2.plot(ethnic_data['modifier_value'], ethnic_data['odds_ratio'],
                label=ethnic)
        ax2.fill_between(ethnic_data['modifier_value'],
                        ethnic_data['lower_ci'],
                        ethnic_data['upper_ci'],
                        alpha=0.2)

    ax2.set_title(f'IMD Rank Effect Modification for {phenotype_name}')
    ax2.set_xlabel('Normalized IMD Rank (0 = Most Deprived, 1 = Least Deprived)')
    ax2.set_ylabel('OR')
    ax2.grid(True, alpha=0.3)

    # layout
    handles, labels = ax2.get_legend_handles_labels()
    fig.legend(handles, labels, bbox_to_anchor=(1.02, 0.5), loc='center left', title="Ethnic Group")

    plt.tight_layout()

    return fig

## altair - use this code in stremlit

def plot_effect_modifications_altair(
        effect_df,
        phenotype_name
        ):
    """
    Basic Altair plots showing effect modification by age and IMD.
    """
    # Age band plot
    age_df = effect_df[effect_df['effect_modifier'] == 'age_band']

    age_chart = alt.Chart(age_df).mark_line().encode(
        x='modifier_value:N',
        y='odds_ratio:Q',
        color='ethnic_group:N'
    ).properties(
        width=400,
        height=300,
        title=f'Age Effect for {phenotype_name}'
    )

    # IMD plot
    imd_df = effect_df[effect_df['effect_modifier'] == 'imd_rank'].copy()
    imd_df['modifier_value'] = pd.to_numeric(imd_df['modifier_value'])

    imd_chart = alt.Chart(imd_df).mark_line().encode(
        x='modifier_value:Q',
        y='odds_ratio:Q',
        color='ethnic_group:N'
    ).properties(
        width=400,
        height=300,
        title=f'IMD Effect for {phenotype_name}'
    )

    # combine plots
    return alt.hconcat(age_chart, imd_chart)


# TEST PIPELINE

In [None]:
load_dotenv()

snowsesh = SnowflakeConnection()
snowsesh.use_database("INTELLIGENCE_DEV")
snowsesh.use_schema("AI_CENTRE_FEATURE_STORE")

In [None]:
# Get data for a specific phenotype
phenotype_name = "HYPERTENSION"
df = get_modeling_data(snowsesh, phenotype_name)
df.head()

In [None]:
# Prepare data for modelling
model_df, metadata = prepare_modeling_data(df, phenotype_name)
model_df.head()

In [None]:
# Logit model
model_results = fit_logistic_models(model_df, metadata, phenotype_name)
model_summary = print_model_summary(model_results)

print("Model Summary:")
print(model_summary)

print("Odds Ratios:")
print(model_results['odds_ratios'])

print("Pseudo R-squared:")
print(model_results['pseudo_r_squared'])

print("Formula:")
print(model_results['formula'])

In [None]:
risk_df = calculate_individual_risks(df, model_results, metadata)
geo_analysis = analyse_geographic_risk(risk_df, phenotype_name)

summary_stats = summarise_risk_analysis(geo_analysis, phenotype_name)

print(f"Summary for {phenotype_name}:")
for key, value in summary_stats.items():
    print(f"{key}: {value}")

In [None]:
effect_df = create_effect_modification_df(model_results, metadata)
effect_df.info()

In [None]:
effect_df.head()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd


fig = plot_effect_modifications(effect_df, "Hypertension")
plt.show()

In [None]:
alt.renderers.enable('default')
chart = plot_effect_modifications_altair(effect_df, "Hypertension")
chart

# MAIN (LOOP)

In [None]:
def main():
    """
    Creates effect modification analysis for all phenotypes and combines into single dataset
    Performs risk analysis and prediction for all phenotypes and combines into single dataset
    Uploads both to Snowflake
    """
    load_dotenv()

    try:
        effect_dfs = []
        risk_dfs = []

        # Process each phenotype
        for phenotype in tqdm(PHENOTYPES_OF_INTEREST.keys(), desc="Processing phenotypes"):
            print(f"Now modelling {phenotype}")

            # Get and process data
            df = get_modeling_data(snowsesh, phenotype)
            model_df, metadata = prepare_modeling_data(df, phenotype)
            model_results = fit_logistic_models(model_df, metadata, phenotype)

            # Generate and store effects with phenotype label
            effect_df = create_effect_modification_df(model_results, metadata)
            effect_df['phenotype'] = phenotype
            effect_dfs.append(effect_df)
            print(f"{phenotype} effects appended")

            # Generate and store predicted risks with phenotype label
            risk_df = calculate_individual_risks(df, model_results, metadata)
            geo_analysis = analyse_geographic_risk(risk_df, phenotype)
            geo_analysis['phenotype'] = phenotype
            key_cols = [
                'PATIENT_LSOA_2011',
                'phenotype',
                'population',
                'actual_cases',
                'expected_cases',
                'case_difference',
                'standardized_difference',
                'significant_under_diagnosis'
            ]
            risk_dfs.append(geo_analysis[key_cols])

        # Combine all effects and save
        phenotype_effects = pd.concat(effect_dfs, ignore_index=True)
        #phenotype_effects.to_csv('phenotype_effects.csv', index=False)
        print(f"Saved combined effects for {len(PHENOTYPES_OF_INTEREST)} phenotypes")

        # Combine all calculated risks and save
        combined_risk_analysis = pd.concat(risk_dfs, ignore_index=True)
        #combined_risk_analysis.to_csv('risk_analysis.csv', index=False)
        print(f"Saved combined risk analysis for {len(PHENOTYPES_OF_INTEREST)} phenotypes")

        # Create effect visualisations
        for phenotype in PHENOTYPES_OF_INTEREST.keys():
            try:
                phenotype_data = phenotype_effects[phenotype_effects['phenotype'] == phenotype]
                fig = plot_effect_modifications(phenotype_data, phenotype)
                print(f"SHOWING: {phenotype}")
                plt.show()
            except:
                pass

    except Exception as e:
        print(f"Error in main process: {e}")
        raise e

if __name__ == "__main__":
    main()

## Save to Snowflake

In [None]:
# snowsesh.load_csv_as_table(
#     csv_path='phenotype_effects.csv',
#     table_name='PHENOTYPE_ADJUSTED_EFFECTS',
#     mode="overwrite"
# )

# snowsesh.load_csv_as_table(
#     csv_path='risk_analysis.csv',
#     table_name='PHENOTYPE_GEOSPATIAL_RISK',
#     mode="overwrite"
# )