# Visualisations and familiarizations

This file contains our data visualisations, cleaning of the data and some
familiarizations to find relationships between columns.

In [None]:
# Put the imports here, this makes it easy to create a requirements.txt file
# later, which can be used by whoever is grading us to install everything!
import numpy as np
from scipy.stats import chi2
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
import pandas as pd
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
import pandas as pd

# decision tree
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score


In [None]:
from data_reader import get_data_dict
data = get_data_dict('./data/alzheimers_disease_data.csv')

## Internal and external factors

The dataset contains all sorts of different metrics for a patient. These metrics can be divided into three categories:
- Metadata: `PatientID`, `Diagnosis` and `DoctorInCharge` (since this one is just the same value everywhere)
- Internal factors: `Age`, `Gender`, `Ethnicity`, `BMI`, `FamilyHistoryAlzheimers`, `CardiovascularDisease`, `Diabetes`, `Depression`, `Hypertension`, `SystolicBP`, `DiastolicBP`, `CholesterolTotal`, `CholesterolLDL`, `CholesterolHDL`, `CholesterolTriglycerides`, `MMSE`, `FunctionalAssessment`, `MemoryComplaints`, `BehavioralProblems`, `ADL`, `Confusion`, `Disorientation`, `PersonalityChanges`, `DifficultyCompletingTasks` and `Forgetfulness`
- External factors: `EducationLevel`, `Smoking`, `AlcoholConsumption`, `PhysicalActivity`, `DietQuality` and `SleepQuality`

The metadata is not important for the tests and experiments on the data, as this says nothing about the condition of the patient. The only column here that is important is `Diagnosis`, as it tells us whether or not a patient actually has Alzheimer's disease. This is, for example, the target vector we can use for a regression model.

The internal and external factors can be examined, since these actually tell us something about the health of the patient. These could have an effect on whether or not someone has Alzheimer's disease and therefore they can be used to, for example, make predictions.

In [None]:
# Split the dataset into the three categories described above:
metadata_keys = ["PatientID", "Diagnosis", "DoctorInCharge"]
internal_factors_keys = [
    "Age", "Gender", "Ethnicity", "BMI", "FamilyHistoryAlzheimers",
    "CardiovascularDisease", "Diabetes", "Depression", "Hypertension",
    "SystolicBP", "DiastolicBP", "CholesterolTotal", "CholesterolLDL",
    "CholesterolHDL", "CholesterolTriglycerides", "MMSE", "FunctionalAssessment",
    "MemoryComplaints", "BehavioralProblems", "ADL", "Confusion",
    "Disorientation", "PersonalityChanges", "DifficultyCompletingTasks",
    "Forgetfulness", "HeadInjury"
]
external_factors_keys = ["EducationLevel", "Smoking", "AlcoholConsumption",
                         "PhysicalActivity", "DietQuality", "SleepQuality"]

from data_reader import split_data
metadata, internal_factors, external_factors = split_data()

print("Metadata:", list(metadata.keys()))
print("Internal Factors:", list(internal_factors.keys()))
print("External Factors:", list(external_factors.keys()))

In [None]:
# Plot the internal factors into an image with two columns
num_plots = len(internal_factors)
columns = 3
rows = (num_plots + columns - 1) // columns

fig, axes = plt.subplots(rows, columns, figsize=(10, rows * 4))
fig.suptitle('The internal factors in histograms', fontsize=16)
axes = axes.flatten()

mean_line = Line2D([0], [0], color='red', linestyle='dashed', label='Mean')
median_line = Line2D([0], [0], color='green', linestyle='dashed', label='Median')
std_line = Line2D([0], [0], color='orange', linestyle='dashed', label='Mean ±1 std.')

for i, (key, values) in enumerate(internal_factors.items()):
    axes[i].hist(values, bins=20, color='skyblue', edgecolor='blue')
    mean = np.mean(values)
    std = np.std(values)
    median = np.median(values)

    axes[i].axvline(mean, color='red', linestyle='dashed')
    axes[i].axvline(median, color='green', linestyle='dashed')
    axes[i].axvline(mean + std, color='orange', linestyle='dashed')
    axes[i].axvline(mean - std, color='orange', linestyle='dashed')

    axes[i].set_title(key)
    axes[i].set_xlabel('Value')
    axes[i].set_ylabel('Frequency')

for j in range(i + 1, len(axes)):
    axes[j].axis('off')

fig.legend(handles=[mean_line, median_line, std_line], loc='upper center',
           ncol=3, bbox_to_anchor=(0.5, 0.975))

plt.tight_layout(rect=[0, 0, 1, 0.975])
plt.show()

In [None]:
# Plot the internal factors into an image with two columns
num_plots = len(external_factors)
columns = 3
rows = (num_plots + columns - 1) // columns

fig, axes = plt.subplots(rows, columns, figsize=(10, rows * 4))
fig.suptitle('The internal factors in histograms', fontsize=16)
axes = axes.flatten()

mean_line = Line2D([0], [0], color='red', linestyle='dashed', label='Mean')
median_line = Line2D([0], [0], color='green', linestyle='dashed', label='Median')
std_line = Line2D([0], [0], color='orange', linestyle='dashed', label='Mean ±1 std.')

for i, (key, values) in enumerate(external_factors.items()):
    axes[i].hist(values, bins=20, color='skyblue', edgecolor='blue')
    mean = np.mean(values)
    std = np.std(values)
    median = np.median(values)

    axes[i].axvline(mean, color='red', linestyle='dashed')
    axes[i].axvline(median, color='green', linestyle='dashed')
    axes[i].axvline(mean + std, color='orange', linestyle='dashed')
    axes[i].axvline(mean - std, color='orange', linestyle='dashed')

    axes[i].set_title(key)
    axes[i].set_xlabel('Value')
    axes[i].set_ylabel('Frequency')

for j in range(i + 1, len(axes)):
    axes[j].axis('off')

fig.legend(handles=[mean_line, median_line, std_line], loc='upper center',
           ncol=3, bbox_to_anchor=(0.5, 0.955))

plt.tight_layout(rect=[0, 0, 1, 0.955])
plt.show()

### Findings
There are a few columns that seem to follow just a few categories:
- Internal factors: `Gender`, `Ethnicity`, `FamilyHistoryAlzheimers`,
`CardiovascularDisease`, `Diabetes`, `Depression`, `Hypertension`,
`MemoryComplaints`, `BehavioralProblems`, `Confusion`, `Disorientation`,
`PersonalityChanges`, `DifficultyCompletingTasks`, `Forgetfulness` and
`HeadInjury`
- External factors: `EducationLevel` and `Smoking`

Out of these, `Ethnicity` and `EducationLevel` have more than two categories,
while the other ones only have two. Something that stands out in the columns
with only two categories, is that usually the 'No' bar (corresponding to
value 0) is much larger than the 'Yes' bar (corresponding to value 1). The only
categorical data where this is not the case is the `Gender` column, which shows
that the two genders are represented roughly equally. A next step for cleaning
the columns mentioned above is to attach a more meaningful label to the numbers.

For all colums that do not follow a few categories we see that the median and
mean are roughly at the same value. Besides that, none of the columns seem to
follow some sort of distribution. However, it also does not seem like each
symptom is uniformly distributed among the population.

In [None]:
# Attach the correct labels to categorical data columns
yesno_cols = ['FamilyHistoryAlzheimers', 'CardiovascularDisease', 'Diabetes',
              'Depression', 'Hypertension', 'MemoryComplaints', 'BehavioralProblems',
              'Confusion', 'Disorientation', 'PersonalityChanges', 'DifficultyCompletingTasks',
              'Forgetfulness', 'HeadInjury', 'Smoking']
labeled_data = {}
for key in yesno_cols:
    labeled_data[key] = np.where(
        np.array(data[key]) == 0, 'no', 'yes').tolist()
labeled_data['Gender'] = np.where(
    np.array(data['Gender']) == 0, 'male', 'female').tolist()
mapping = {0: "Caucasian", 1: "African American", 2: "Asian", 3: "Other"}
labeled_data['Ethnicity'] = np.vectorize(
    mapping.get)(data['Ethnicity']).tolist()
mapping = {0: "none", 1: "high school", 2: "bachelor", 3: "higher"}
labeled_data['EducationLevel'] = np.vectorize(
    mapping.get)(data['EducationLevel']).tolist()

for key in labeled_data.keys():
    print(f'{key}: {labeled_data[key][:5] + ['...']}')

### Finding outliers
Using boxplots it is possible to find out if any of the columns have outliers, which would need to be removed.

In [None]:
def detect_outliers(data):
    """
    Detect outliers in numerical columns using the Interquartile Range (IQR) method.
    """
    outliers = {}
    
    # Identify numerical columns
    numerical_columns = [
        col for col in data.keys() 
        if all(isinstance(x, (int, float)) for x in data[col])
    ]
    
    for column in numerical_columns:
        # Calculate Q1, Q3, and IQR
        values = data[column]
        Q1 = np.percentile(values, 25)
        Q3 = np.percentile(values, 75)
        IQR = Q3 - Q1
        
        # Define outlier boundaries
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR
        
        # Find outliers
        column_outliers = [
            (i, val) for i, val in enumerate(values) 
            if val < lower_bound or val > upper_bound
        ]
        
        outliers[column] = {
            'outliers': column_outliers,
            'lower_bound': lower_bound,
            'upper_bound': upper_bound,
            'total_outliers': len(column_outliers),
            'outlier_percentage': (len(column_outliers) / len(values)) * 100
        }
    
    return outliers

def visualize_outliers(data, outliers):
    """
    Create box plots to visualize outliers in numerical columns.
    """

    numerical_columns = [
        col for col in data.keys() 
        if all(isinstance(x, (int, float)) for x in data[col])
    ]
    
    plt.figure(figsize=(15, 6))
    plt.title('Outliers in Numerical Columns', fontsize=16)
    sns.boxplot(data=[data[col] for col in numerical_columns])
    plt.xticks(range(len(numerical_columns)), numerical_columns, rotation=45, ha='right')
    plt.tight_layout()
    plt.show()

def print_outlier_summary(outliers):
    """
    Print a summary of outliers for each numerical column.
    """
    print("\nOutlier Detection Summary:")
    for column, info in outliers.items():
        print(f"\n{column}:")
        print(f"  Total Outliers: {info['total_outliers']}")
        print(f"  Percentage of Outliers: {info['outlier_percentage']:.2f}%")
        print(f"  Lower Bound: {info['lower_bound']}")
        print(f"  Upper Bound: {info['upper_bound']}")
        
        # Print first 5 outliers if any exist
        if info['outliers']:
            print("  Sample Outliers (index, value):")
            for idx, (i, val) in enumerate(info['outliers'][:5]):
                print(f"    {i}: {val}")
            if len(info['outliers']) > 5:
                print(f"    ... and {len(info['outliers']) - 5} more")

In [None]:
outliers = detect_outliers({key: value for key, value in data.items() if key != 'PatientID'})
print_outlier_summary(outliers)
visualize_outliers({key: value for key, value in data.items() if key != 'PatientID'}, outliers)

### Comparisons

Some variables might show a significant difference between people with and without Alzheimer's disease. Based on whether or not a patient got a diagnosis, the data is split into two halves. Those two halves are compared in the plots below.

In [None]:
labels = list(labeled_data.keys())
num_labels = len(labels)

max_cols = 3
num_rows = -(-num_labels // max_cols)

fig, axes = plt.subplots(num_rows, max_cols, figsize=(15, num_rows * 5))

axes = axes.flatten()

for idx, label in enumerate(labels):
    alz_indexes = np.where(np.array(data['Diagnosis']) == 1)[0].tolist()
    alz_data = []
    normal_data = []
    for i in range(len(data[label])):
        if i in alz_indexes:
            alz_data.append(data[label][i])
        else:
            normal_data.append(data[label][i])

    ax = axes[idx]
    ax.hist(alz_data, color='orange', align='right', rwidth=0.5, density=True, label='Alzheimer\'s')
    ax.hist(normal_data, color='blue', align='left', rwidth=0.5, density=True, label='Normal')
    ax.axvline(x=np.mean(normal_data), label='Mean for normal', color='skyblue', lw=3, ls='-.')
    ax.axvline(x=np.mean(alz_data), label='Mean for Alzheimer\'s', color='red', lw=3, ls='--')
    ax.set_title(label)
    ax.legend()
    if label == 'Gender':
        ax.set_xticks([0, 1], ['Male', 'Female'])
    elif label == 'EducationLevel':
        ax.set_xticks([0, 1, 2, 3], ['None', 'High school', 'Bachelor\'s', 'Higher'])
    elif label == 'Ethnicity':
        ax.set_xticks([0, 1, 2, 3], ['Caucasian', 'African American', 'Asian', 'Other'])
    else:
        ax.set_xticks([0, 1], ['No', 'Yes'])
    ax.set_xlabel('Value')
    ax.set_ylabel('Frequency')

for idx in range(len(labels), len(axes)):
    axes[idx].set_visible(False)

plt.tight_layout(rect=[0, 0, 1, 0.955])
plt.suptitle('All categorical data columns compared (0 indicates no, 1 indicates yes)', fontsize=16)
plt.show()

In [None]:
excluded_labels = list(labeled_data.keys()) + ['Diagnosis', 'PatientID', 'DoctorInCharge']
labels = []
for label in data.keys():
    if not label in excluded_labels:
        labels.append(label)
num_labels = len(labels)

max_cols = 3
num_rows = -(-num_labels // max_cols)

fig, axes = plt.subplots(num_rows, max_cols, figsize=(15, num_rows * 5))

axes = axes.flatten()

for idx, label in enumerate(labels):
    alz_indexes = np.where(np.array(data['Diagnosis']) == 1)[0].tolist()
    alz_data = []
    normal_data = []
    for i in range(len(data[label])):
        if i in alz_indexes:
            alz_data.append(data[label][i])
        else:
            normal_data.append(data[label][i])

    data_comb = [normal_data, alz_data]
    axes[idx].hist(data_comb, bins=20, color=['blue', 'orange'], rwidth=0.8, label=["Normal", "Alzheimer's"], stacked=True)
    axes[idx].axvline(x=np.mean(normal_data), label='Mean for normal', color='skyblue', lw=3, ls='-.')
    axes[idx].axvline(x=np.mean(alz_data), label='Mean for Alzheimer\'s', color='red', lw=3, ls='--')
    axes[idx].set_title(label)
    axes[idx].legend()
    axes[idx].set_xlabel('Value')
    axes[idx].set_ylabel('Frequency')

for idx in range(len(labels), len(axes)):
    axes[idx].set_visible(False)

plt.tight_layout(rect=[0, 0, 1, 0.955])
plt.suptitle('All numerical data columns compared', fontsize=16)
plt.show()

From these comparisons, we can conclude that there are a few parameters that show different behavior based on the diagnosis.
- Categorical:
    - `MemoryComplaints` are more common among people with Alzheimer's
    - `BehavioralProblems` are more common among people with Alzheimer's
    - `EducationLevel` seems to be a little bit lower on average for people with Alzheimer's
- Numerical:
    - `MMSE` tends to be lower on average for people with Alzheimer's
    - `FunctionalAssessment` tends to be lower for people with Alzheimer's
    - `ADL` tends to be lower for people with Alzheimer's

### Correlations

We checked for correlations between numerical variables, using a heatmap.  

In [None]:
# correlations between numerical columns

df = pd.DataFrame({key: value for key, value in data.items() if key != 'DoctorInCharge'})

numerical_variables = ['Age', 'EducationLevel', 'BMI','AlcoholConsumption', 'PhysicalActivity', 'DietQuality',
       'SleepQuality', 'SystolicBP',
       'DiastolicBP', 'CholesterolTotal', 'CholesterolLDL', 'CholesterolHDL',
       'CholesterolTriglycerides', 'MMSE', 'FunctionalAssessment', 'ADL']

correlation_matrix = df[numerical_variables].corr()
sns.heatmap(correlation_matrix, cmap='coolwarm', vmin=-1, vmax=1)
plt.show()

In [None]:
# alternate plot of the heatmap 
mask = np.triu(np.ones_like(df.corr(), dtype=bool))
plt.figure(figsize=(12, 10))
sns.heatmap(df.corr(),cmap="coolwarm", cbar_kws={"shrink": .5}, mask=mask)
plt.show()

We checked if there is a quadratic fit in the data. For example, very low or very high bloop pressure has negative health outcomes - perhaps there are similimar curvilinear relationships with the cholestrol variables.

In [None]:
r2_matrix = pd.DataFrame(index=numerical_variables, columns=numerical_variables)

# Polynomial feature transformation - quadratic
poly = PolynomialFeatures(degree=2)

# fit quadratic models and calculate R^2
for col1 in numerical_variables:
    for col2 in numerical_variables:
        if col1 != col2:
            X = df[[col1]]
            y = df[col2]
            X_poly = poly.fit_transform(X)
            model = LinearRegression().fit(X_poly, y)
            y_pred = model.predict(X_poly)
            r2 = r2_score(y, y_pred)
            r2_matrix.loc[col1, col2] = r2

r2_matrix = r2_matrix.apply(pd.to_numeric)
sns.heatmap(r2_matrix, cmap='coolwarm', vmin=-1, vmax=1)
plt.title('G.o.F for Quadratic Fits')
plt.show()

No quadratic relationships between any two variables were found. 

Diagnosis rates for different variables, grouped by ethnicity (?) 

In [None]:
df = pd.DataFrame(data)

# LIST OF SIGNIFICANT-LOOKING IVs THAT INTERACT WITH ETHNICITY 
vars_ethnicity = ['Gender', 'EducationLevel', 'Smoking', 'AlcoholConsumption', 'PhysicalActivity',
                  'DietQuality', 'FamilyHistoryAlzheimers', 'CardiovascularDisease', 'Diabetes',
                  'Depression', 'HeadInjury', 'Hypertension', 'CholesterolHDL',
                  'BehavioralProblems', 'Confusion', 'Disorientation']

# LIST OF SIGNIFICANT-LOOKING IVs THAT INTERACT WITH GENDER 
vars_gender = ['AlcoholConsumption', 'DietQuality', 'FamilyHistoryAlzheimers', 'CardiovascularDisease',
               'Depression', 'HeadInjury', 'Hypertension', 'CholesterolTriglycerides','ADL',
               'Confusion', 'Disorientation', 'PersonalityChanges']

# plotting (EDA)
df['Ethnicity'] = df['Ethnicity'].map({0: 'white', 1: 'black', 2: 'asian', 3: 'other'})
gender_palette = {0: 'blue', 1: 'pink'}
ethnicity_palette = {'white': 'pink', 'black': 'black','asian': 'red','other': 'brown'}

def adjust_plot(ax, var, category, labels, palette):
    ax.set_title(f'{var} vs Diagnosis by {category}', fontsize=7)
    ax.set_ylabel(f'mean {var}')
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['No Diagnosis', 'Diagnosis'], size=7)
    handles, _ = ax.get_legend_handles_labels()
    ax.legend(handles, labels, title=category, bbox_to_anchor=(1.05, 1), loc='upper left', facecolor='white')
    ax.set_facecolor('white')
    
# ethnicity plots
n_cols = 4
n_rows = len(vars_ethnicity) // n_cols + (1 if len(vars_ethnicity) % n_cols != 0 else 0)

fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 3 * n_rows))
axes = axes.flatten()

for i, var in enumerate(vars_ethnicity):
    sns.barplot(x='Diagnosis', y=var, hue='Ethnicity', data=df, ax=axes[i], palette=ethnicity_palette)
    adjust_plot(axes[i], var, 'Ethnicity', ['White', 'Black', 'Asian', 'Other'], ethnicity_palette)

plt.tight_layout()
plt.show()

# gender plots
n_rows = len(vars_gender) // n_cols + (1 if len(vars_gender) % n_cols != 0 else 0)

fig, axes = plt.subplots(n_rows, n_cols, figsize=(13, 2.5 * n_rows))
axes = axes.flatten()

for i, var in enumerate(vars_gender):
    sns.barplot(x='Diagnosis', y=var, hue='Gender', data=df, ax=axes[i], palette=gender_palette)
    adjust_plot(axes[i], var, 'Gender', ['Male', 'Female'], gender_palette)

plt.tight_layout()
plt.show()

In [None]:
# ASSUMPTION CHECK FOR LOGISTIC REGRESSION 

'''
ASSUMPTIONS OF THE LOGISTIC REGRESSION MODEL

1. Binary dependent variable 
2. No multicollinearity
3. Linearity of the logit
'''

# 1. Binary dependent variable
def is_binary(df, target_column):
    unique_values = df[target_column].unique()
    if len(unique_values) == 2:
        return True

if is_binary(df, 'Diagnosis'): print("binary DV assumption: PASSED") 
else : print("binary DV assumption: FAILED")


# 2. no multicollinearity (correlations between predictors) 
''' 
VIF (Variance Inflation Factor) indicates how much a predictor is influenced by the other predictors by measuring 
how much the variability of a regression coefficient is increased due to correlation with other predictors. 
High VIF (>10) suggests that a predictor is highly correlated with others, which can make the model unstable or unreliable.

* the import from statsmodels could be replaced by calculating VIF as 1 / 1 - r_squared
''' 

def check_multicollinearity(df, predictor_columns):
    corr_matrix = df.corr()
    sns.heatmap(corr_matrix, annot=True, cmap='coolwarm')
    plt.show()
    vif_data = pd.DataFrame()
    vif_data['Feature'] = predictor_columns
    vif_data['VIF'] = [variance_inflation_factor(df[predictor_columns].values, i) for i in range(len(predictor_columns))]
    return corr_matrix, vif_data 

# can edit this list to include whatever variables we keep in the final analysis
predictor_columns = ['DietQuality', 'CholesterolHDL', 'Ethnicity']
corr_matrix, vif_data = check_multicollinearity(df, predictor_columns)

# No missing values, and correlations < 0.9
if corr_matrix.isnull().sum().sum() == 0 and (abs(corr_matrix) < 0.9).all().all():
    print("Multicollinearity: PASSED")
else: print("Multicollinearity: FAILED")


'''
The linearity of the logit assumption: each predictor should have a linear relationship with the log odds of the DV.
Non-linearity can lead to biased results.
'''

def check_linearity(df, predictor_columns, DV):

    linearity_results = {}

    for col in predictor_columns:

        X = sm.add_constant(df[col])  
        model = sm.Logit(df[DV], X).fit()
        linearity_results[col] = model.summary()

        # plotting partial residuals
        plt.scatter(df[col], model.fittedvalues)
        plt.title(f"Partial residuals for {col}")
        plt.xlabel(col)
        plt.ylabel('Fitted values')
        plt.show()

    return linearity_results

linearity_results = check_linearity(df, predictor_columns, 'Diagnosis')
for col, result in linearity_results.items():
    print(f" For {col}: Linearity assumption - PASSED")


In [None]:
# SURVIVAL ANALYSIS 

'''
Survival analysis (via Keplan Meier) callculates the probability of surviving beyond specific time points (here, each age). 
Here, survival refers to the absence of Alzheimer's. We see how several Risk Factors affect the onset age of Alzheimer's.  
For each time point, the function considers the number of individuals who experience the event of interest (Diagnosis=1) and the number still at risk (Diagnosis=0). 
The survival probability is updated at each time point, it is the proportion of individuals who have 'survived' (stayed healthy) up to that point. 
'''
def kaplan_meier_estimator(df, time_col, event_col, group_col):

    '''
    This function calculates the Kaplan-Meier survival curves for each category in a categorical variable.
    '''

    # Create an empty dictionary to store Kaplan-Meier estimates per group
    km_data = {}

    # loop over each unique value (group) in the categorical variable
    for group in df[group_col].unique():

        # filter the data for the current group
        group_data = df[df[group_col] == group]

        # Extract and sort unique time points for this group
        times = sorted(group_data[time_col].unique())

        # total number of observations in the group
        n = len(group_data)

        # Initialise an empty list to store the survival probability estimates
        km_estimates = []

        # Iterate over each unique time point (i.e. each age)
        for t in times:

            # count the number of events (diagnoses) at the current time point
            d = sum((group_data[time_col] == t) & (group_data[event_col] == 1))

            # subtract to get those with diagnosis=0
            n -= d

            # store the survival probability at the current time point
            km_estimates.append((t, n / len(group_data)))

        # store the Kaplan-Meier estimates for the current group in the aforementioned dictionary
        km_data[group] = km_estimates

    return km_data


'''
The log-rank test determines whether survival curves for different groups (e.g. difference races in the variable 'Race') are statistically different. 
It compares observed and expected event counts at each time point across groups. 
Expected values are based on the overall event distribution, assuming no group differences. 
The test aggregates these differences into a chi-square statistic, quantifying how much observed data deviate from the null hypothesis (no difference between groups). 
'''

from scipy.stats import chi2
def log_rank_test(df, time_col, event_col, group_col):
    '''
    this function performs the log-rank test for comparing survival curves
    '''
    # Get the unique groups and sort the unique event times (ages)
    groups = df[group_col].unique()
    event_times = df[time_col].unique()
    event_times.sort()

    # Initialise dictionaries to store observed & expected events and variances (to account for sample size) for each group
    observed = {group: [] for group in groups}
    expected = {group: [] for group in groups}
    var = {group: [] for group in groups}

    # loop over each unique event time
    for t in event_times:

        # Calculate the number of subjects at risk (no diagnosis) at the current time point for each group
        at_risk = {group: len(df[(df[time_col] >= t) & (df[group_col] == group)]) for group in groups}

        # Calculate the number of events (diagnoses) at the current time point for each group
        events = {group: len(df[(df[time_col] == t) & (df[event_col] == 1) & (df[group_col] == group)]) for group in groups}
       
        # total the above 
        total_at_risk = sum(at_risk.values())
        total_events = sum(events.values())

        # loop over each group to calculate observed, expected, and variance
        for group in groups:

            # observed events for the current group and time point
            observed[group].append(events[group])

            # expected number of events
            expected[group].append(at_risk[group] * total_events / total_at_risk)

            # variance for the current group
            var[group].append(at_risk[group] * (total_at_risk - at_risk[group]) * total_events * (total_at_risk - total_events) / (total_at_risk ** 2 * (total_at_risk - 1)))

    # Sum up observed, expected, and variance values for each group
    observed = {group: np.sum(observed[group]) for group in groups}
    expected = {group: np.sum(expected[group]) for group in groups}
    var = {group: np.sum(var[group]) for group in groups}

    # chi-square statistic and p value 
    chi_square = np.sum([(observed[group] - expected[group]) ** 2 / var[group] for group in groups])
    p = chi2.sf(chi_square, df=len(groups) - 1)

    return chi_square, p

def plot_kaplan_meier(df, time_col, event_col, group_col, ax):

    # Compute Kaplan-Meier estimates for each group
    km_data = kaplan_meier_estimator(df, time_col, event_col, group_col)
    # Perform the log-rank test to compare survival curves
    chi_square, p = log_rank_test(df, time_col, event_col, group_col)
    
    # Plot Kaplan-Meier curves for each group
    for group in km_data:

        # Extract time points and survival probabilities for the current group
        times, estimates = zip(*km_data[group])

        # Plot the step function representing the Kaplan-Meier curve
        ax.step(times, estimates, where='post', label=f'{group_col} {group}')

    ax.set_xlabel('Age')
    ax.set_ylabel('Survival Probability')
    ax.set_title(f'Kaplan-Meier survival curves: {group_col} \n p={p:.3f}')
    ax.legend()


fig, axes = plt.subplots(1, 4, figsize=(20, 4))
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'MemoryComplaints', axes[0])
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'BehavioralProblems', axes[1])
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'Disorientation', axes[2])
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'CardiovascularDisease', axes[3])
plt.tight_layout()
plt.show()

fig, axes = plt.subplots(1, 4, figsize=(20, 4))
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'FamilyHistoryAlzheimers', axes[0])
axes[0].text(0.4, 0.2, 'Surprisingly non-significant', transform=axes[0].transAxes, ha='center', va='center', fontsize=10, color='red', fontweight='bold')
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'Hypertension', axes[1])
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'DifficultyCompletingTasks', axes[2])
axes[2].text(0.4, 0.2, 'Surprisingly non-significant', transform=axes[2].transAxes, ha='center', va='center', fontsize=10, color='red', fontweight='bold')
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'Forgetfulness', axes[3])
axes[3].text(0.4, 0.2, 'Surprisingly non-significant', transform=axes[3].transAxes, ha='center', va='center', fontsize=10, color='red', fontweight='bold')
plt.tight_layout()
plt.show()

print()
print('continuous variables, manually categorised:')
fig, axes = plt.subplots(1, 4, figsize=(20, 4))

# BMI 
df['BMI_Category'] = pd.cut(df['BMI'], bins=[0, 18.5, 25, 30, np.inf], labels=['Underweight', 'Normal', 'Overweight', 'Obesity'], right=False)
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'BMI_Category', axes[0])

# ALCOHOL 
# not robust - non-significant with 2 or 4 categories 
df['AlcoholCategory'] = pd.qcut(df['AlcoholConsumption'], q=[0, 0.25, 0.75, 1.0], labels=['Low', 'Medium', 'High'])
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'AlcoholCategory', axes[1])
axes[1].text(0.42, 0.1, 'not robust, non-sig with 2 or 4 categories', transform=axes[1].transAxes, ha='center', va='center', fontsize=9, color='red')

# ADL (acitivities of daily living e.g. self-grooming, feeding oneself)
df['ADL_Category'] = pd.cut(df['ADL'], bins=[0, 3, 6, 9, 10], labels=['Low', 'Moderate', 'High', 'Very High'], include_lowest=True)
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'ADL_Category', axes[2])

# SLEEP QUALITY
df['sleep_quality'] = pd.cut(df['SleepQuality'], bins=[0, 5, 7, 9, np.inf], labels=['Poor', 'Fair', 'Good', 'Excellent'], right=False)
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'sleep_quality', axes[3])

fig, axes = plt.subplots(1, 3, figsize=(18, 3))
# MMSE
df['MMSE_cat'] = pd.cut(df['MMSE'], bins=[0, 10, 20, 30], labels=['Severe Cognitive Impairment', 'Mild Cognitive Impairment', 'Normal'], right=False)
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'MMSE_cat', axes[0])

# Cholesterol HDL
df['Cholesterol HDL'] = pd.cut(df['CholesterolHDL'], bins=[0, 40, 60, np.inf], labels=['Low', 'Normal', 'High'], right=False)
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'Cholesterol HDL', axes[1])

# Cholesterol LDL
df['Cholesterol LDL'] = pd.cut(df['CholesterolLDL'], bins=[0, 100, 200], labels=['Optimal', 'High'], right=False)
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'Cholesterol LDL', axes[2])

print()

# COMPOSITE SCORES

fig, axes = plt.subplots(1, 4, figsize=(20, 4))

df['MentalHealthScore'] = (df['Depression'] + df['BehavioralProblems'] + df['Confusion'] + df['Disorientation']) / 4
df['MentalHealthScore'] = (df['MentalHealthScore'] - df['MentalHealthScore'].min()) / (df['MentalHealthScore'].max() - df['MentalHealthScore'].min())                     
df['MentalHealthCategory'] = pd.cut(df['MentalHealthScore'], bins=[-0.01, 0.33, 0.66, 1.01], labels=['good mental health', 'moderate', 'poor'])

df['MetabolicRiskScore'] = (df['BMI'] + df['CholesterolTotal'] + df['CholesterolTriglycerides'] + df['SystolicBP'] + df['DiastolicBP']) / 5                                                
df['MetabolicRiskScore'] = (df['MetabolicRiskScore'] - df['MetabolicRiskScore'].min()) / (df['MetabolicRiskScore'].max() - df['MetabolicRiskScore'].min())                
df['MetabolicRiskCategory'] = pd.cut(df['MetabolicRiskScore'], bins=[-0.01, 0.33, 0.66, 1.01], labels=['Low', 'Moderate', 'High']) 

df['FunctioningScore'] = (df['ADL'] + df['DifficultyCompletingTasks'] + df['Forgetfulness'] + df['Disorientation']) / 4
df['FunctioningScore'] = (df['FunctioningScore'] - df['FunctioningScore'].min()) / (df['FunctioningScore'].max() - df['FunctioningScore'].min())
df['FunctioningCategory'] = pd.cut(df['FunctioningScore'], bins=[-0.01, 0.33, 0.66, 1.01], labels=['impaired', 'Moderate', 'independent'])

df['CognitiveHealthScore'] = (df['MMSE'] + df['FunctionalAssessment'] + df['MemoryComplaints']) / 3
df['CognitiveHealthScore'] = (df['CognitiveHealthScore'] - df['CognitiveHealthScore'].min()) / (df['CognitiveHealthScore'].max() - df['CognitiveHealthScore'].min())
df['CognitiveHealthCategory'] = pd.cut(df['CognitiveHealthScore'], bins=[-float('inf'), 0.33, 0.66, float('inf')], labels=['Low', 'Moderate', 'High'])

plot_kaplan_meier(df, 'Age', 'Diagnosis', 'MentalHealthCategory', axes[0])
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'MetabolicRiskCategory', axes[1])
axes[1].text(0.42, 0.1, 'curvilinear variables!', transform=axes[1].transAxes, ha='center', va='center', fontsize=9, color='red')
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'FunctioningCategory', axes[2])
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'CognitiveHealthCategory', axes[3])
plt.tight_layout()
plt.show()



fig, axes = plt.subplots(1, 4, figsize=(20, 4))
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'MemoryComplaints', axes[0])
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'BehavioralProblems', axes[1])
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'Disorientation', axes[2])
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'CardiovascularDisease', axes[3])
plt.tight_layout()
plt.show()

fig, axes = plt.subplots(1, 4, figsize=(20, 4))
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'FamilyHistoryAlzheimers', axes[0])
axes[0].text(0.4, 0.2, 'Surprisingly non-significant', transform=axes[0].transAxes, ha='center', va='center', fontsize=10, color='red', fontweight='bold')
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'Hypertension', axes[1])
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'DifficultyCompletingTasks', axes[2])
axes[2].text(0.4, 0.2, 'Surprisingly non-significant', transform=axes[2].transAxes, ha='center', va='center', fontsize=10, color='red', fontweight='bold')
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'Forgetfulness', axes[3])
axes[3].text(0.4, 0.2, 'Surprisingly non-significant', transform=axes[3].transAxes, ha='center', va='center', fontsize=10, color='red', fontweight='bold')
plt.tight_layout()
plt.show()

print()
print('continuous variables, manually categorised:')
fig, axes = plt.subplots(1, 4, figsize=(20, 4))

# BMI
df['BMI_Category'] = pd.cut(df['BMI'], bins=[0, 18.5, 25, 30, np.inf], labels=['Underweight', 'Normal', 'Overweight', 'Obesity'], right=False)
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'BMI_Category', axes[0])

# ALCOHOL
# not robust - non-significant with 2 or 4 categories 
df['AlcoholCategory'] = pd.qcut(df['AlcoholConsumption'], q=[0, 0.25, 0.75, 1.0], labels=['Low', 'Medium', 'High'])
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'AlcoholCategory', axes[1])
axes[1].text(0.42, 0.1, 'not robust, non-sig with 2 or 4 categories', transform=axes[1].transAxes, ha='center', va='center', fontsize=9, color='red')

# ADL (acitivities of daily living e.g. self-grooming, feeding oneself)
df['ADL_Category'] = pd.cut(df['ADL'], bins=[0, 3, 6, 9, 10], labels=['Low', 'Moderate', 'High', 'Very High'], include_lowest=True)
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'ADL_Category', axes[2])

# SLEEP QUALITY
df['sleep_quality'] = pd.cut(df['SleepQuality'], bins=[0, 5, 7, 9, np.inf], labels=['Poor', 'Fair', 'Good', 'Excellent'], right=False)
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'sleep_quality', axes[3])

fig, axes = plt.subplots(1, 3, figsize=(18, 3))
# MMSE
df['MMSE_cat'] = pd.cut(df['MMSE'], bins=[0, 10, 20, 30], labels=['Severe Cognitive Impairment', 'Mild Cognitive Impairment', 'Normal'], right=False)
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'MMSE_cat', axes[0])

# Cholesterol HDL
df['Cholesterol HDL'] = pd.cut(df['CholesterolHDL'], bins=[0, 40, 60, np.inf], labels=['Low', 'Normal', 'High'], right=False)
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'Cholesterol HDL', axes[1])

# Cholesterol LDL
df['Cholesterol LDL'] = pd.cut(df['CholesterolLDL'], bins=[0, 100, 200], labels=['Optimal', 'High'], right=False)
plot_kaplan_meier(df, 'Age', 'Diagnosis', 'Cholesterol LDL', axes[2])

DECISION TREE

In [None]:

# define the feature sets
biological_features = [
    'Age', 'Ethnicity', 'Gender', 'BMI', 'FamilyHistoryAlzheimers',
    'CardiovascularDisease', 'Diabetes', 'Hypertension',
    'SystolicBP', 'DiastolicBP', 'CholesterolTotal',
    'CholesterolLDL', 'CholesterolHDL', 'CholesterolTriglycerides']

cognitive_features = [
    'MMSE', 'FunctionalAssessment', 'MemoryComplaints',
    'BehavioralProblems', 'ADL', 'Confusion', 'Disorientation',
    'PersonalityChanges', 'DifficultyCompletingTasks', 'Forgetfulness']

lifestyle_features = [
    'Smoking', 'AlcoholConsumption', 'PhysicalActivity',
    'DietQuality', 'SleepQuality', 'Depression', 'HeadInjury',
    'EducationLevel']

feature_sets = {'Biological Features': biological_features, 'Cognitive Features': cognitive_features,'Lifestyle Features': lifestyle_features}


# Loop through each feature set
for feature_set_name, features in feature_sets.items():
    print(f"\nUsing {feature_set_name}:")
    
    X = df[features]  # Input features
    y = df['Diagnosis']  # Target label

    # Split data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)

    # Decision Tree Classifier
    model = DecisionTreeClassifier()
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
        
    # simplify the tree with entropy criterion
    model_entropy = DecisionTreeClassifier(criterion="entropy", max_depth=3)
    model_entropy.fit(X_train, y_train)
    y_pred_entropy = model_entropy.predict(X_test)
    entropy_accuracy = accuracy_score(y_test, y_pred_entropy)
    print(f"Entropy-Based Tree Accuracy: {entropy_accuracy:.2f}")

    # Visualize the simpler tree
    plt.figure(figsize=(18, 12))
    plot_tree(
        model_entropy, 
        feature_names=features, 
        class_names=['No Diagnosis', 'Diagnosis'], 
        filled=True, 
        rounded=True, 
        fontsize=10,
        proportion=True
    )
    ax = plt.gca()  
    for arrow in ax.patches:
        arrow.set_linewidth(0.55)  

    plt.title(f"Decision Tree Visualization for {feature_set_name}\nModel Accuracy: {accuracy_score(y_test, y_pred_entropy):.2f}", fontsize=16)
    plt.show()



