# Load Files

In [None]:
#Load admissions, d_hcpcs, d_icd_diagnoses, d_icd_procedures, diagnoses_icd,
# hpcsevents, procedures_icd, and prescriptions tables

# Renamed dictionary with prefixes
file_columns = {
    'hcpcsevents.csv': {'prefix': 'hcpcs_', 'cols': ['subject_id', 'hadm_id', 'hcpcs_cd', 'chartdate', 'short_description']},
    'd_hcpcs.csv': {'prefix': 'd_hcpcs_', 'cols': ['code', 'category', 'long_description', 'short_description']},
    'diagnoses_icd.csv': {'prefix': 'diag_', 'cols': ['subject_id', 'hadm_id', 'icd_code', 'icd_version']},
    'd_icd_diagnoses.csv': {'prefix': 'd_diag_', 'cols': ['icd_code', 'icd_version', 'long_title']},
    'procedures_icd.csv': {'prefix': 'proc_', 'cols': ['subject_id', 'hadm_id', 'chartdate', 'icd_code', 'icd_version']},
    'd_icd_procedures.csv': {'prefix': 'd_proc_', 'cols': ['icd_code', 'icd_version', 'long_title']},
    'admissions.csv': {'prefix': 'admit_', 'cols': ['subject_id', 'hadm_id', 'admittime', 'dischtime']},
    'prescriptions.csv': {'prefix': 'presc_', 'cols': ['subject_id', 'hadm_id', 'drug', 'starttime']}
}

# Define data types for relevant columns in each file
file_dtypes = {
    'hcpcsevents.csv': {'subject_id': 'int64', 'hadm_id': 'int64', 'hcpcs_cd': 'object', 'short_description': 'object'},
    'd_hcpcs.csv': {'code': 'object', 'category': 'object', 'long_description': 'object', 'short_description': 'object'},
    'diagnoses_icd.csv': {'subject_id': 'int64', 'hadm_id': 'int64', 'icd_code': 'object', 'icd_version': 'int64'},
    'd_icd_diagnoses.csv': {'icd_code': 'object', 'icd_version': 'int64', 'long_title': 'object'},
    'procedures_icd.csv': {'subject_id': 'int64', 'hadm_id': 'int64', 'icd_code': 'object', 'icd_version': 'int64'},
    'd_icd_procedures.csv': {'icd_code': 'object', 'icd_version': 'int64', 'long_title': 'object'},
    'admissions.csv': {'subject_id': 'int64', 'hadm_id': 'int64'},
    'prescriptions.csv': {'subject_id': 'int64', 'hadm_id': 'int64', 'drug': 'object'}
}

# Define date columns for parsing
date_columns = {
    'hcpcsevents.csv': ['chartdate'],
    'procedures_icd.csv': ['chartdate'],
    'admissions.csv': ['admittime','dischtime'],
    'prescriptions.csv': ['starttime']
}

In [None]:
import os
import pandas as pd

dataframe_dict = {}
csv_files = os.listdir(folder_path)

for file in csv_files:
    if file in file_columns and file in file_dtypes:
        file_info = file_columns[file]
        file_path = os.path.join(folder_path, file)
        dtypes_info = file_dtypes[file]
        parse_dates_info = date_columns.get(file, [])

        try:
            if file in ['procedures_icd.csv', 'diagnoses_icd.csv']:
                df = pd.read_csv(file_path,
                                 usecols=file_info['cols'],
                                 low_memory=False,
                                 parse_dates=parse_dates_info)

                # Handle missing 'subject_id' by removing rows
                if 'subject_id' in df.columns:
                    df['subject_id'] = pd.to_numeric(df['subject_id'], errors='coerce')
                    df = df.dropna(subset=['subject_id']).copy()

                    # Convert 'subject_id' to int64 after removing NaNs
                    try:
                        df['subject_id'] = df['subject_id'].astype('int64')
                    except ValueError:
                        print(f"Could not convert 'subject_id' to int64 in {file} after dropping NaNs. Possible non-numeric values remaining.")

            else:
                # Read other files with specified dtypes and parse dates
                df = pd.read_csv(file_path,
                                 usecols=file_info['cols'],
                                 dtype=dtypes_info,
                                 parse_dates=parse_dates_info)

            # Create a dictionary for renaming columns
            rename_dict = {col: file_info['prefix'] + col for col in df.columns}

            # Rename the columns
            df = df.rename(columns=rename_dict)
            dataframe_dict[file] = df
            print(f"Successfully loaded and processed {file}")

        except FileNotFoundError:
            print(f"File not found: {file_path}")
        except Exception as e:
            print(f"An error occurred while processing {file}: {e}")


# Merge descriptions files with their counterparts

## Merge diagnoses_icd and d_icd_diagnoses and Filter Out CKD Datapoints

In [None]:
diagnoses_icd_df = dataframe_dict['diagnoses_icd.csv']
d_icd_diagnoses_df = dataframe_dict['d_icd_diagnoses.csv']

# Merge diagnoses_icd and d_icd_diagnoses to get diagnosis descriptions
diagnoses_merged_df = pd.merge(diagnoses_icd_df, d_icd_diagnoses_df, how='left',
                               left_on=['diag_icd_code', 'diag_icd_version'],
                               right_on=['d_diag_icd_code', 'd_diag_icd_version'])

# Drop the redundant columns from the right dataframe after merging
diagnoses_merged_df = diagnoses_merged_df.drop(columns=['d_diag_icd_code', 'd_diag_icd_version'])

# Rename the remaining columns
diagnoses_merged_df = diagnoses_merged_df.rename(columns={'d_diag_long_title': 'diag_long_title'})


print("Head of the merged diagnoses_merged_df:")
display(diagnoses_merged_df.head())

print("\nInfo of the merged diagnoses_merged_df:")
diagnoses_merged_df.info()

In [None]:
# Filter diagnoses_merged_df to include only rows with specified ICD codes
ckd_icd_codes = [
    '40300', '40301', '40310', '40311', '40390', '40391', '40400', '40401',
    '40402', '40403', '40410', '40411', '40412', '40413', '40490', '40491',
    '40492', '40493', '5851', '5852', '5853', '5854', '5855', '5859', 'D631',
    'E0822', 'E0922', 'E1022', 'E1122', 'E1322', 'I12', 'I120', 'I129', 'I13',
    'I130', 'I131', 'I1310', 'I1311', 'I132', 'N18', 'N181', 'N182', 'N183',
    'N1830', 'N1831', 'N1832', 'N184', 'N185', 'N189', 'O102', 'O1021',
    'O10211', 'O10212', 'O10213', 'O10219', 'O1022', 'O1023', 'O103',
    'O1031', 'O10311', 'O10312', 'O10313', 'O10319', 'O1032', 'O1033'
]

filtered_diagnoses_merged_df = diagnoses_merged_df[diagnoses_merged_df['diag_icd_code'].isin(ckd_icd_codes)].copy()

print("Head of filtered_diagnoses_merged_df:")
display(filtered_diagnoses_merged_df.head())

print("\nInfo of filtered_diagnoses_merged_df:")
filtered_diagnoses_merged_df.info()

del diagnoses_icd_df
del d_icd_diagnoses_df
del diagnoses_merged_df

## Merge hcpcsevents and d_hcpcs

In [None]:
# Access the dataframes
hcpcsevents_df = dataframe_dict['hcpcsevents.csv']
d_hcpcs_df = dataframe_dict['d_hcpcs.csv']

# Merge hcpcsevents and d_hcpcs to get descriptions
hcpcsevents_merged_df = pd.merge(hcpcsevents_df, d_hcpcs_df, how='left',
                                 left_on='hcpcs_hcpcs_cd',
                                 right_on='d_hcpcs_code')

# Drop redundant columns and rename for consistency
# Drop the redundant code and one of the short descriptions
hcpcsevents_merged_df = hcpcsevents_merged_df.drop(columns=['d_hcpcs_code', 'hcpcs_short_description'])

# Rename remaining d_hcpcs columns to hcpcs_ prefix
hcpcsevents_merged_df = hcpcsevents_merged_df.rename(columns={
    'd_hcpcs_category': 'hcpcs_category',
    'd_hcpcs_long_description': 'hcpcs_long_description',
    'd_hcpcs_short_description': 'hcpcs_short_description',
    'hcpcs_hcpcs_cd': 'hcpcs_cd'
})


# Define the desired column order
desired_column_order = [
    'hcpcs_subject_id',
    'hcpcs_hadm_id',
    'hcpcs_chartdate',
    'hcpcs_cd',
    'hcpcs_category',
    'hcpcs_short_description',
    'hcpcs_long_description'
]

# Create a list of existing columns to only select those present in the dataframe
existing_columns_in_order = [col for col in desired_column_order if col in hcpcsevents_merged_df.columns]

hcpcsevents_merged_df = hcpcsevents_merged_df[existing_columns_in_order]


print("Head of the merged and cleaned hcpcsevents_merged_df:")
display(hcpcsevents_merged_df.head())

print("\nInfo of the merged and cleaned hcpcsevents_merged_df:")
hcpcsevents_merged_df.info()

del hcpcsevents_df
del d_hcpcs_df

## Merge procedures_icd and d_icd_proceudres

In [None]:
# Access the dataframes
procedures_icd_df = dataframe_dict['procedures_icd.csv']
d_icd_procedures_df = dataframe_dict['d_icd_procedures.csv']

# Merge procedures_icd and d_icd_procedures to get descriptions
procedures_merged_df = pd.merge(procedures_icd_df, d_icd_procedures_df, how='left',
                                left_on=['proc_icd_code', 'proc_icd_version'],
                                right_on=['d_proc_icd_code', 'd_proc_icd_version'])

# Drop redundant columns from the right dataframe after merging
procedures_merged_df = procedures_merged_df.drop(columns=['d_proc_icd_code', 'd_proc_icd_version'])

# Rename the remaining column from d_icd_procedures_df
procedures_merged_df = procedures_merged_df.rename(columns={'d_proc_long_title': 'proc_long_title'})

# Define the desired column order
desired_column_order = [
    'proc_subject_id',
    'proc_hadm_id',
    'proc_chartdate',
    'proc_icd_code',
    'proc_icd_version',
    'proc_long_title'
]

# Reindex the dataframe to reorder columns
procedures_merged_df = procedures_merged_df[desired_column_order]

print("Head of the merged procedures_merged_df:")
display(procedures_merged_df.head())

print("\nInfo of the merged procedures_merged_df:")
procedures_merged_df.info()

del procedures_icd_df
del d_icd_procedures_df

# Merge Files w Sample

In [None]:
# Access the dataframes
admissions_df = dataframe_dict['admissions.csv']
prescriptions_df = dataframe_dict['prescriptions.csv']
hcpcs_events_merged_df_copy = hcpcsevents_merged_df.copy() # Create a copy to work with
procedures_merged_df_copy = procedures_merged_df.copy() # Create a copy to work with


# Get the unique subject_id and hadm_id from the filtered diagnoses dataframe
unique_ids = filtered_diagnoses_merged_df[['diag_subject_id', 'diag_hadm_id']].drop_duplicates()

# Filter admissions_df
# Rename columns to match for filtering
admissions_df_filtered = admissions_df.rename(columns={'admit_subject_id': 'diag_subject_id', 'admit_hadm_id': 'diag_hadm_id'})
admissions_df_filtered = pd.merge(unique_ids, admissions_df_filtered, on=['diag_subject_id', 'diag_hadm_id'], how='left')
# Rename back to original admissions column names after filtering
admissions_df_filtered = admissions_df_filtered.rename(columns={'diag_subject_id': 'admit_subject_id', 'diag_hadm_id': 'admit_hadm_id'})
del admissions_df # Delete original dataframe after it's no longer needed


# Filter prescriptions_df
# Rename columns to match for filtering
prescriptions_df_filtered = prescriptions_df.rename(columns={'presc_subject_id': 'diag_subject_id', 'presc_hadm_id': 'diag_hadm_id'})
prescriptions_df_filtered = pd.merge(unique_ids, prescriptions_df_filtered, on=['diag_subject_id', 'diag_hadm_id'], how='left')
# Rename back to original prescriptions column names after filtering
prescriptions_df_filtered = prescriptions_df_filtered.rename(columns={'diag_subject_id': 'presc_subject_id', 'diag_hadm_id': 'presc_hadm_id'})
del prescriptions_df # Delete original dataframe after it's no longer needed


# Filter hcpcsevents_merged_df
hcpcsevents_merged_df_filtered = hcpcs_events_merged_df_copy.rename(columns={'hcpcs_subject_id': 'diag_subject_id', 'hcpcs_hadm_id': 'diag_hadm_id'})
hcpcsevents_merged_df_filtered = pd.merge(unique_ids, hcpcsevents_merged_df_filtered, on=['diag_subject_id', 'diag_hadm_id'], how='left')
hcpcsevents_merged_df_filtered = hcpcsevents_merged_df_filtered.rename(columns={'diag_subject_id': 'hcpcs_subject_id', 'diag_hadm_id': 'hcpcs_hadm_id'})
# del hcpcsevents_merged_df # Removed: do not delete the global variable

# Filter procedures_merged_df
procedures_merged_df_filtered = procedures_merged_df_copy.rename(columns={'proc_subject_id': 'diag_subject_id', 'proc_hadm_id': 'diag_hadm_id'})
procedures_merged_df_filtered = pd.merge(unique_ids, procedures_merged_df_filtered, on=['diag_subject_id', 'diag_hadm_id'], how='left')
procedures_merged_df_filtered = procedures_merged_df_filtered.rename(columns={'diag_subject_id': 'proc_subject_id', 'diag_hadm_id': 'proc_hadm_id'})
# del procedures_merged_df # Removed: do not delete the global variable

In [None]:
# Start with the filtered diagnoses dataframe as the base
merged_df = filtered_diagnoses_merged_df

# Merge with filtered admissions_df
merged_df = pd.merge(merged_df, admissions_df_filtered,
                                how='left',
                                left_on=['diag_subject_id', 'diag_hadm_id'],
                                right_on=['admit_subject_id', 'admit_hadm_id'])

# Drop redundant ID columns after merging admissions
merged_df = merged_df.drop(columns=['admit_subject_id', 'admit_hadm_id'])
# Delete filtered dataframe after merge
del admissions_df_filtered

# Merge with filtered prescriptions_df
merged_df = pd.merge(merged_df, prescriptions_df_filtered,
                                how='left',
                                left_on=['diag_subject_id', 'diag_hadm_id'],
                                right_on=['presc_subject_id', 'presc_hadm_id'])

# Drop redundant ID columns after merging prescriptions
merged_df = merged_df.drop(columns=['presc_subject_id', 'presc_hadm_id'])
# Delete filtered dataframe after merge
del prescriptions_df_filtered

# Display the head and info of the current merged dataframe
print("Head of the merged dataframe after adding admissions and prescriptions:")
display(merged_df.head())

print("\nInfo of the merged dataframe after adding admissions and prescriptions:")
merged_df.info()

In [None]:
# Merge with filtered hcpcsevents_merged_df
merged_df = pd.merge(merged_df, hcpcsevents_merged_df_filtered,
                                how='left',
                                left_on=['diag_subject_id', 'diag_hadm_id'],
                                right_on=['hcpcs_subject_id', 'hcpcs_hadm_id'])

# Drop redundant ID columns after merging hcpcsevents
merged_df = merged_df.drop(columns=['hcpcs_subject_id', 'hcpcs_hadm_id'])
# Delete filtered dataframe after merge
del hcpcsevents_merged_df_filtered

# Display the head and info of the current merged dataframe
print("Head of the merged dataframe after adding hcpcsevents:")
display(merged_df.head())

print("\nInfo of the merged dataframe after adding hcpcsevents:")
merged_df.info()

In [None]:
# Merge with filtered procedures_merged_df
merged_df = pd.merge(merged_df, procedures_merged_df_filtered,
                                how='left',
                                left_on=['diag_subject_id', 'diag_hadm_id'],
                                right_on=['proc_subject_id', 'proc_hadm_id'])

# Drop redundant ID columns after merging procedures
merged_df = merged_df.drop(columns=['proc_subject_id', 'proc_hadm_id'])
# Delete filtered dataframe after merge
del procedures_merged_df_filtered

# Display the head and info of the final merged dataframe
print("Head of the final merged dataframe with sample:")
display(merged_df.head())

print("\nInfo of the final merged dataframe with sample:")
merged_df.info()

In [None]:
# Rename columns
merged_df = merged_df.rename(columns={
    "diag_subject_id": "subject_id",
    "diag_hadm_id": "hadm_id",
    "admit_admittime": "admittime",
    "admit_dischtime": "dischtime"
})

# Define the desired column order
desired_column_order = [
    "subject_id",
    "hadm_id",
    "admittime",
    "dischtime",
    "diag_icd_code",
    "diag_icd_version",
    "diag_long_title",
    "hcpcs_chartdate",
    "hcpcs_cd",
    "hcpcs_category",
    "hcpcs_short_description",
    "hcpcs_long_description",
    "proc_chartdate",
    "proc_icd_code",
    "proc_icd_version",
    "proc_long_title",
    "presc_drug",
    "presc_starttime"
]

# Reindex the dataframe to apply the desired column order
# Ensure all desired columns exist in the dataframe before reindexing
existing_columns_in_order = [col for col in desired_column_order if col in merged_df.columns]

merged_df = merged_df[existing_columns_in_order]


print("Head of the dataframe after renaming and reordering columns:")
display(merged_df.head())

print("\nInfo of the dataframe after renaming and reordering columns:")
merged_df.info()

# EDA

## Merged_df Descriptive Statistics & Missing Values

In [None]:
print("Head of the merged_df:")
display(merged_df.head())

print("\nInfo of the merged_df:")
merged_df.info()

print("\nDescriptive statistics for numerical columns:")
display(merged_df.describe())

print("\nMissing values per column:")
display(merged_df.isnull().sum())

## Explore Categorical Columns and Visualize Distributions

In [None]:
print("\nValue counts for 'diag_icd_code' (Top 20):")
display(merged_df['diag_icd_code'].value_counts().head(20))

print("\nValue counts for 'diag_long_title' (Top 20):")
display(merged_df['diag_long_title'].value_counts().head(20))

print("\nValue counts for 'presc_drug' (Top 20):")
display(merged_df['presc_drug'].value_counts().head(20))

In [None]:
# 3. Visualize distributions

import matplotlib.pyplot as plt
import seaborn as sns

# Distribution of 'diag_icd_version'
plt.figure(figsize=(8, 5))
sns.countplot(data=merged_df, x='diag_icd_version')
plt.title('Distribution of diag_icd_version')
plt.xlabel('ICD Version')
plt.ylabel('Count')
plt.show()

# Top N 'diag_icd_code'
plt.figure(figsize=(12, 6))
top_codes = merged_df['diag_icd_code'].value_counts().nlargest(10)
sns.barplot(x=top_codes.index, y=top_codes.values)
plt.title('Top 10 diag_icd_code Counts')
plt.xlabel('ICD Code')
plt.ylabel('Count')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

# Top N 'presc_drug'
plt.figure(figsize=(12, 6))
top_drugs = merged_df['presc_drug'].value_counts().nlargest(10)
sns.barplot(x=top_drugs.index, y=top_drugs.values)
plt.title('Top 10 presc_drug Counts')
plt.xlabel('Drug Name')
plt.ylabel('Count')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

# Top N 'hcpcs_cd'
plt.figure(figsize=(12, 6))
top_hcpcs = merged_df['hcpcs_cd'].value_counts().nlargest(10)
sns.barplot(x=top_hcpcs.index, y=top_hcpcs.values)
plt.title('Top 10 hcpcs_cd Counts')
plt.xlabel('HCPCS Code')
plt.ylabel('Count')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

# Top N 'proc_icd_code'
plt.figure(figsize=(12, 6))
top_procs = merged_df['proc_icd_code'].value_counts().nlargest(10)
sns.barplot(x=top_procs.index, y=top_procs.values)
plt.title('Top 10 proc_icd_code Counts')
plt.xlabel('Procedure ICD Code')
plt.ylabel('Count')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()


## Explore Relationships Between Variables

In [None]:
# Relationship between diag_icd_code and presc_drug
print("\nRelationship between diag_icd_code and presc_drug (Top 10 combinations):")
# Create a contingency table of the top combinations
diag_drug_counts = merged_df.groupby(['diag_icd_code', 'presc_drug']).size().reset_index(name='count')
# Sort and display the top combinations
display(diag_drug_counts.sort_values(by='count', ascending=False).head(10))

# Relationship between diag_icd_code and hcpcs_cd
print("\nRelationship between diag_icd_code and hcpcs_cd (Top 10 combinations):")
diag_hcpcs_counts = merged_df.groupby(['diag_icd_code', 'hcpcs_cd']).size().reset_index(name='count')
display(diag_hcpcs_counts.sort_values(by='count', ascending=False).head(10))

# Relationship between diag_icd_code and proc_icd_code
print("\nRelationship between diag_icd_code and proc_icd_code (Top 10 combinations):")
diag_proc_counts = merged_df.groupby(['diag_icd_code', 'proc_icd_code']).size().reset_index(name='count')
display(diag_proc_counts.sort_values(by='count', ascending=False).head(10))

# Relationship between presc_drug and hcpcs_cd
print("\nRelationship between presc_drug and hcpcs_cd (Top 10 combinations):")
drug_hcpcs_counts = merged_df.groupby(['presc_drug', 'hcpcs_cd']).size().reset_index(name='count')
display(drug_hcpcs_counts.sort_values(by='count', ascending=False).head(10))


# Relationship between presc_drug and proc_icd_code
print("\nRelationship between presc_drug and proc_icd_code (Top 10 combinations):")
drug_proc_counts = merged_df.groupby(['presc_drug', 'proc_icd_code']).size().reset_index(name='count')
display(drug_proc_counts.sort_values(by='count', ascending=False).head(10))


# CKD Guideline Recommended Treatmemt by Category & Dialysis ICD Codes

In [None]:
# Define the drug categories
ckd_grt = {
    "SGLT2 inhibitor": ["empagliflozin", "Empagliflozin", "jardiance", "Jardiance", "Jardiance (empagliflozin)", "Jardiance (Empaglifozin)",
                        "Jardiance 10mg", "Jardiance(empagliflozin)", "Jardiance(Empagliflozin)", "empagliflozin-linagliptin (GLYXAMBI)", "INV-Empagliflozin",
                        "dapagliflozin", "Dapagliflozin", "farxiga", "Farxiga", "linagliptin-empagliflozin (GLYXAMBI)",
                        "canagliflozin", "Canagliflozin", "invokana", "Invokana", "invokana (canagliflozin)",
                        "ertugliflozin", "Ertugliflozin", "Steglatro", "steglatro", "steglatro (ertugliflozin)",
                        "bexagliflozin", "Bexagliflozin", "Brenzavvy", "brenzavvy", "brenzavvy (bexagliflozin)",
                        "sotagliflozin", "Sotagliflozin", "Inpefa", "inpefa", "inpefa (sotagliflozin)"],
    "ACE inhibitor": ["lisinopril", "Lisinopril", "lisinopril-hydrochlorothiazide", "qbrelis", "Qbrelis", "zestril", "Zestril", "Zestril (lisinopril)",
                      "ramipril", "Ramipril", "Altace", "altace",
                      "captopril", "Captopril", "Capoten", "capoten", "Capoten (captopril)", "captopril-hydrochlorothiazide",
                      "perindopril", "Perindopril", "Perindopril Sodium", "Perindopril-hydrochlorothiazide", "aceon", "Aceon", "coversyl", "Coversyl", "coversum", "Coversum", "prexum", "Prexum", "prestarium", "Prestarium",
                      "benazepril", "Benazepril", "Lotensin", "lotensin", "Lotensin Sodium", "Lotensin-hydrochlorothiazide", "Lotensin (benazepril)",
                      "fosinopril", "Fosinopril", "Fosinopril Sodium", "Fosinopril-hydrochlorothiazide", "Fosinopril (fosinopril sodium)",
                      "moexipril", "Moexipril", "Moexipril Sodium", "Moexipril-hydrochlorothiazide", "Moexipril (moexipril sodium)",
                      "quinapril", "Quinapril", "Quinapril Maleate", "Quinapril Sodium", "Quinapril Sodium (quinapril)", "INV-Quinapril",
                      "trandolapril", "Trandolapril", "Tramadol", "Tramadol Sodium", "Tramadol-hydrochlorothiazide",
                      "enalapril", "Enalapril", "Enalapril Maleate", "enalapril maleate", "Enalaprilat", "INV-Enalapril", "Vasotec (enalapril)", "vasotec"],
    "ARBs": ["losartan", "Losartan", "Losartan Potassium", "Cozaar", "cozaar",
             "Hyzaar", "losartan-hydrochlorothiazide",
             "azilsartan", "Azilsartan", "Edarbi", "edarbi",
             "candesartan", "Candesartan", "Atacand", "atacand", "atacand (candesartan)",
             "irbesartan", "Irbesartan", "Avapro", "avapro",
             "olmesartan", "Olmesartan", "Benicar", "benicar",
             "telmisartan", "Telmisartan", "Micardis", "micardis",
             "valsartan", "Valsartan", "Diovan", "diovan", "valsartan sodium",
             "eprosartan", "Eprosartan", "Teveten", "teveten"],
    "Statin": ["Atorvastatin", "atorvastatin", "atorvastatin", "atorvastatin 40 mg",
               "Atorvastatin 40 mg", "Atorvastatin 40 mg", "atorvastatin 40 mg cap",
               "atorvastatin 40 mg capsule", "Atorvastatin 40 mg capsule", "atorvastatin 40 mg capsule",
               "Atorvastatin 40mg", "Atorvastatin 40mg Tab", "INV-Atorvastatin", "Lipitor", "lipitor",
               "Fluvastatin", "fluvastatin", "Fluvastatin Sodium",  "Lescol XL (fluvastatin XL)", "Lescol XL", "Lescol",
               "Lovastatin", "lovastatin", "Altoprev", "Lovastatin 10mg", "Lovastatin 20mg", "Mevacor",
               "pitavastatin", "pitavastatin calcium", "pitavastatin calcium (Livalo)", "Livalo (pitavastatin calcium)", "Livalo", "Zypitamag",
               "Pravastatin", "pravastatin", "Pravastatin Sodium", "Pravastatin Sodium 10mg",
               "rosuvastatin", "Rosuvastatin", "Rosuvastatin Calcium", "Rosuvastatin Calcium 10mg", "Crestor", "crestor",
               "Simvastatin", "simvastatin", "SImvastatin", "simvastatin", "Simvastatin", "simvastatin 40 mg", "Simvastatin 40 mg", "Zocor", "zocor"]
}

print("Defined CKD Guideline Recommended Treatment drug categories:")
display(ckd_grt)

In [None]:
# ICD-9 Procedure Codes: 39.95 (Hemodialysis), 39.99 (Other extracorporeal procedures)
# ICD-10 Procedure Codes: 5A1D000 (Hemodialysis), 5A1900W (Continuous renal replacement therapy [CRRT])
# HCPCS Codes: G0490-G0499 (ESRD related services), 90935-90970 (Hemodialysis services)
dialysis_proc_icd_codes = ['3995', '3999', '5A1D000', '5A1900W', '5498']
dialysis_hcpcs_codes = [f'G049{i}' for i in range(10)] + [f'909{i:02d}' for i in range(35, 71)] + ['G0257']

In [None]:
# Define the mapping of original codes to new categories
diagnosis_mapping = {
    'Hypertensive CKD (Stages 1-4)': ['I129', 'I130', '40390', 'I1310', '40310', '40491', '40300', '40490', '40411', '40401'],
    'Hypertensive CKD (Stages 5-End Stage)': ['40311', 'I120', '40391', 'I132', '40301', 'I1311', '40493', '40492', '40403'],
    'Diabetes with CKD': ['E1122', 'E1022', 'E0922', 'E1322', 'E0822'],
    'Stage 1 CKD': ['5851', 'N181'],
    'Stage 2 CKD': ['N182'],
    'Stage 3 CKD': ['N183', 'N1830', '5853', '5852', 'N1832'],
    'Stage 4 CKD': ['5854', 'N184'],
    'Stage 5 CKD': ['5855', 'N185'],
    'Unspecified CKD': ['N189', '5859']
}

# Create a reverse mapping from code to category
code_to_category = {}
for category, codes in diagnosis_mapping.items():
    for code in codes:
        code_to_category[code] = category

# Apply the mapping to create the new 'grouped_diagnosis' column in merged_df
merged_df['grouped_diagnosis'] = merged_df['diag_icd_code'].map(code_to_category)

# Check if there are any codes in the dataframe that were not in the mapping
unmapped_codes = merged_df[merged_df['grouped_diagnosis'].isna()]['diag_icd_code'].unique()
if unmapped_codes.size > 0:
    print(f"\nWarning: The following diag_icd_codes were not found in the provided mapping and will have NaN in 'grouped_diagnosis': {list(unmapped_codes)}")


print("\nMerged_df with new grouped_diagnosis column:")
display(merged_df.head())
merged_df.info()

print("\nValue counts for the new grouped_diagnosis column in merged_df:")
display(merged_df['grouped_diagnosis'].value_counts())

# Survival Analysis Data Preparation


## Data preparation for survival analysis

Prepare the `merged_df` for survival analysis by defining the event (dialysis) and the time to event (time from treatment start to dialysis or end of observation). This will involve identifying the first occurrence of a CKD guideline-recommended treatment and the first occurrence of a dialysis event for each subject.


In [None]:
# 1. Create a new column in merged_df called is_dialysis
# Check if proc_icd_code or hcpcs_cd is in the dialysis code lists
merged_df['is_dialysis'] = 0
merged_df.loc[merged_df['proc_icd_code'].isin(dialysis_proc_icd_codes), 'is_dialysis'] = 1
merged_df.loc[merged_df['hcpcs_cd'].isin(dialysis_hcpcs_codes), 'is_dialysis'] = 1

print("Added 'is_dialysis' column to merged_df:")
display(merged_df[['proc_icd_code', 'hcpcs_cd', 'is_dialysis']].head())


# 3. Define a list of all CKD GRT drugs
ckd_grt_drugs = [drug for sublist in ckd_grt.values() for drug in sublist]

# 4. Filter prescriptions to only include CKD GRT drugs
# Need to access the original prescriptions_df from dataframe_dict to get all prescriptions
prescriptions_df = dataframe_dict['prescriptions.csv']
ckd_grt_prescriptions = prescriptions_df[prescriptions_df['presc_drug'].str.lower().isin([d.lower() for d in ckd_grt_drugs])].copy()


# 5. Find the earliest presc_starttime for CKD guideline-recommended treatments for each subject
treatment_start_times = ckd_grt_prescriptions.groupby('presc_subject_id')['presc_starttime'].min().reset_index()
treatment_start_times = treatment_start_times.rename(columns={'presc_subject_id': 'subject_id', 'presc_starttime': 'treatment_starttime'})

print("\nEarliest treatment start time for each subject:")
display(treatment_start_times.head())


# 6. Filter merged_df for dialysis events
dialysis_events_df = merged_df[merged_df['is_dialysis'] == 1].copy()

# 7. Combine potential dialysis dates from procedures and hcpcs events
# Ensure chartdate columns are datetime objects
dialysis_events_df['proc_chartdate'] = pd.to_datetime(dialysis_events_df['proc_chartdate'])
dialysis_events_df['hcpcs_chartdate'] = pd.to_datetime(dialysis_events_df['hcpcs_chartdate'])
dialysis_events_df['event_date'] = dialysis_events_df[['proc_chartdate', 'hcpcs_chartdate']].min(axis=1)


# 8. Find the earliest event date for each subject
# Use the correct subject_id column name from merged_df
dialysis_event_dates = dialysis_events_df.groupby('subject_id')['event_date'].min().reset_index()
dialysis_event_dates = dialysis_event_dates.rename(columns={'event_date': 'dialysis_event_date'})


print("\nEarliest dialysis event date for each subject:")
display(dialysis_event_dates.head())

# 9. Create a new DataFrame at the subject level
# Start with unique subject_ids from the filtered diagnoses (using the original diag_subject_id before rename)
subject_survival_df = filtered_diagnoses_merged_df[['diag_subject_id']].drop_duplicates().rename(columns={'diag_subject_id': 'subject_id'})


# 10. Merge with treatment start times
subject_survival_df = pd.merge(subject_survival_df, treatment_start_times, on='subject_id', how='left')

# 11. Merge with dialysis event dates
subject_survival_df = pd.merge(subject_survival_df, dialysis_event_dates, on='subject_id', how='left')

print("\nSubject-level survival DataFrame before calculating time to event:")
display(subject_survival_df.head())
subject_survival_df.info()

# 12. Define the end of the observation period
end_of_observation = merged_df['admittime'].max() # Using admittime as a proxy for study end

# 13. Fill missing dialysis_event_date with the end of observation date for censored subjects
subject_survival_df['dialysis_event_date'] = subject_survival_df['dialysis_event_date'].fillna(end_of_observation)

# 14. Calculate time_to_event in days
subject_survival_df['time_to_event'] = (subject_survival_df['dialysis_event_date'] - subject_survival_df['treatment_starttime']).dt.days

# 15. Create event column (1 if dialysis_event_date is not the end_of_observation date, 0 otherwise)
subject_survival_df['event'] = (subject_survival_df['dialysis_event_date'] != end_of_observation).astype(int)

# 16. Drop rows where treatment_starttime is missing (subjects who didn't receive GRT treatment in the dataset)
subject_survival_df = subject_survival_df.dropna(subset=['treatment_starttime']).copy()


print("\nSubject-level survival DataFrame with time_to_event and event:")
display(subject_survival_df.head())
subject_survival_df.info()


# 17. Create the final survival analysis DataFrame
survival_analysis_df = subject_survival_df[['subject_id', 'treatment_starttime', 'dialysis_event_date', 'time_to_event', 'event']].copy()

# 18. Get unique grouped diagnoses per subject from the original merged_df
# Use the subject_id column name from merged_df
subject_grouped_diagnoses = merged_df[['subject_id', 'grouped_diagnosis']].drop_duplicates()


# 19. Merge unique grouped diagnoses into the survival dataframe
survival_analysis_df = pd.merge(survival_analysis_df, subject_grouped_diagnoses, on='subject_id', how='left')

# 20. Drop rows where grouped_diagnosis is missing (from the merge if a subject had no grouped diagnosis)
survival_analysis_df = survival_analysis_df.dropna(subset=['grouped_diagnosis']).copy()

# 21. Delete intermediate dataframes to free up memory
del treatment_start_times
del dialysis_event_dates
del subject_survival_df
del ckd_grt_prescriptions
del dialysis_events_df
del subject_grouped_diagnoses
del prescriptions_df

## Feature engineering
Create features based on the CKD guideline-recommended treatments, considering whether a patient is on *any* such treatment, or perhaps indicators for specific classes of treatments (SGLT2 inhibitors, ACE inhibitors, etc.).


In [None]:
# 1. Create a binary feature for whether a subject received ANY CKD guideline-recommended treatment.
# This is implicitly handled by the fact that we dropped rows with missing 'treatment_starttime'
# in the previous subtask. The presence of a row in survival_analysis_df means they had a
# treatment_starttime, indicating they received at least one GRT.

# 2. For each category of CKD guideline-recommended treatment, create a binary feature.
# Get a fresh copy of merged_df to avoid modifying it with temporary columns
temp_merged_df = merged_df.copy()

# Create a list of all CKD GRT drugs (case-insensitive)
ckd_grt_drugs_lower = [drug.lower() for sublist in ckd_grt.values() for drug in sublist]

# Filter prescriptions in the temporary merged_df to only include CKD GRT drugs
# and ensure 'presc_drug' column is treated case-insensitively.
ckd_grt_prescriptions_temp = temp_merged_df[temp_merged_df['presc_drug'].str.lower().isin(ckd_grt_drugs_lower)].copy()

# Filter these prescriptions to keep only those that occurred at or after the subject's
# earliest treatment_starttime. This requires merging with survival_analysis_df to get the start time.
category_prescriptions_after_start = pd.merge(
    ckd_grt_prescriptions_temp,
    survival_analysis_df[['subject_id', 'treatment_starttime']],
    on='subject_id',
    how='inner' # Use inner merge to keep only subjects in survival_analysis_df
)

# Filter prescriptions that are on or after the treatment_starttime
category_prescriptions_after_start = category_prescriptions_after_start[
    category_prescriptions_after_start['presc_starttime'] >= category_prescriptions_after_start['treatment_starttime']
].copy()

# Create binary features for each category
for category, drugs in ckd_grt.items():
    # Get the drugs for the current category (case-insensitive)
    category_drugs_lower = [drug.lower() for drug in drugs]

    # Filter prescriptions after start time for drugs in this category
    subjects_on_category_treatment_df = category_prescriptions_after_start[
        category_prescriptions_after_start['presc_drug'].str.lower().isin(category_drugs_lower)
    ].copy()

    # Identify subjects who received at least one drug from this category after their treatment_starttime
    subjects_on_category_treatment = subjects_on_category_treatment_df['subject_id'].unique()

    # Create the binary feature column in survival_analysis_df
    feature_name = f'on_{category.lower().replace(" ", "_")}'
    survival_analysis_df[feature_name] = survival_analysis_df['subject_id'].isin(subjects_on_category_treatment).astype(int)

# 3. Display the head of survival_analysis_df after adding the new features.
print("Added binary features for CKD guideline-recommended treatment categories to survival_analysis_df:")
display(survival_analysis_df.head())

# 4. Display the info of survival_analysis_df after adding new features.
print("\nInfo of survival_analysis_df after adding treatment category features:")
survival_analysis_df.info()

# 5. Display value counts for each of the newly created binary treatment category features.
print("\nValue counts for new treatment category features:")
for category in ckd_grt.keys():
    feature_name = f'on_{category.lower().replace(" ", "_")}'
    print(f"\nValue counts for '{feature_name}':")
    display(survival_analysis_df[feature_name].value_counts())

# 6. Delete intermediate dataframes to free up memory.
del temp_merged_df
del ckd_grt_prescriptions_temp
del category_prescriptions_after_start
del subjects_on_category_treatment_df

# Check SGLT2 binary entries

In [None]:
# 1. Access the prescriptions.csv DataFrame from dataframe_dict
raw_prescriptions_df = dataframe_dict['prescriptions.csv'].copy()

# 2. Extract the list of SGLT2 inhibitor drugs from the ckd_grt dictionary
sglt2_inhibitors = ckd_grt['SGLT2 inhibitor']
sglt2_inhibitors_lower = [drug.lower() for drug in sglt2_inhibitors]

# 3. Create a boolean mask for SGLT2 inhibitor drugs (case-insensitive)
sglt2_mask = raw_prescriptions_df['presc_drug'].str.lower().isin(sglt2_inhibitors_lower)

# 4. Filter raw_prescriptions_df using this mask
sglt2_prescriptions_raw_df = raw_prescriptions_df[sglt2_mask].copy()

print("SGLT2 Inhibitor Prescriptions in Raw Data (Head):")
# 5. Print the head of this filtered DataFrame
display(sglt2_prescriptions_raw_df.head())

print("\nInfo of SGLT2 Inhibitor Prescriptions in Raw Data:")
# 6. Print the info of this filtered DataFrame
sglt2_prescriptions_raw_df.info()

print("\nValue Counts for SGLT2 Inhibitor Drugs in Raw Data:")
# 7. Display the value counts of the 'presc_drug' column
display(sglt2_prescriptions_raw_df['presc_drug'].value_counts())

del raw_prescriptions_df # Clean up temporary DataFrame

In [None]:
import pandas as pd

# 1. Get unique subject_ids from the raw SGLT2 inhibitor prescriptions
sglt2_subjects_raw = sglt2_prescriptions_raw_df['presc_subject_id'].unique()

# 2. Filter survival_analysis_df to include only these subjects
filtered_survival_for_sglt2 = survival_analysis_df[
    survival_analysis_df['subject_id'].isin(sglt2_subjects_raw)
].copy()

# 3. Merge sglt2_prescriptions_raw_df with filtered_survival_for_sglt2
#    on subject_id to bring in the treatment_starttime for these subjects.

merged_sglt2_check = pd.merge(
    sglt2_prescriptions_raw_df.rename(columns={'presc_subject_id': 'subject_id'}),
    filtered_survival_for_sglt2[['subject_id', 'treatment_starttime']],
    on='subject_id',
    how='inner' # Keep only subjects present in both
)

# 4. Check if presc_starttime (of the SGLT2 inhibitor) is >= treatment_starttime
merged_sglt2_check['is_sglt2_after_treatment_start'] = (
    merged_sglt2_check['presc_starttime'] >= merged_sglt2_check['treatment_starttime']
)

print("Head of merged data for SGLT2 check with treatment_starttime:")
display(merged_sglt2_check.head())

# 5. Count how many unique subjects have at least one SGLT2 inhibitor prescription
#    that occurred at or after their overall GRT treatment_starttime.
subjects_with_valid_sglt2_post_start = merged_sglt2_check[
    merged_sglt2_check['is_sglt2_after_treatment_start']
]['subject_id'].nunique()

print(f"\nNumber of unique subjects with SGLT2 inhibitor prescription >= overall GRT treatment_starttime: {subjects_with_valid_sglt2_post_start}")

# 6. Compare this with the count of '1's in the 'on_sglt2_inhibitor' column in survival_analysis_df
actual_on_sglt2_count = survival_analysis_df['on_sglt2_inhibitor'].sum()

print(f"Number of subjects marked 'on_sglt2_inhibitor' in survival_analysis_df: {actual_on_sglt2_count}")

if subjects_with_valid_sglt2_post_start == actual_on_sglt2_count:
    print("\nThe counts match, indicating the logic for 'on_sglt2_inhibitor' is consistent with prescriptions after treatment_starttime.")
else:
    print("\nThe counts do NOT match, indicating a potential discrepancy in how 'on_sglt2_inhibitor' was created or filtered.")

# Clean up temporary DataFrame
del merged_sglt2_check
del filtered_survival_for_sglt2
del sglt2_subjects_raw

In [None]:
import pandas as pd

# Extract the list of SGLT2 inhibitor drugs from the ckd_grt dictionary (case-insensitive)
sglt2_inhibitors = ckd_grt['SGLT2 inhibitor']
sglt2_inhibitors_lower = [drug.lower() for drug in sglt2_inhibitors]

# Check if any SGLT2 inhibitor drugs are present in the 'presc_drug' column of merged_df
# Note: merged_df contains one row per diagnosis, with one associated prescription if available
sglt2_in_merged_df = merged_df[merged_df['presc_drug'].str.lower().isin(sglt2_inhibitors_lower)].copy()

print("Head of merged_df rows containing SGLT2 inhibitors:")
display(sglt2_in_merged_df.head())

print("\nNumber of rows in merged_df containing SGLT2 inhibitors:" \
      f" {len(sglt2_in_merged_df)}")

# Get the number of unique subjects from these rows
unique_subjects_with_sglt2_in_merged_df = sglt2_in_merged_df['subject_id'].nunique()
print("Number of unique subjects in merged_df with SGLT2 inhibitors:"\
      f" {unique_subjects_with_sglt2_in_merged_df}")


## Analyze SGLT2 Inhibitor Presence in ckd_grt_prescriptions


In [None]:
import pandas as pd

# Re-create ckd_grt_prescriptions as it was deleted previously
prescriptions_df = dataframe_dict['prescriptions.csv']
ckd_grt_drugs = [drug for sublist in ckd_grt.values() for drug in sublist]
ckd_grt_prescriptions = prescriptions_df[prescriptions_df['presc_drug'].str.lower().isin([d.lower() for d in ckd_grt_drugs])].copy()

# 1. Extract the list of SGLT2 inhibitor drugs from the ckd_grt dictionary
sglt2_inhibitors = ckd_grt['SGLT2 inhibitor']
sglt2_inhibitors_lower = [drug.lower() for drug in sglt2_inhibitors]

# 2. Filter ckd_grt_prescriptions using a case-insensitive match for SGLT2 inhibitors
sglt2_prescriptions_filtered_ckd_grt = ckd_grt_prescriptions[
    ckd_grt_prescriptions['presc_drug'].str.lower().isin(sglt2_inhibitors_lower)
].copy()

print("SGLT2 Inhibitor Prescriptions in ckd_grt_prescriptions (Head):")
# 3. Display the head of this filtered DataFrame
display(sglt2_prescriptions_filtered_ckd_grt.head())

print("\nInfo of SGLT2 Inhibitor Prescriptions in ckd_grt_prescriptions:")
# 4. Print the .info() of this filtered DataFrame
sglt2_prescriptions_filtered_ckd_grt.info()

print("\nValue Counts for SGLT2 Inhibitor Drugs in ckd_grt_prescriptions:")
# 5. Display the value counts of the 'presc_drug' column
display(sglt2_prescriptions_filtered_ckd_grt['presc_drug'].value_counts())

# Clean up temporary dataframes
del prescriptions_df
del ckd_grt_prescriptions


In [None]:
import pandas as pd

# Re-create prescriptions_df_filtered if it was deleted
prescriptions_df = dataframe_dict['prescriptions.csv'].copy()
unique_ids = filtered_diagnoses_merged_df[['diag_subject_id', 'diag_hadm_id']].drop_duplicates()
prescriptions_df_filtered_recreated = pd.merge(
    unique_ids,
    prescriptions_df.rename(columns={'presc_subject_id': 'diag_subject_id', 'presc_hadm_id': 'diag_hadm_id'}),
    on=['diag_subject_id', 'diag_hadm_id'],
    how='left'
)
prescriptions_df_filtered_recreated = prescriptions_df_filtered_recreated.rename(columns={'diag_subject_id': 'presc_subject_id', 'diag_hadm_id': 'presc_hadm_id'})

# Extract the list of SGLT2 inhibitor drugs from the ckd_grt dictionary (case-insensitive)
sglt2_inhibitors = ckd_grt['SGLT2 inhibitor']
sglt2_inhibitors_lower = [drug.lower() for drug in sglt2_inhibitors]

# Filter prescriptions_df_filtered_recreated for SGLT2 inhibitors
sglt2_in_prescriptions_filtered = prescriptions_df_filtered_recreated[
    prescriptions_df_filtered_recreated['presc_drug'].str.lower().isin(sglt2_inhibitors_lower)
].copy()

print("SGLT2 Inhibitor Prescriptions in prescriptions_df_filtered (Head):")
display(sglt2_in_prescriptions_filtered.head())

print("\nNumber of rows with SGLT2 Inhibitor Prescriptions in prescriptions_df_filtered: " \
      f"{len(sglt2_in_prescriptions_filtered)}")

print("\nUnique subjects with SGLT2 Inhibitor Prescriptions in prescriptions_df_filtered: " \
      f"{sglt2_in_prescriptions_filtered['presc_subject_id'].nunique()}")

# Clean up temporary dataframes
del prescriptions_df
del prescriptions_df_filtered_recreated
del unique_ids

## Verify SGLT2 inhibitor presence in filtered_diagnoses_merged patients


In [None]:
import pandas as pd

# 1. Access the filtered_diagnoses_merged_df DataFrame and sglt2_prescriptions_raw_df
# These are already available in the kernel state.

# 2. Create a DataFrame `unique_sampled_ids`
unique_sampled_ids = filtered_diagnoses_merged_df[['diag_subject_id', 'diag_hadm_id']].drop_duplicates().copy()
unique_sampled_ids = unique_sampled_ids.rename(columns={'diag_subject_id': 'subject_id', 'diag_hadm_id': 'hadm_id'})

# 3. Create a DataFrame `sglt2_unique_presc_ids`
sglt2_unique_presc_ids = sglt2_prescriptions_raw_df[['presc_subject_id', 'presc_hadm_id']].drop_duplicates().copy()
sglt2_unique_presc_ids = sglt2_unique_presc_ids.rename(columns={'presc_subject_id': 'subject_id', 'presc_hadm_id': 'hadm_id'})

# 4. Perform an inner merge to find overlapping (subject, admission) pairs
merged_sglt2_in_sampled = pd.merge(
    unique_sampled_ids,
    sglt2_unique_presc_ids,
    on=['subject_id', 'hadm_id'],
    how='inner'
)

print("--- SGLT2 Inhibitor Presence in Patients ---")

# 5. Print the number of rows in the resulting merged DataFrame
num_overlapping_pairs = len(merged_sglt2_in_sampled)
print(f"Number of (subject, admission) pairs from filtered_diagnoses_merged_df with SGLT2 inhibitor prescriptions: {num_overlapping_pairs}")

# 6. Display the head of this merged DataFrame if it contains any entries
if num_overlapping_pairs > 0:
    print("Head of overlapping (subject, admission) pairs:")
    display(merged_sglt2_in_sampled.head())
else:
    print("No overlapping (subject, admission) pairs found.")

# 7. Print the number of unique subject_ids found in this merged DataFrame
unique_subjects_with_sglt2_in_sampled = merged_sglt2_in_sampled['subject_id'].nunique()
print(f"Number of unique subjects from filtered_diagnoses_merged_df with SGLT2 inhibitor prescriptions: {unique_subjects_with_sglt2_in_sampled}")

# 8. Delete any temporary DataFrames
del unique_sampled_ids
del sglt2_unique_presc_ids
del merged_sglt2_in_sampled

## Address multicollinearity and complete separation

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Select predictor columns for correlation analysis
# Include the 'on_' treatment features and the one-hot encoded grouped diagnoses.
# Exclude the original 'grouped_diagnosis' object column and identifier columns.
predictor_cols = [col for col in survival_analysis_df.columns if col.startswith('on_') or col.startswith('diag_')]

# Ensure all predictor columns are numeric (convert boolean to int if necessary)
correlation_df = survival_analysis_df[predictor_cols].copy()
for col in correlation_df.columns:
    if correlation_df[col].dtype == 'bool':
        correlation_df[col] = correlation_df[col].astype(int)


# Calculate the correlation matrix
correlation_matrix = correlation_df.corr()

print("Correlation matrix for predictor variables (including one-hot encoded diagnoses):")
display(correlation_matrix)

# Visualize the correlation matrix using a heatmap
plt.figure(figsize=(15, 12)) # Adjusted figure size for more predictors
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt=".2f", annot_kws={"size": 8})
plt.title('Correlation Matrix of Predictor Variables')
plt.tight_layout()
plt.show()

# Identify highly correlated pairs (e.g., absolute correlation > 0.7)
print("\nHighly correlated pairs (absolute correlation > 0.7):")
high_corr_pairs = {}
# Iterate through the upper triangle of the correlation matrix
for i in range(len(correlation_matrix.columns)):
    for j in range(i + 1, len(correlation_matrix.columns)): # Start from i+1 to avoid duplicates and diagonal
        col1 = correlation_matrix.columns[i]
        col2 = correlation_matrix.columns[j]
        corr_value = correlation_matrix.iloc[i, j]
        if abs(corr_value) > 0.7:
            high_corr_pairs[(col1, col2)] = corr_value

if high_corr_pairs:
    for pair, corr_value in high_corr_pairs.items():
        print(f"  {pair[0]} and {pair[1]}: {corr_value:.2f}")
else:
    print("  No highly correlated pairs found (absolute correlation > 0.7).")

# Examine the 'on_arbs' column for complete separation
print("\nValue counts for 'on_arbs':")
arbs_value_counts = survival_analysis_df['on_arbs'].value_counts()
display(arbs_value_counts)

print("\nRelationship between 'on_arbs' and 'event' (for complete separation check):")
arbs_event_crosstab = pd.crosstab(survival_analysis_df['on_arbs'], survival_analysis_df['event'])
display(arbs_event_crosstab)

# Confirm complete separation if all events (event=1) occur in only one category of 'on_arbs'
# Based on previous value counts, it's likely that 'on_arbs' has no '1' values,
# which would lead to complete separation.
if 1 in arbs_event_crosstab.index:
    if arbs_event_crosstab.loc[1, 0] == 0 and arbs_event_crosstab.loc[1, 1] > 0:
        print("\nComplete separation likely exists for 'on_arbs' (all events in category 1).")
    elif arbs_event_crosstab.loc[1, 1] == 0 and arbs_event_crosstab.loc[1, 0] > 0:
         print("\nComplete separation likely exists for 'on_arbs' (no events in category 1).")
    else:
         print("\nComplete separation does not appear to exist for 'on_arbs'.")
elif 0 in arbs_event_crosstab.index and 1 not in arbs_event_crosstab.index:
     print("\n'on_arbs' only has one category present in the data (0), which will cause issues for modeling (complete separation).")
else:
    print("\nCould not determine complete separation for 'on_arbs' based on crosstab.")

del correlation_df # Clean up temporary dataframe

# Fit Standard Cox Proportional Hazards Model

In [None]:
!pip install lifelines

In [None]:
from lifelines import CoxPHFitter

# Prepare data for Cox model
# Need to use the survival_analysis_df which contains time_to_event, event, and the treatment and diagnosis features.
# Drop columns not needed for the model (subject_id, date columns, grouped_diagnosis object)
cox_data = survival_analysis_df.drop(columns=['subject_id', 'treatment_starttime', 'dialysis_event_date', 'grouped_diagnosis']).copy()

# Ensure all predictor columns are numeric
for col in cox_data.columns:
    if cox_data[col].dtype == 'bool':
        cox_data[col] = cox_data[col].astype(int)


# Fit the Cox proportional hazards model
cph = CoxPHFitter()

try:
    cph.fit(cox_data, duration_col='time_to_event', event_col='event')

    print("\nStandard Cox Proportional Hazards Model Summary:")
    cph.print_summary()

    # Evaluate the fitted model using concordance index
    c_index_cox = cph.concordance_index_
    print(f"\nConcordance Index of the Standard Cox model: {c_index_cox:.4f}")

except Exception as e:
    print(f"An error occurred during standard Cox model fitting: {e}")
    print("The model may not have converged due to issues like complete separation or multicollinearity.")

# Explore alternative survival models


In [None]:
from lifelines import CoxPHFitter

# Prepare data for Cox model
# Need to use the survival_analysis_df which contains time_to_event, event, and the treatment and diagnosis features.
# Drop columns not needed for the model (subject_id, date columns, grouped_diagnosis object)
cox_data = survival_analysis_df.drop(columns=['subject_id', 'treatment_starttime', 'dialysis_event_date', 'grouped_diagnosis']).copy()

# Ensure all predictor columns are numeric
for col in cox_data.columns:
    if cox_data[col].dtype == 'bool':
        cox_data[col] = cox_data[col].astype(int)


# Fit a Penalized Cox Proportional Hazards Model (with L2 regularization)
cph_penalized = CoxPHFitter(penalizer=0.1)

try:
    cph_penalized.fit(cox_data, duration_col='time_to_event', event_col='event')

    print("\nPenalized Cox Proportional Hazards Model Summary (L2 Regularization):")
    cph_penalized.print_summary()

    # Evaluate the fitted model using concordance index
    c_index_penalized = cph_penalized.concordance_index_
    print(f"\nConcordance Index of the Penalized Cox model: {c_index_penalized:.4f}")

except Exception as e:
    print(f"An error occurred during penalized Cox model fitting: {e}")

## Feature Engineering: Interaction Terms and Number of Treatments

In [None]:
# 1. Create a feature for the number of CKD guideline-recommended treatment categories each subject is on.
# Sum the binary columns for each treatment category
survival_analysis_df['num_ckd_grt_categories'] = survival_analysis_df[[
    'on_sglt2_inhibitor', 'on_ace_inhibitor', 'on_arbs', 'on_statin'
]].sum(axis=1)

print("\nAdded 'num_ckd_grt_categories' feature:")
display(survival_analysis_df[['subject_id', 'num_ckd_grt_categories']].head())

print("\nValue counts for 'num_ckd_grt_categories':")
display(survival_analysis_df['num_ckd_grt_categories'].value_counts())


# 2. Create interaction terms between grouped CKD diagnosis categories and individual treatment categories.
survival_analysis_df = pd.get_dummies(survival_analysis_df, columns=['grouped_diagnosis'], drop_first=False)


# Create interaction terms by multiplying the one-hot encoded diagnosis columns with the treatment columns
diagnosis_categories = [col for col in survival_analysis_df.columns if col.startswith('grouped_diagnosis_')]
treatment_categories = ['on_sglt2_inhibitor', 'on_ace_inhibitor', 'on_arbs', 'on_statin']

for diag_col in diagnosis_categories:
    for treat_col in treatment_categories:
        interaction_col_name = f'{diag_col}_x_{treat_col}'
        survival_analysis_df[interaction_col_name] = survival_analysis_df[diag_col] * survival_analysis_df[treat_col]

print("\nAdded interaction terms to survival_analysis_df:")
display(survival_analysis_df.head())

print("\nInfo of survival_analysis_df after adding new features:")
survival_analysis_df.info()

## Fit Penalized Cox Model with Interaction Terms and Number of Treatments

In [None]:
from lifelines import CoxPHFitter

# Prepare data for the penalized Cox model with new features
# Drop columns not needed for the model (subject_id, date columns)
# The original 'grouped_diagnosis' column has already been one-hot encoded and dropped in the previous step.
cox_data_interactions = survival_analysis_df.drop(columns=['subject_id', 'treatment_starttime', 'dialysis_event_date']).copy()

# Ensure all predictor columns are numeric
for col in cox_data_interactions.columns:
    if cox_data_interactions[col].dtype == 'bool':
        cox_data_interactions[col] = cox_data_interactions[col].astype(int)


# Fit a Penalized Cox Proportional Hazards Model (with L2 regularization)
cph_penalized_interactions = CoxPHFitter(penalizer=0.1)

try:
    cph_penalized_interactions.fit(cox_data_interactions, duration_col='time_to_event', event_col='event')

    print("\nPenalized Cox Proportional Hazards Model Summary (with Interaction Terms and Number of Treatments):")
    cph_penalized_interactions.print_summary()

    # Evaluate the fitted model using concordance index
    c_index_penalized_interactions = cph_penalized_interactions.concordance_index_
    print(f"\nConcordance Index of the Penalized Cox model (with Interactions): {c_index_penalized_interactions:.4f}")

except Exception as e:
    print(f"An error occurred during penalized Cox model fitting with interaction terms: {e}")

## Examine Columns with Low Variance and Potential Complete Separation

In [None]:
# Examine the variance of the identified low-variance columns
low_variance_cols = [
    'grouped_diagnosis_Hypertensive CKD (Stages 5-End Stage)_x_on_sglt2_inhibitor'
]

print("Variance of identified low-variance columns:")
display(survival_analysis_df[low_variance_cols].var())

# Examine the relationship with the event column for complete separation
print("\nRelationship between low-variance columns and 'event' (for complete separation check):")

for col in low_variance_cols:
    print(f"\nCrosstab for '{col}' and 'event':")
    crosstab_result = pd.crosstab(survival_analysis_df[col], survival_analysis_df['event'])
    display(crosstab_result)

    # Check for complete separation: if all events (event=1) occur in only one category of the feature
    if 1 in crosstab_result.index:
        if crosstab_result.loc[1, 0] == 0 and crosstab_result.loc[1, 1] > 0:
            print(f"Complete separation likely exists for '{col}' (all events in category 1).")
        elif crosstab_result.loc[1, 1] == 0 and crosstab_result.loc[1, 0] > 0:
             print(f"Complete separation likely exists for '{col}' (no events in category 1).")
        else:
             print(f"Complete separation does not appear to exist for '{col}'.")
    elif 0 in crosstab_result.index and 1 not in crosstab_result.index:
         print(f"'{col}' only has one category present in the data (0), which will cause issues for modeling (complete separation).")
    else:
        print(f"Could not determine complete separation for '{col}' based on crosstab.")

## Remove Problematic Features and Refit Penalized Cox Model

In [None]:
# Identify columns to drop due to complete separation
columns_to_drop = [
    'grouped_diagnosis_Hypertensive CKD (Stages 5-End Stage)_x_on_sglt2_inhibitor'
]

# Drop the problematic columns from the dataframe used for modeling
cox_data_interactions_cleaned = cox_data_interactions.drop(columns=columns_to_drop).copy()

print("Dropped problematic columns due to complete separation.")
print(f"Remaining columns for modeling: {cox_data_interactions_cleaned.columns.tolist()}")


# Fit a Penalized Cox Proportional Hazards Model (with L2 regularization) using the cleaned data.
cph_penalized_cleaned = CoxPHFitter(penalizer=0.1)

try:
    cph_penalized_cleaned.fit(cox_data_interactions_cleaned, duration_col='time_to_event', event_col='event')

    print("\nPenalized Cox Proportional Hazards Model Summary (Cleaned Data):")
    cph_penalized_cleaned.print_summary()

    # Evaluate the fitted model using concordance index
    c_index_penalized_cleaned = cph_penalized_cleaned.concordance_index_
    print(f"\nConcordance Index of the Penalized Cox model (Cleaned Data): {c_index_penalized_cleaned:.4f}")

except Exception as e:
    print(f"An error occurred during penalized Cox model fitting with cleaned data: {e}")

# Penalized Cox post cleaning
exclude previously identified problematic interaction terms ( 'grouped_diagnosis_Hypertensive CKD (Stages 5-End Stage)_x_on_sglt2_inhibitor', 'grouped_diagnosis_Stage 1 CKD_x_on_sglt2_inhibitor', 'grouped_diagnosis_Stage 5 CKD_x_on_sglt2_inhibitor')

identify & remove any columns with zero or extremely low variance

## Consolidate and Prepare Data for Modeling

In [None]:
import pandas as pd

# 1. Define a list of columns to be excluded from survival_analysis_df
columns_to_drop_modeling = [
    'subject_id',
    'treatment_starttime',
    'dialysis_event_date',
    'grouped_diagnosis_Hypertensive CKD (Stages 5-End Stage)_x_on_sglt2_inhibitor',
    'grouped_diagnosis_Stage 5 CKD_x_on_sglt2_inhibitor'
]

# 2. Create cox_data_full by dropping these identified columns from survival_analysis_df.
cox_data_full = survival_analysis_df.drop(columns=columns_to_drop_modeling).copy()

# 3. Iterate through all columns in the new cox_data_full DataFrame. If a column's data type is boolean, convert it to an integer type (0 or 1).
for col in cox_data_full.columns:
    if cox_data_full[col].dtype == 'bool':
        cox_data_full[col] = cox_data_full[col].astype(int)

print("Recreated cox_data_full DataFrame:")
# 4. Display the first few rows of the prepared cox_data_full DataFrame using display().
display(cox_data_full.head())

print("\nInfo of the recreated cox_data_full DataFrame:")
# 5. Print the .info() of the cox_data_full DataFrame to show its structure and data types.
cox_data_full.info()

In [None]:
import numpy as np
from lifelines import CoxPHFitter

# Identify predictor columns (excluding 'time_to_event' and 'event')
predictor_cols = [col for col in cox_data_full.columns if col not in ['time_to_event', 'event']]

problematic_cols = []
print("\n--- Checking for low variance and complete separation ---")

# Iterate through predictor columns to check variance and complete separation
for col in predictor_cols:
    col_variance = cox_data_full[col].var()
    if col_variance < 1e-6:
        print(f"\nColumn '{col}' has very low variance: {col_variance:.2e}")

        # Perform crosstab with 'event' to confirm complete separation
        crosstab_result = pd.crosstab(cox_data_full[col], cox_data_full['event'])
        print(f"Crosstab for '{col}' and 'event':")
        display(crosstab_result)

        # Check for complete separation logic
        if 1 in crosstab_result.index:
            if (crosstab_result.loc[1, 0] == 0 and crosstab_result.loc[1, 1] > 0) or \
               (crosstab_result.loc[1, 1] == 0 and crosstab_result.loc[1, 0] > 0):
                print(f"Complete separation confirmed for '{col}'. Adding to problematic list.")
                problematic_cols.append(col)
            else:
                print(f"Low variance for '{col}' but complete separation is not evident from crosstab.")
        elif crosstab_result.shape[0] == 1:
             print(f"Column '{col}' has only one unique value. Adding to problematic list.")
             problematic_cols.append(col)
        else:
             print(f"Low variance for '{col}' but complete separation check inconclusive.")

# Remove duplicate columns from problematic_cols if any
problematic_cols = list(set(problematic_cols))

if problematic_cols:
    print(f"\nIdentified problematic columns to drop: {problematic_cols}")
    cox_data_final = cox_data_full.drop(columns=problematic_cols).copy()
else:
    print("\nNo additional problematic columns identified based on low variance/complete separation.")
    cox_data_final = cox_data_full.copy()

print("\nRemaining columns for modeling (cox_data_final):")
print(cox_data_final.columns.tolist())

# Add a small positive value to any zero durations if they exist, as required by lifelines
zero_duration_mask = cox_data_final['time_to_event'] <= 0
if zero_duration_mask.any():
    epsilon = np.finfo(float).eps
    cox_data_final.loc[zero_duration_mask, 'time_to_event'] = epsilon
    print(f"Added {epsilon} to {zero_duration_mask.sum()} zero or negative durations in 'time_to_event'.")


# Fit a Penalized Cox Proportional Hazards Model (with L2 regularization) using the cleaned data
cph_final = CoxPHFitter(penalizer=0.1) # Using penalizer=0.1 as it previously showed good performance

try:
    cph_final.fit(cox_data_final, duration_col='time_to_event', event_col='event')

    print("\n--- Penalized Cox Proportional Hazards Model Summary (Final Cleaned Data) ---")
    cph_final.print_summary()

    # Evaluate the fitted model using concordance index
    c_index_final = cph_final.concordance_index_
    print(f"\nConcordance Index of the Penalized Cox model (Final Cleaned Data): {c_index_final:.4f}")

except Exception as e:
    print(f"\nAn error occurred during penalized Cox model fitting with final cleaned data: {e}")
    c_index_final = None

# Summarize the process and results
print("\n--- Summary of Process and Results ---")
print("1. The `cox_data_full` DataFrame was prepared by excluding identifier/date columns and previously identified problematic interaction terms.")
print("2. An iterative check for extremely low variance columns was performed, and complete separation was confirmed for these using crosstabs against the 'event' column.")
print(f"3. The following columns were identified and removed due to low variance/complete separation: {problematic_cols if problematic_cols else 'None'}.")
print("4. A small positive epsilon value was added to any zero or negative 'time_to_event' durations to ensure compatibility with `lifelines` models.")
print(f"5. A penalized CoxPHFitter model (L2 regularization, penalizer=0.1) was fitted using the `cox_data_final` dataset.")
if c_index_final is not None:
    print(f"6. The model converged successfully with a Concordance Index of {c_index_final:.4f}, indicating good predictive performance.")
    print("7. The model summary provides insights into the significance and hazard ratios of the remaining predictors. Significance of predictors can now be interpreted more reliably as convergence issues related to separation have been addressed.")
else:
    print("6. The model did not converge successfully even after cleaning, suggesting further investigation into multicollinearity or data structure may be needed.")




# Simpler penalized cox model (main effects only)


In [None]:
from lifelines import CoxPHFitter

# 1. Create a new DataFrame cox_data_main_effects by selecting the relevant columns
# Main effect predictor columns: binary treatment features and one-hot encoded diagnosis features
main_effect_cols = [
    'time_to_event',
    'event',
    'on_sglt2_inhibitor',
    'on_ace_inhibitor',
    'on_arbs',
    'on_statin'
]

# Add the one-hot encoded grouped diagnosis columns
diagnosis_cols_one_hot = [col for col in survival_analysis_df.columns if col.startswith('grouped_diagnosis_')]
main_effect_cols.extend(diagnosis_cols_one_hot)

# Select these columns from survival_analysis_df
cox_data_main_effects = survival_analysis_df[main_effect_cols].copy()

# 2. Ensure all selected predictor columns are numeric (convert boolean to int if necessary)
for col in cox_data_main_effects.columns:
    if cox_data_main_effects[col].dtype == 'bool':
        cox_data_main_effects[col] = cox_data_main_effects[col].astype(int)

print("Prepared data for main effects Cox model:")
display(cox_data_main_effects.head())
cox_data_main_effects.info()

# 3. Instantiate a CoxPHFitter object with a chosen penalizer value
# Based on the previous experimentation, penalizer=0.1 performed well.
chosen_penalizer = 0.1
cph_main_effects = CoxPHFitter(penalizer=chosen_penalizer)

# 4. Fit the CoxPHFitter model
try:
    cph_main_effects.fit(cox_data_main_effects, duration_col='time_to_event', event_col='event')

    # 5. Print the summary of the fitted model
    print(f"\nPenalized Cox Proportional Hazards Model Summary (Main Effects Only, Penalizer={chosen_penalizer}):")
    cph_main_effects.print_summary()

    # 6. Calculate and print the concordance index
    c_index_main_effects = cph_main_effects.concordance_index_
    print(f"\nConcordance Index of the Penalized Cox model (Main Effects Only): {c_index_main_effects:.4f}")

except Exception as e:
    print(f"An error occurred during main effects Cox model fitting: {e}")

# 7. Delete the cox_data_main_effects DataFrame
del cox_data_main_effects

In [None]:
from lifelines import CoxPHFitter
import numpy as np

# Re-define the main effect predictor columns explicitly, excluding interaction terms and num_ckd_grt_categories
main_effect_cols = [
    'time_to_event',
    'event',
    'on_sglt2_inhibitor',
    'on_ace_inhibitor',
    'on_arbs',
    'on_statin'
]

# Add the one-hot encoded grouped diagnosis columns.
diagnosis_cols_one_hot = [col for col in survival_analysis_df.columns
                          if col.startswith('grouped_diagnosis_') and '_x_' not in col]

main_effect_cols.extend(diagnosis_cols_one_hot)

# Select these columns from survival_analysis_df
cox_data_main_effects = survival_analysis_df[main_effect_cols].copy()

# Ensure all selected predictor columns are numeric (convert boolean to int if necessary)
for col in cox_data_main_effects.columns:
    if cox_data_main_effects[col].dtype == 'bool':
        cox_data_main_effects[col] = cox_data_main_effects[col].astype(int)

print("Prepared data for main effects Cox model (re-checked columns):")
display(cox_data_main_effects.head())
cox_data_main_effects.info()


predictor_cols_for_check = [col for col in cox_data_main_effects.columns if col not in ['time_to_event', 'event']]

problematic_main_effect_cols = []
print("\n--- Checking for low variance and complete separation in main effects model ---")

for col in predictor_cols_for_check:
    col_variance = cox_data_main_effects[col].var()
    if col_variance < 1e-6:
        print(f"\nColumn '{col}' has very low variance: {col_variance:.2e}")

        # Perform crosstab with 'event' to confirm complete separation
        crosstab_result = pd.crosstab(cox_data_main_effects[col], cox_data_main_effects['event'])
        print(f"Crosstab for '{col}' and 'event':")
        display(crosstab_result)

        # Check for complete separation logic
        if 1 in crosstab_result.index:
            if (crosstab_result.loc[1, 0] == 0 and crosstab_result.loc[1, 1] > 0) or \
               (crosstab_result.loc[1, 1] == 0 and crosstab_result.loc[1, 0] > 0):
                print(f"Complete separation confirmed for '{col}'. Adding to problematic list.")
                problematic_main_effect_cols.append(col)
            elif crosstab_result.shape[0] == 1:
                 print(f"Column '{col}' has only one unique value. Adding to problematic list.")
                 problematic_main_effect_cols.append(col)
            else:
                print(f"Low variance for '{col}' but complete separation is not evident from crosstab.")
        elif crosstab_result.shape[0] == 1:
             print(f"Column '{col}' has only one unique value. Adding to problematic list.")
             problematic_main_effect_cols.append(col)
        else:
             print(f"Low variance for '{col}' but complete separation check inconclusive.")

# Remove duplicate columns from problematic_main_effect_cols if any
problematic_main_effect_cols = list(set(problematic_main_effect_cols))

if problematic_main_effect_cols:
    print(f"\nIdentified problematic columns to drop for main effects model: {problematic_main_effect_cols}")
    cox_data_main_effects_cleaned = cox_data_main_effects.drop(columns=problematic_main_effect_cols).copy()
else:
    print("\nNo additional problematic columns identified based on low variance/complete separation for main effects model.")
    cox_data_main_effects_cleaned = cox_data_main_effects.copy()

print("\nRemaining columns for main effects modeling (cox_data_main_effects_cleaned):")
print(cox_data_main_effects_cleaned.columns.tolist())

# Add a small positive value to any zero durations if they exist, as required by lifelines
zero_duration_mask_main_effects = cox_data_main_effects_cleaned['time_to_event'] <= 0
if zero_duration_mask_main_effects.any():
    epsilon = np.finfo(float).eps
    cox_data_main_effects_cleaned.loc[zero_duration_mask_main_effects, 'time_to_event'] = epsilon
    print(f"Added {epsilon} to {zero_duration_mask_main_effects.sum()} zero or negative durations in 'time_to_event' for main effects cleaned data.")


# Instantiate a CoxPHFitter object with a chosen penalizer value
# Based on the previous experimentation, penalizer=0.1 performed well.
chosen_penalizer = 0.1
cph_main_effects = CoxPHFitter(penalizer=chosen_penalizer)

# Fit the CoxPHFitter model
try:
    cph_main_effects.fit(cox_data_main_effects_cleaned, duration_col='time_to_event', event_col='event')

    # 5. Print the summary of the fitted model
    print(f"\nPenalized Cox Proportional Hazards Model Summary (Main Effects Only, Penalizer={chosen_penalizer}):")
    cph_main_effects.print_summary()

    # 6. Calculate and print the concordance index
    c_index_main_effects = cph_main_effects.concordance_index_
    print(f"\nConcordance Index of the Penalized Cox model (Main Effects Only): {c_index_main_effects:.4f}")

except Exception as e:
    print(f"An error occurred during main effects Cox model fitting: {e}")


In [None]:
# Interpretation of the Main Effects Penalized Cox Model Summary:

print("\nInterpretation of the Penalized Cox Model Summary (Main Effects Only):")

print("\nOverall Model Fit:")
if hasattr(cph_main_effects, 'log_likelihood_') and hasattr(cph_main_effects, 'summary'):
    print(f"* Partial Log-Likelihood: {cph_main_effects.log_likelihood_}")

    try:
        ll_ratio_test = cph_main_effects.summary.loc['log-likelihood ratio test', 'coef']
        ll_ratio_p_value = cph_main_effects.summary.loc['log-likelihood ratio test', 'p']
        ll_ratio_dof = cph_main_effects.summary.loc['log-likelihood ratio test', 'df']

        print(f"* Log-likelihood ratio test: {ll_ratio_test:.2f} on {ll_ratio_dof} df, p-value = {ll_ratio_p_value:.3f}.")
        if ll_ratio_p_value < 0.005:
            print("This highly significant result indicates that the model with these predictors fits the data significantly better than a model without predictors.")
        else:
            print("This result suggests that the model with these predictors does not significantly improve the fit compared to a model without predictors.")

    except KeyError:
        print("\nCould not find 'log-likelihood ratio test' in the summary index. The summary structure might have changed.")
        print("\nFull model summary:")
        cph_main_effects.print_summary()
else:
    print("Model did not converge successfully, cannot provide overall model fit statistics.")


print("\nIndividual Predictor Interpretation (Hazard Ratios and Significance):")
print("Hazard ratios exp(coef) less than 1 indicate a decreased hazard (lower risk of dialysis).")
print("Hazard ratios exp(coef) greater than 1 indicate an increased hazard (higher risk of dialysis).")

# Ensure summary is available before proceeding
if hasattr(cph_main_effects, 'summary'):
    summary_df = cph_main_effects.summary.copy()
    # Remove the overall test rows for individual predictor interpretation
    rows_to_remove = ['log-likelihood ratio test', 'Concordance', 'Partial AIC', 'log-likelihood ratio test (p=?)'] # Added another common index
    summary_df = summary_df[~summary_df.index.isin(rows_to_remove)]

    alpha = 0.05

    for index, row in summary_df.iterrows():
        # Check if the row has expected columns before accessing
        if 'exp(coef)' in row and 'p' in row and 'exp(coef) lower 95%' in row and 'exp(coef) upper 95%' in row:
            covariate = index
            hazard_ratio = row['exp(coef)']
            p_value = row['p']
            ci_lower = row['exp(coef) lower 95%']
            ci_upper = row['exp(coef) upper 95%']

            significance = "Statistically significant" if p_value < alpha else "Not statistically significant"
            effect = "decreased" if hazard_ratio < 1 else "increased" if hazard_ratio > 1 else "no significant change in"

            print(f"* {covariate}: Hazard Ratio = {hazard_ratio:.2f} (95% CI: {ci_lower:.2f} - {ci_upper:.2f}), p={p_value:.3f}. {significance}.")
            if significance == "Statistically significant":
                print(f"  Being in this category is associated with a significantly {effect} hazard of going on dialysis compared to the reference group (or other categories in a one-hot encoded set).")
            else:
                print(f"  This predictor is not statistically significantly associated with the hazard of going on dialysis in this model.")
        else:
            print(f"* Warning: Missing expected columns in summary row for {index}. Skipping interpretation.")
else:
    print("\nModel did not converge successfully, cannot interpret individual predictors.")

print("\nModel Performance:")
if hasattr(cph_main_effects, 'concordance_index_'):
    print(f"* Concordance Index: {cph_main_effects.concordance_index_:.4f}. A value of {cph_main_effects.concordance_index_:.4f} suggests reasonably good discriminatory power.")
else:
    print("* Concordance Index not available as model did not converge.")

print("\nNote on Diagnosis Interpretation:")
print("Since one-hot encoding was performed without dropping the first category, there is no single explicit reference diagnosis group.")
print("The coefficients for the grouped diagnosis categories represent the difference in the log-hazard compared to a baseline where all diagnosis dummy variables are zero (which doesn't correspond to a real group).")
print("For easier interpretation, you could re-run the one-hot encoding specifying one category as the reference by using `drop_first=True`.")

# Final check on the subtask completion
print("\nSubtask 'Fit a penalized Cox model including only the main effects...' is completed.")


# Lasso & Elastic Net (Simplified)
Lasso (L1, `l1_ratio=1.0`, `penalizer=0.1`) and Elastic Net (`l1_ratio=0.5`, `penalizer=0.1`)

## Prepare Data for Simplified Regularized Cox Models

In [None]:
import numpy as np

# 1. Create a new DataFrame by selecting the specified columns
simplified_cox_data = survival_analysis_df[[
    'time_to_event',
    'event',
    'on_sglt2_inhibitor',
    'on_ace_inhibitor',
    'on_arbs',
    'on_statin'
]].copy()

# 2. Iterate through all columns in the simplified_cox_data DataFrame and convert boolean to int
for col in simplified_cox_data.columns:
    if simplified_cox_data[col].dtype == 'bool':
        simplified_cox_data[col] = simplified_cox_data[col].astype(int)

# 3. Check for any 'time_to_event' values that are less than or equal to zero and adjust
zero_duration_mask = simplified_cox_data['time_to_event'] <= 0
if zero_duration_mask.any():
    epsilon = 1e-6 # A very small positive number
    simplified_cox_data.loc[zero_duration_mask, 'time_to_event'] = epsilon
    print(f"Added {epsilon} to {zero_duration_mask.sum()} zero or negative durations in 'time_to_event'.")

print("Prepared simplified_cox_data DataFrame:")
# 4. Display the head of the simplified_cox_data DataFrame.
display(simplified_cox_data.head())

print("\nInfo of the simplified_cox_data DataFrame:")
# 5. Print the .info() of the simplified_cox_data DataFrame.
simplified_cox_data.info()

In [None]:
from lifelines import CoxPHFitter

# 1. Fit a Lasso (L1) penalized Cox model
# Lasso regularization (L1_ratio = 1.0) for feature selection
print("\n--- Fitting Lasso (L1) Penalized Cox Model ---")
cph_lasso = CoxPHFitter(penalizer=0.1, l1_ratio=1.0)
try:
    cph_lasso.fit(simplified_cox_data, duration_col='time_to_event', event_col='event')

    print("\nLasso Penalized Cox Model Summary:")
    cph_lasso.print_summary()

    # Evaluate the fitted model using concordance index
    c_index_lasso = cph_lasso.concordance_index_
    print(f"\nConcordance Index of the Lasso Cox model: {c_index_lasso:.4f}")

    print("\nCoefficients of Lasso Cox Model:")
    display(cph_lasso.params_)

except Exception as e:
    print(f"An error occurred during Lasso Cox model fitting: {e}")


# 2. Fit an Elastic Net penalized Cox model
# Elastic Net regularization (L1_ratio = 0.5 for a mix of L1 and L2) for balanced regularization
print("\n\n--- Fitting Elastic Net Penalized Cox Model ---")
cph_elastic_net = CoxPHFitter(penalizer=0.1, l1_ratio=0.5)
try:
    cph_elastic_net.fit(simplified_cox_data, duration_col='time_to_event', event_col='event')

    print("\nElastic Net Penalized Cox Model Summary:")
    cph_elastic_net.print_summary()

    # Evaluate the fitted model using concordance index
    c_index_elastic_net = cph_elastic_net.concordance_index_
    print(f"\nConcordance Index of the Elastic Net Cox model: {c_index_elastic_net:.4f}")

    print("\nCoefficients of Elastic Net Cox Model:")
    display(cph_elastic_net.params_)

except Exception as e:
    print(f"An error occurred during Elastic Net Cox model fitting: {e}")

## Fit Lasso (L1) Penalized Cox Model (Simplified)

In [None]:
from lifelines import CoxPHFitter

# 1. Instantiate a CoxPHFitter object named cph_lasso
cph_lasso = CoxPHFitter(penalizer=0.1, l1_ratio=1.0) # l1_ratio=1.0 for pure Lasso regularization

# 2. Fit the cph_lasso model to the prepared simplified data
try:
    cph_lasso.fit(simplified_cox_data, duration_col='time_to_event', event_col='event')

    # 3. Print the summary of the fitted cph_lasso model
    print("\nLasso Penalized Cox Model Summary:")
    cph_lasso.print_summary()

    # 4. Calculate and print the concordance index
    c_index_lasso = cph_lasso.concordance_index_
    print(f"\nConcordance Index of the Lasso Cox model: {c_index_lasso:.4f}")

    # 5. Display the coefficients of the cph_lasso model
    print("\nCoefficients of Lasso Cox Model:")
    display(cph_lasso.params_)

except Exception as e:
    print(f"An error occurred during Lasso Cox model fitting: {e}")

## Fit Elastic Net Penalized Cox Model (Simplified)


In [None]:
from lifelines import CoxPHFitter

# 1. Instantiate a CoxPHFitter object named cph_elastic_net
cph_elastic_net = CoxPHFitter(penalizer=0.1, l1_ratio=0.5) # l1_ratio=0.5 for Elastic Net regularization

# 2. Fit the cph_elastic_net model to the prepared simplified data
try:
    cph_elastic_net.fit(simplified_cox_data, duration_col='time_to_event', event_col='event')

    # 3. Print the summary of the fitted cph_elastic_net model
    print("\nElastic Net Penalized Cox Model Summary:")
    cph_elastic_net.print_summary()

    # 4. Calculate and print the concordance index
    c_index_elastic_net = cph_elastic_net.concordance_index_
    print(f"\nConcordance Index of the Elastic Net Cox model: {c_index_elastic_net:.4f}")

    # 5. Display the coefficients of the cph_elastic_net model
    print("\nCoefficients of Elastic Net Cox Model:")
    display(cph_elastic_net.params_)

except Exception as e:
    print(f"An error occurred during Elastic Net Cox model fitting: {e}")

# Parametric
Weibull, Log-Logistic, and Log-Normal

## Prepare Simplified Data for Parametric Models


In [None]:
import numpy as np

# 1. Create a new DataFrame called simplified_aft_data by selecting the specified columns
simplified_aft_data = survival_analysis_df[[
    'time_to_event',
    'event',
    'on_sglt2_inhibitor',
    'on_ace_inhibitor',
    'on_arbs',
    'on_statin'
]].copy()

# 2. Iterate through all columns in the simplified_aft_data DataFrame and convert boolean to int
for col in simplified_aft_data.columns:
    if simplified_aft_data[col].dtype == 'bool':
        simplified_aft_data[col] = simplified_aft_data[col].astype(int)

# 3. Identify any rows where 'time_to_event' is less than or equal to zero and adjust
zero_duration_mask = simplified_aft_data['time_to_event'] <= 0
if zero_duration_mask.any():
    epsilon = np.finfo(float).eps
    simplified_aft_data.loc[zero_duration_mask, 'time_to_event'] = epsilon
    print(f"Added {epsilon} to {zero_duration_mask.sum()} zero or negative durations in 'time_to_event'.")

print("Prepared simplified_aft_data DataFrame:")
# 4. Display the first few rows of the simplified_aft_data DataFrame.
display(simplified_aft_data.head())

print("\nInfo of the simplified_aft_data DataFrame:")
# 5. Print the .info() of the simplified_aft_data DataFrame.
simplified_aft_data.info()

## Simplified Models


In [None]:
from lifelines import WeibullAFTFitter

# 1. Instantiate a WeibullAFTFitter object with a small penalizer
weibull_aft = WeibullAFTFitter(penalizer=0.01)

# 2. Fit the model to the simplified_aft_data DataFrame
try:
    weibull_aft.fit(simplified_aft_data, duration_col='time_to_event', event_col='event')

    # 3. Print the summary of the fitted Weibull AFT model
    print("\nWeibull AFT Model Summary (Simplified):")
    weibull_aft.print_summary()

    # 4. Calculate and print the concordance index
    c_index_weibull = weibull_aft.concordance_index_
    print(f"\nConcordance Index of the Weibull AFT model: {c_index_weibull:.4f}")

except Exception as e:
    print(f"An error occurred during Weibull AFT model fitting: {e}")

In [None]:
from lifelines import LogLogisticAFTFitter

# 1. Instantiate a LogLogisticAFTFitter object with a small penalizer
loglogistic_aft = LogLogisticAFTFitter(penalizer=0.01)

# 2. Fit the model to the simplified_aft_data DataFrame
try:
    loglogistic_aft.fit(simplified_aft_data, duration_col='time_to_event', event_col='event')

    # 3. Print the summary of the fitted Log-Logistic AFT model
    print("\nLog-Logistic AFT Model Summary (Simplified):")
    loglogistic_aft.print_summary()

    # 4. Calculate and print the concordance index
    c_index_loglogistic = loglogistic_aft.concordance_index_
    print(f"\nConcordance Index of the Log-Logistic AFT model: {c_index_loglogistic:.4f}")

except Exception as e:
    print(f"An error occurred during Log-Logistic AFT model fitting: {e}")

In [None]:
from lifelines import LogNormalAFTFitter

# 1. Instantiate a LogNormalAFTFitter object with a small penalizer
lognormal_aft = LogNormalAFTFitter(penalizer=0.01)

# 2. Fit the model to the simplified_aft_data DataFrame
try:
    lognormal_aft.fit(simplified_aft_data, duration_col='time_to_event', event_col='event')

    # 3. Print the summary of the fitted Log-Normal AFT model
    print("\nLog-Normal AFT Model Summary (Simplified):")
    lognormal_aft.print_summary()

    # 4. Calculate and print the concordance index
    c_index_lognormal = lognormal_aft.concordance_index_
    print(f"\nConcordance Index of the Log-Normal AFT model: {c_index_lognormal:.4f}")

except Exception as e:
    print(f"An error occurred during Log-Normal AFT model fitting: {e}")

## Comprehensive Models


In [None]:
import numpy as np

# 1. Create a new DataFrame named comprehensive_aft_data as a copy of survival_analysis_df.
comprehensive_aft_data = survival_analysis_df.copy()

# 2. Drop the columns 'subject_id', 'treatment_starttime', and 'dialysis_event_date' from comprehensive_aft_data.
columns_to_drop_identifiers = [
    'subject_id',
    'treatment_starttime',
    'dialysis_event_date'
]
comprehensive_aft_data = comprehensive_aft_data.drop(columns=columns_to_drop_identifiers)

# 3. Identify the problematic columns previously excluded from the penalized Cox model.
# These were identified as causing convergence issues due to complete separation or extremely low variance.
problematic_interaction_columns = [
    'grouped_diagnosis_Hypertensive CKD (Stages 5-End Stage)_x_on_sglt2_inhibitor',
    'grouped_diagnosis_Stage 1 CKD_x_on_sglt2_inhibitor',
    'grouped_diagnosis_Stage 5 CKD_x_on_sglt2_inhibitor'
]

# 4. Remove these problematic columns from comprehensive_aft_data.
existing_problematic_cols = [col for col in problematic_interaction_columns if col in comprehensive_aft_data.columns]
if existing_problematic_cols:
    comprehensive_aft_data = comprehensive_aft_data.drop(columns=existing_problematic_cols)
    print(f"Dropped problematic columns: {existing_problematic_cols}")
else:
    print("No specified problematic columns found to drop.")

# 5. Iterate through all remaining columns in comprehensive_aft_data and convert any boolean columns to integer type (0 or 1).
for col in comprehensive_aft_data.columns:
    if comprehensive_aft_data[col].dtype == 'bool':
        comprehensive_aft_data[col] = comprehensive_aft_data[col].astype(int)

# 6. Check for any 'time_to_event' values that are less than or equal to zero and replace them with a small positive epsilon value.
zero_duration_mask = comprehensive_aft_data['time_to_event'] <= 0
if zero_duration_mask.any():
    epsilon = np.finfo(float).eps
    comprehensive_aft_data.loc[zero_duration_mask, 'time_to_event'] = epsilon
    print(f"Added {epsilon} to {zero_duration_mask.sum()} zero or negative durations in 'time_to_event'.")

print("\nPrepared comprehensive_aft_data DataFrame:")
# 7. Display the first few rows of the comprehensive_aft_data DataFrame.
display(comprehensive_aft_data.head())

print("\nInfo of the comprehensive_aft_data DataFrame:")
# 8. Print the .info() of the comprehensive_aft_data DataFrame.
comprehensive_aft_data.info()

In [None]:
from lifelines import WeibullAFTFitter

# 1. Instantiate a WeibullAFTFitter object
weibull_aft_comprehensive = WeibullAFTFitter()

# 2. Fit the model to the comprehensive_aft_data DataFrame
try:
    weibull_aft_comprehensive.fit(comprehensive_aft_data, duration_col='time_to_event', event_col='event')

    # 3. Print the summary of the fitted Weibull AFT model
    print("\nWeibull AFT Model Summary (Comprehensive):")
    weibull_aft_comprehensive.print_summary()

    # 4. Calculate and print the concordance index
    c_index_weibull_comprehensive = weibull_aft_comprehensive.concordance_index_
    print(f"\nConcordance Index of the Weibull AFT model (Comprehensive): {c_index_weibull_comprehensive:.4f}")

except Exception as e:
    print(f"An error occurred during Weibull AFT model fitting with comprehensive data: {e}")

In [None]:
from lifelines import WeibullAFTFitter

# Identify additional problematic columns (from the ConvergenceWarning)
additional_problematic_cols = [
    'grouped_diagnosis_Diabetes with CKD_x_on_sglt2_inhibitor'
]

# Remove these additional problematic columns from comprehensive_aft_data.
existing_additional_problematic_cols = [col for col in additional_problematic_cols if col in comprehensive_aft_data.columns]
if existing_additional_problematic_cols:
    comprehensive_aft_data = comprehensive_aft_data.drop(columns=existing_additional_problematic_cols).copy()
    print(f"Dropped additional problematic columns: {existing_additional_problematic_cols}")
else:
    print("No additional specified problematic columns found to drop.")

# Scale the duration vector down to improve convergence as suggested by the error message
# Using a copy to avoid modifying the original comprehensive_aft_data directly before dropping all problematic columns
comprehensive_aft_data['time_to_event_scaled'] = comprehensive_aft_data['time_to_event'] / 100.0

# 1. Instantiate a WeibullAFTFitter object with a small penalizer
# A penalizer can help with convergence issues, similar to regularization in Cox models.
weibull_aft_comprehensive = WeibullAFTFitter(penalizer=0.01) # Using a small penalizer

# 2. Fit the model to the comprehensive_aft_data DataFrame using the scaled duration
try:
    weibull_aft_comprehensive.fit(comprehensive_aft_data.drop(columns=['time_to_event']),
                                  duration_col='time_to_event_scaled',
                                  event_col='event')

    # 3. Print the summary of the fitted Weibull AFT model
    print("\nWeibull AFT Model Summary (Comprehensive, Scaled, Penalized):")
    weibull_aft_comprehensive.print_summary()

    # 4. Calculate and print the concordance index
    c_index_weibull_comprehensive = weibull_aft_comprehensive.concordance_index_
    print(f"\nConcordance Index of the Weibull AFT model (Comprehensive, Scaled, Penalized): {c_index_weibull_comprehensive:.4f}")

except Exception as e:
    print(f"An error occurred during Weibull AFT model fitting with comprehensive data: {e}")


In [None]:
from lifelines import LogLogisticAFTFitter

# 1. Instantiate a LogLogisticAFTFitter object with a small penalizer
# Using a small penalizer to help with convergence, similar to the Weibull model.
loglogistic_aft_comprehensive = LogLogisticAFTFitter(penalizer=0.01)

# 2. Fit the model to the comprehensive_aft_data DataFrame using the scaled duration
try:
    loglogistic_aft_comprehensive.fit(comprehensive_aft_data.drop(columns=['time_to_event']),
                                      duration_col='time_to_event_scaled',
                                      event_col='event')

    # 3. Print the summary of the fitted Log-Logistic AFT model
    print("\nLog-Logistic AFT Model Summary (Comprehensive, Scaled, Penalized):")
    loglogistic_aft_comprehensive.print_summary()

    # 4. Calculate and print the concordance index
    c_index_loglogistic_comprehensive = loglogistic_aft_comprehensive.concordance_index_
    print(f"\nConcordance Index of the Log-Logistic AFT model (Comprehensive, Scaled, Penalized): {c_index_loglogistic_comprehensive:.4f}")

except Exception as e:
    print(f"An error occurred during Log-Logistic AFT model fitting with comprehensive data: {e}")

In [None]:
from lifelines import LogNormalAFTFitter

# 1. Instantiate a LogNormalAFTFitter object with a small penalizer
# Using a small penalizer to help with convergence, similar to the previous AFT models.
lognormal_aft_comprehensive = LogNormalAFTFitter(penalizer=0.01)

# 2. Fit the model to the comprehensive_aft_data DataFrame using the scaled duration
try:
    lognormal_aft_comprehensive.fit(comprehensive_aft_data.drop(columns=['time_to_event']),
                                      duration_col='time_to_event_scaled',
                                      event_col='event')

    # 3. Print the summary of the fitted Log-Normal AFT model
    print("\nLog-Normal AFT Model Summary (Comprehensive, Scaled, Penalized):")
    lognormal_aft_comprehensive.print_summary()

    # 4. Calculate and print the concordance index
    c_index_lognormal_comprehensive = lognormal_aft_comprehensive.concordance_index_
    print(f"\nConcordance Index of the Log-Normal AFT model (Comprehensive, Scaled, Penalized): {c_index_lognormal_comprehensive:.4f}")

except Exception as e:
    print(f"An error occurred during Log-Normal AFT model fitting with comprehensive data: {e}")

### comprehensive Log-Logistic AFT model coef interpretation


In [None]:
print("\nSummary for Log-Logistic AFT Model (Comprehensive, Scaled, Penalized):\n")
loglogistic_aft_comprehensive.print_summary()

print(f"\nConcordance Index of the Log-Logistic AFT model (Comprehensive, Scaled, Penalized): {loglogistic_aft_comprehensive.concordance_index_:.4f}")

# Convert the summary to a DataFrame
loglogistic_summary_df = loglogistic_aft_comprehensive.summary

print("\nInterpretation of Log-Logistic AFT Model Coefficients:")
print("For AFT models, exp(coef) > 1 means a longer expected event time (protective effect), and exp(coef) < 1 means a shorter expected event time (increased risk).")
print("Statistical significance is generally indicated by p < 0.05.")

# Filter for main effects of GRT categories
grt_main_effects = ['on_sglt2_inhibitor', 'on_ace_inhibitor', 'on_arbs', 'on_statin']
print("\n--- Main Effects of CKD GRT Categories ---")
for effect in grt_main_effects:
    try:
        # Corrected indexing: AFT models' coefficients are under 'alpha_'
        row = loglogistic_summary_df.loc[('alpha_', effect)]
        exp_coef = row['exp(coef)']
        p_value = row['p']
        if p_value < 0.05:
            interpretation = f"Statistically significant. An {exp_coef:.2f}x {'longer' if exp_coef > 1 else 'shorter'} expected time to dialysis."
        else:
            interpretation = "Not statistically significant."
        print(f"* {effect}: exp(coef) = {exp_coef:.2f}, p = {p_value:.3f}. {interpretation}")
    except KeyError:
        print(f"* {effect}: Not found in summary or potentially dropped (e.g., due to multicollinearity or a common reference category).")

# Interpret 'num_ckd_grt_categories'
print("\n--- Main Effect of Number of CKD GRT Categories ---")
try:
    # Corrected indexing
    row = loglogistic_summary_df.loc[('alpha_', 'num_ckd_grt_categories')]
    exp_coef = row['exp(coef)']
    p_value = row['p']
    if p_value < 0.05:
        interpretation = f"Statistically significant. An {exp_coef:.2f}x {'longer' if exp_coef > 1 else 'shorter'} expected time to dialysis for each additional GRT category."
    else:
        interpretation = "Not statistically significant."
    print(f"* num_ckd_grt_categories: exp(coef) = {exp_coef:.2f}, p = {p_value:.3f}. {interpretation}")
except KeyError:
    print(f"* num_ckd_grt_categories: Not found in summary or potentially dropped.")

# Interpret grouped diagnosis categories
print("\n--- Main Effects of Grouped CKD Diagnosis Categories ---")
diagnosis_main_effects = [col for col in loglogistic_summary_df.index.get_level_values(1) if col.startswith('grouped_diagnosis_') and '_x_' not in col and ('alpha_', col) in loglogistic_summary_df.index]
for effect in diagnosis_main_effects:
    try:
        # Corrected indexing
        row = loglogistic_summary_df.loc[('alpha_', effect)]
        exp_coef = row['exp(coef)']
        p_value = row['p']
        if p_value < 0.05:
            interpretation = f"Statistically significant. An {exp_coef:.2f}x {'longer' if exp_coef > 1 else 'shorter'} expected time to dialysis compared to baseline (or other categories in one-hot encoding without `drop_first`)."
        else:
            interpretation = "Not statistically significant."
        print(f"* {effect}: exp(coef) = {exp_coef:.2f}, p = {p_value:.3f}. {interpretation}")
    except KeyError:
        print(f"* {effect}: Not found in summary or potentially dropped.")

# Interpret interaction terms
print("\n--- Interaction Terms: CKD Diagnosis Groups x GRT Categories ---")
interaction_terms = [col for col in loglogistic_summary_df.index.get_level_values(1) if '_x_' in col and ('alpha_', col) in loglogistic_summary_df.index]
for term in interaction_terms:
    try:
        # Corrected indexing
        row = loglogistic_summary_df.loc[('alpha_', term)]
        exp_coef = row['exp(coef)']
        p_value = row['p']
        if p_value < 0.05:
            interpretation = f"Statistically significant. Indicates that the effect of the GRT category on time to dialysis significantly changes within this specific CKD diagnosis group. Acceleration factor = {exp_coef:.2f}."
        else:
            interpretation = "Not statistically significant."
        print(f"* {term}: exp(coef) = {exp_coef:.2f}, p = {p_value:.3f}. {interpretation}")
    except KeyError:
        print(f"* {term}: Not found in summary or potentially dropped.")

# Optuna Hyperparameter

Install the Optuna library, then define and run Optuna studies for hyperparameter optimization of the comprehensive and main-effects penalized Cox Proportional Hazards and the 3 comprehensive parametric models. Subsequently, apply the best hyperparameters to re-train these models. Also utilize xgboost cox model to get a better model with more predictors that are statistically significant

## Install necessary libraries

In [None]:
!pip install optuna
!pip install xgboost

## Define Objective Functions for Optuna Studies


In [None]:
from lifelines import CoxPHFitter, WeibullAFTFitter, LogLogisticAFTFitter, LogNormalAFTFitter

print("Lifelines Fitter classes imported.")

In [None]:
import optuna
import numpy as np

# Objective function for Comprehensive Penalized CoxPH
def objective_cox_comprehensive(trial):
    penalizer = trial.suggest_float('penalizer', 1e-3, 1.0, log=True)
    l1_ratio = trial.suggest_float('l1_ratio', 0.0, 1.0)

    cph = CoxPHFitter(penalizer=penalizer, l1_ratio=l1_ratio)
    try:
        cph.fit(cox_data_final, duration_col='time_to_event', event_col='event')
        return cph.concordance_index_
    except Exception:
        return 0.0

# Objective function for Main-Effects Penalized CoxPH
def objective_cox_main_effects(trial):
    penalizer = trial.suggest_float('penalizer', 1e-3, 1.0, log=True)
    l1_ratio = trial.suggest_float('l1_ratio', 0.0, 1.0)

    cph = CoxPHFitter(penalizer=penalizer, l1_ratio=l1_ratio)
    try:
        cph.fit(cox_data_main_effects_cleaned, duration_col='time_to_event', event_col='event')
        return cph.concordance_index_
    except Exception:
        return 0.0

# Objective function for Comprehensive Weibull AFT
def objective_weibull_comprehensive(trial):
    penalizer = trial.suggest_float('penalizer', 1e-3, 1.0, log=True)

    weibull_aft = WeibullAFTFitter(penalizer=penalizer)
    try:
        weibull_aft.fit(comprehensive_aft_data.drop(columns=['time_to_event']), duration_col='time_to_event_scaled', event_col='event')
        return weibull_aft.concordance_index_
    except Exception:
        return 0.0

# Objective function for Comprehensive Log-Logistic AFT
def objective_loglogistic_comprehensive(trial):
    penalizer = trial.suggest_float('penalizer', 1e-3, 1.0, log=True)

    loglogistic_aft = LogLogisticAFTFitter(penalizer=penalizer)
    try:
        loglogistic_aft.fit(comprehensive_aft_data.drop(columns=['time_to_event']), duration_col='time_to_event_scaled', event_col='event')
        return loglogistic_aft.concordance_index_
    except Exception:
        return 0.0

# Objective function for Comprehensive Log-Normal AFT
def objective_lognormal_comprehensive(trial):
    penalizer = trial.suggest_float('penalizer', 1e-3, 1.0, log=True)

    lognormal_aft = LogNormalAFTFitter(penalizer=penalizer)
    try:
        lognormal_aft.fit(comprehensive_aft_data.drop(columns=['time_to_event']), duration_col='time_to_event_scaled', event_col='event')
        return lognormal_aft.concordance_index_
    except Exception:
        return 0.0

print("Objective functions for Optuna defined.")


## Run Optuna Studies for Each Model



In [None]:
study_results = {}

# 1. Comprehensive Penalized CoxPH
print("\n--- Running Optuna study for Comprehensive Penalized CoxPH ---")
study_cox_comprehensive = optuna.create_study(direction='maximize')
study_cox_comprehensive.optimize(objective_cox_comprehensive, n_trials=50, show_progress_bar=True)
study_results['cox_comprehensive'] = study_cox_comprehensive
print(f"Best trial for Comprehensive CoxPH (C-index): {study_cox_comprehensive.best_value:.4f}")
print(f"Best hyperparameters for Comprehensive CoxPH: {study_cox_comprehensive.best_params}")

# 2. Main-Effects Penalized CoxPH
print("\n--- Running Optuna study for Main-Effects Penalized CoxPH ---")
study_cox_main_effects = optuna.create_study(direction='maximize')
study_cox_main_effects.optimize(objective_cox_main_effects, n_trials=50, show_progress_bar=True)
study_results['cox_main_effects'] = study_cox_main_effects
print(f"Best trial for Main-Effects CoxPH (C-index): {study_cox_main_effects.best_value:.4f}")
print(f"Best hyperparameters for Main-Effects CoxPH: {study_cox_main_effects.best_params}")

# 3. Comprehensive Weibull AFT
print("\n--- Running Optuna study for Comprehensive Weibull AFT ---")
study_weibull_comprehensive = optuna.create_study(direction='maximize')
study_weibull_comprehensive.optimize(objective_weibull_comprehensive, n_trials=50, show_progress_bar=True)
study_results['weibull_comprehensive'] = study_weibull_comprehensive
print(f"Best trial for Comprehensive Weibull AFT (C-index): {study_weibull_comprehensive.best_value:.4f}")
print(f"Best hyperparameters for Comprehensive Weibull AFT: {study_weibull_comprehensive.best_params}")

# 4. Comprehensive Log-Logistic AFT
print("\n--- Running Optuna study for Comprehensive Log-Logistic AFT ---")
study_loglogistic_comprehensive = optuna.create_study(direction='maximize')
study_loglogistic_comprehensive.optimize(objective_loglogistic_comprehensive, n_trials=50, show_progress_bar=True)
study_results['loglogistic_comprehensive'] = study_loglogistic_comprehensive
print(f"Best trial for Comprehensive Log-Logistic AFT (C-index): {study_loglogistic_comprehensive.best_value:.4f}")
print(f"Best hyperparameters for Comprehensive Log-Logistic AFT: {study_loglogistic_comprehensive.best_params}")

# 5. Comprehensive Log-Normal AFT
print("\n--- Running Optuna study for Comprehensive Log-Normal AFT ---")
study_lognormal_comprehensive = optuna.create_study(direction='maximize')
study_lognormal_comprehensive.optimize(objective_lognormal_comprehensive, n_trials=50, show_progress_bar=True)
study_results['lognormal_comprehensive'] = study_lognormal_comprehensive
print(f"Best trial for Comprehensive Log-Normal AFT (C-index): {study_lognormal_comprehensive.best_value:.4f}")
print(f"Best hyperparameters for Comprehensive Log-Normal AFT: {study_lognormal_comprehensive.best_params}")

print("\nOptuna studies completed for all models.")

In [None]:
# Best hyperparameters from Optuna studies
best_params_cox_comprehensive = study_results['cox_comprehensive'].best_params
best_params_cox_main_effects = study_results['cox_main_effects'].best_params
best_params_weibull_comprehensive = study_results['weibull_comprehensive'].best_params
best_params_loglogistic_comprehensive = study_results['loglogistic_comprehensive'].best_params
best_params_lognormal_comprehensive = study_results['lognormal_comprehensive'].best_params

# --- Retrain Comprehensive Penalized CoxPH Model --- #
print("\n--- Retraining Comprehensive Penalized CoxPH Model with Optimized Hyperparameters ---")
cph_comprehensive_optimized = CoxPHFitter(penalizer=best_params_cox_comprehensive['penalizer'],
                                          l1_ratio=best_params_cox_comprehensive['l1_ratio'])
try:
    cph_comprehensive_optimized.fit(cox_data_final, duration_col='time_to_event', event_col='event')
    print("Comprehensive Penalized CoxPH Model Summary (Optimized):")
    cph_comprehensive_optimized.print_summary()
    print(f"Concordance Index: {cph_comprehensive_optimized.concordance_index_:.4f}")
except Exception as e:
    print(f"Error retraining Comprehensive CoxPH: {e}")

# --- Retrain Main-Effects Penalized CoxPH Model --- #
print("\n--- Retraining Main-Effects Penalized CoxPH Model with Optimized Hyperparameters ---")
cph_main_effects_optimized = CoxPHFitter(penalizer=best_params_cox_main_effects['penalizer'],
                                         l1_ratio=best_params_cox_main_effects['l1_ratio'])
try:
    cph_main_effects_optimized.fit(cox_data_main_effects_cleaned, duration_col='time_to_event', event_col='event')
    print("Main-Effects Penalized CoxPH Model Summary (Optimized):")
    cph_main_effects_optimized.print_summary()
    print(f"Concordance Index: {cph_main_effects_optimized.concordance_index_:.4f}")
except Exception as e:
    print(f"Error retraining Main-Effects CoxPH: {e}")

# --- Retrain Comprehensive Weibull AFT Model --- #
print("\n--- Retraining Comprehensive Weibull AFT Model with Optimized Hyperparameters ---")
weibull_aft_comprehensive_optimized = WeibullAFTFitter(penalizer=best_params_weibull_comprehensive['penalizer'])
try:
    weibull_aft_comprehensive_optimized.fit(comprehensive_aft_data.drop(columns=['time_to_event']),
                                            duration_col='time_to_event_scaled',
                                            event_col='event')
    print("Weibull AFT Model Summary (Optimized):")
    weibull_aft_comprehensive_optimized.print_summary()
    print(f"Concordance Index: {weibull_aft_comprehensive_optimized.concordance_index_:.4f}")
except Exception as e:
    print(f"Error retraining Comprehensive Weibull AFT: {e}")

# --- Retrain Comprehensive Log-Logistic AFT Model --- #
print("\n--- Retraining Comprehensive Log-Logistic AFT Model with Optimized Hyperparameters ---")
loglogistic_aft_comprehensive_optimized = LogLogisticAFTFitter(penalizer=best_params_loglogistic_comprehensive['penalizer'])
try:
    loglogistic_aft_comprehensive_optimized.fit(comprehensive_aft_data.drop(columns=['time_to_event']),
                                                duration_col='time_to_event_scaled',
                                                event_col='event')
    print("Log-Logistic AFT Model Summary (Optimized):")
    loglogistic_aft_comprehensive_optimized.print_summary()
    print(f"Concordance Index: {loglogistic_aft_comprehensive_optimized.concordance_index_:.4f}")
except Exception as e:
    print(f"Error retraining Comprehensive Log-Logistic AFT: {e}")

# --- Retrain Comprehensive Log-Normal AFT Model --- #
print("\n--- Retraining Comprehensive Log-Normal AFT Model with Optimized Hyperparameters ---")
lognormal_aft_comprehensive_optimized = LogNormalAFTFitter(penalizer=best_params_lognormal_comprehensive['penalizer'])
try:
    lognormal_aft_comprehensive_optimized.fit(comprehensive_aft_data.drop(columns=['time_to_event']),
                                              duration_col='time_to_event_scaled',
                                              event_col='event')
    print("Log-Normal AFT Model Summary (Optimized):")
    lognormal_aft_comprehensive_optimized.print_summary()
    print(f"Concordance Index: {lognormal_aft_comprehensive_optimized.concordance_index_:.4f}")
except Exception as e:
    print(f"Error retraining Comprehensive Log-Normal AFT: {e}")

print("All models retrained with optimized hyperparameters.")

In [None]:
import xgboost as xgb
import pandas as pd
import numpy as np

# Prepare data for XGBoost Cox model
X_xgb = cox_data_final.drop(columns=['time_to_event', 'event'])

# Create a structured array for the target variable (duration and event)
y_xgb = np.array(list(zip(cox_data_final['time_to_event'], cox_data_final['event'])),
                   dtype=[('f0', '<f8'), ('f1', '?')])

print("Features (X_xgb) head:")
display(X_xgb.head())

print("\nTarget (y_xgb) first 5 rows:")
print(y_xgb[:5])

print("\nShape of X_xgb:", X_xgb.shape)
print("Shape of y_xgb:", y_xgb.shape)


In [None]:
import optuna
import xgboost as xgb
from sklearn.model_selection import KFold
from lifelines.utils import concordance_index

# Objective function for XGBoost Cox model
def objective_xgboost_cox(trial):
    params = {
        'objective': 'survival:cox',
        'eval_metric': 'cox-nloglik',
        'eta': trial.suggest_float('eta', 1e-3, 0.1, log=True),
        'max_depth': trial.suggest_int('max_depth', 3, 10),
        'subsample': trial.suggest_float('subsample', 0.5, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
        'lambda': trial.suggest_float('lambda', 1e-2, 1.0, log=True),
        'alpha': trial.suggest_float('alpha', 1e-2, 1.0, log=True),
        'tree_method': 'hist',
        'n_jobs': -1
    }

    n_estimators = trial.suggest_int('n_estimators', 50, 500)

    # Use K-Fold cross-validation to get a robust C-index
    kf = KFold(n_splits=3, shuffle=True, random_state=42)
    c_indices = []

    for train_index, test_index in kf.split(X_xgb):
        X_train, X_test = X_xgb.iloc[train_index], X_xgb.iloc[test_index]
        y_train, y_test = y_xgb[train_index], y_xgb[test_index]

        # Prepare DMatrix for XGBoost with correct label encoding for survival:cox.
        dtrain = xgb.DMatrix(X_train, label=np.where(y_train['f1'] == 1, y_train['f0'], -y_train['f0']))
        dtest = xgb.DMatrix(X_test)

        try:
            model = xgb.train(params, dtrain, num_boost_round=n_estimators, verbose_eval=False)

            # Predict risk scores on the test set
            test_risk_scores = model.predict(dtest)

            # Calculate C-index using lifelines utility
            c_index = concordance_index(y_test['f0'], -test_risk_scores, y_test['f1'])
            c_indices.append(c_index)
        except Exception:
            # Handle convergence issues or errors during fitting by returning a very low C-index
            return 0.0

    return np.mean(c_indices)

print("Objective function for XGBoost Cox model defined.")


In [None]:
print("\n--- Running Optuna study for XGBoost Cox Model (Corrected) ---")
study_xgboost_cox = optuna.create_study(direction='maximize')
study_xgboost_cox.optimize(objective_xgboost_cox, n_trials=50, show_progress_bar=True)
study_results['xgboost_cox'] = study_xgboost_cox
print(f"Best trial for XGBoost Cox Model (C-index): {study_xgboost_cox.best_value:.4f}")
print(f"Best hyperparameters for XGBoost Cox Model: {study_xgboost_cox.best_params}")

print("Optuna study for XGBoost Cox model completed (Corrected).")

In [None]:
import xgboost as xgb
import numpy as np
from lifelines.utils import concordance_index

# Best hyperparameters from Optuna study for XGBoost Cox Model
best_params_xgboost_cox = study_results['xgboost_cox'].best_params

# --- Retrain XGBoost Cox Model with Optimized Hyperparameters ---
print("\n--- Retraining XGBoost Cox Model with Optimized Hyperparameters ---")

xgb_params_optimized = {
    'objective': 'survival:cox',
    'eval_metric': 'cox-nloglik',
    'eta': best_params_xgboost_cox['eta'],
    'max_depth': best_params_xgboost_cox['max_depth'],
    'subsample': best_params_xgboost_cox['subsample'],
    'colsample_bytree': best_params_xgboost_cox['colsample_bytree'],
    'lambda': best_params_xgboost_cox['lambda'],
    'alpha': best_params_xgboost_cox['alpha'],
    'tree_method': 'hist',
    'n_jobs': -1
}

n_estimators_optimized = best_params_xgboost_cox['n_estimators']

# Prepare final DMatrix using the full dataset for retraining
dmatrix_full = xgb.DMatrix(X_xgb, label=np.where(y_xgb['f1'] == 1, y_xgb['f0'], -y_xgb['f0']))

# Train the final XGBoost Cox model
xgboost_cox_optimized = xgb.train(xgb_params_optimized, dmatrix_full, num_boost_round=n_estimators_optimized, verbose_eval=False)

# Predict risk scores on the full dataset
final_risk_scores_xgb = xgboost_cox_optimized.predict(xgb.DMatrix(X_xgb))

# Calculate C-index using lifelines utility, inverting risk scores as determined in Optuna
c_index_xgboost_optimized = concordance_index(y_xgb['f0'], -final_risk_scores_xgb, y_xgb['f1'])

print(f"XGBoost Cox Model (Optimized) - Concordance Index: {c_index_xgboost_optimized:.4f}")

# Store the optimized model and its C-index for summary
optimized_models = {
    'Comprehensive Penalized CoxPH': {'model': cph_comprehensive_optimized, 'c_index': cph_comprehensive_optimized.concordance_index_},
    'Main-Effects Penalized CoxPH': {'model': cph_main_effects_optimized, 'c_index': cph_main_effects_optimized.concordance_index_},
    'Comprehensive Weibull AFT': {'model': weibull_aft_comprehensive_optimized, 'c_index': weibull_aft_comprehensive_optimized.concordance_index_},
    'Comprehensive Log-Logistic AFT': {'model': loglogistic_aft_comprehensive_optimized, 'c_index': loglogistic_aft_comprehensive_optimized.concordance_index_},
    'Comprehensive Log-Normal AFT': {'model': lognormal_aft_comprehensive_optimized, 'c_index': lognormal_aft_comprehensive_optimized.concordance_index_},
    'XGBoost Cox Model': {'model': xgboost_cox_optimized, 'c_index': c_index_xgboost_optimized}
}

print("\n--- Comparative Summary of Optimized Models (C-index) ---")
performance_summary = pd.DataFrame({
    'Model': [name for name in optimized_models.keys()],
    'C-index': [model_info['c_index'] for model_info in optimized_models.values()]
})
performance_summary = performance_summary.sort_values(by='C-index', ascending=False).reset_index(drop=True)
display(performance_summary)

# Identify the best performing model(s)
best_model_name = performance_summary.loc[0, 'Model']
best_model_c_index = performance_summary.loc[0, 'C-index']

print(f"\nThe best performing model is the {best_model_name} with a C-index of {best_model_c_index:.4f}.")


In [None]:
print("\n--- Interpretation of Comprehensive Penalized CoxPH Model (Optimized) ---")

# Access the summary DataFrame of the optimized Comprehensive Penalized CoxPH model
summary_cph_comprehensive_optimized = cph_comprehensive_optimized.summary.copy()

# Remove the overall test rows for individual predictor interpretation
rows_to_remove = ['log-likelihood ratio test', 'Concordance', 'Partial AIC', 'log-likelihood ratio test (p=?)']
summary_cph_comprehensive_optimized = summary_cph_comprehensive_optimized[~summary_cph_comprehensive_optimized.index.isin(rows_to_remove)]

alpha_level = 0.05

print("Hazard ratios exp(coef) less than 1 indicate a decreased hazard (lower risk of dialysis).")
print("Hazard ratios exp(coef) greater than 1 indicate an increased hazard (higher risk of dialysis).")
print("Statistically significant predictors are those with p < 0.05.")

for index, row in summary_cph_comprehensive_optimized.iterrows():
    if 'exp(coef)' in row and 'p' in row and 'exp(coef) lower 95%' in row and 'exp(coef) upper 95%' in row:
        covariate = index
        hazard_ratio = row['exp(coef)']
        p_value = row['p']
        ci_lower = row['exp(coef) lower 95%']
        ci_upper = row['exp(coef) upper 95%']

        if p_value < alpha_level:
            significance = "Statistically significant (p < 0.05)"
            if hazard_ratio < 1:
                effect = "decreased"
                strength = f"protective effect against dialysis with a hazard ratio of {hazard_ratio:.2f}"
            else:
                effect = "increased"
                strength = f"risk of dialysis with a hazard ratio of {hazard_ratio:.2f}"
            print(f"* {covariate}: {significance}. Associated with {effect} {strength} (95% CI: {ci_lower:.2f} - {ci_upper:.2f}).")
    else:
        print(f"* Warning: Missing expected columns in summary row for {index}. Skipping interpretation.")


# SHAP values for XGBoost model

In [None]:
!pip install shap

In [None]:
print("Head of X_xgb DataFrame:")
display(X_xgb.head())

print("\nInfo of X_xgb DataFrame:")
X_xgb.info()

In [None]:
import shap

# 1. Initialize a SHAP TreeExplainer object
explainer = shap.TreeExplainer(xgboost_cox_optimized, model_output='raw')

# 2. Compute the SHAP values for the X_xgb dataset
shap_values = explainer.shap_values(X_xgb)

print("SHAP values computed successfully.")
print(f"Shape of SHAP values: {shap_values.shape}")

In [None]:
import matplotlib.pyplot as plt

# Generate a SHAP summary plot
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values, X_xgb, plot_type='beeswarm', show=False)
plt.title('SHAP Summary Plot for XGBoost Cox Model')
plt.tight_layout()
plt.show()

In [None]:
import pandas as pd

# Calculate mean absolute SHAP values for each feature
# shap_values contains the SHAP values for each instance and each feature
# X_xgb contains the feature names

feature_importance = pd.DataFrame({
    'Feature': X_xgb.columns,
    'Mean_Absolute_SHAP_Value': np.abs(shap_values).mean(axis=0)
})

# Sort by Mean_Absolute_SHAP_Value in descending order
feature_importance = feature_importance.sort_values(by='Mean_Absolute_SHAP_Value', ascending=False).reset_index(drop=True)

print("\nSHAP Feature Importance Table (Top 20 Most Important Features):")
display(feature_importance.head(20))

# You can adjust .head(20) to display more or fewer features.


SHAP Feature Importance Table (Top 20 Most Important Features):


Unnamed: 0,Feature,Mean_Absolute_SHAP_Value
0,grouped_diagnosis_Hypertensive CKD (Stages 5-E...,0.370502
1,num_ckd_grt_categories,0.151876
2,on_arbs,0.138018
3,grouped_diagnosis_Hypertensive CKD (Stages 5-E...,0.078273
4,grouped_diagnosis_Stage 5 CKD,0.066077
5,grouped_diagnosis_Stage 4 CKD,0.05794
6,on_ace_inhibitor,0.057384
7,on_statin,0.042422
8,grouped_diagnosis_Stage 3 CKD,0.034212
9,grouped_diagnosis_Stage 3 CKD_x_on_statin,0.025393
