# Assignment 4 Notebook
## Kaplan-Meier Analysis

In [None]:
import pandas as pd 
import numpy as np
import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter

# import data
data = pd.read_excel("../Data/RADCURE_Clinical_v04_20241219.xlsx")

# clean data
clean_data = data.copy()
missing_data = clean_data.isnull().sum() / len(clean_data) # calculate ratio of missing values per column

drop_cols = missing_data[missing_data > 0.6].index # create an index of columns that exceed threshold
clean_data.drop(drop_cols, 
                axis = 1, 
                inplace = True) # remove qualifying columns

clean_data.dropna(inplace = True)

relevant_data = clean_data[["Age", 
                           "Sex", 
                           "ECOG PS", 
                           "Smoking Status", 
                           "Ds Site", 
                           "Path", 
                           "Status", 
                           "Length FU"]]

for col in relevant_data.select_dtypes(include = [np.object_]).columns:
    relevant_data[col].replace(relevant_data[col].unique(),
                            range(0, len(relevant_data[col].unique())), 
                            inplace = True)
    relevant_data[col].astype(int)

# fit KME
kmf = KaplanMeierFitter()
kmf.fit(relevant_data['Length FU'], 
        event_observed = relevant_data['Status'])

# Plot the Kaplan-Meier curve
kmf.plot_survival_function()
plt.title('Kaplan-Meier Curve')
plt.xlabel('Time (Years)')
plt.ylabel('Survival Probability')
plt.show()

## Cox Proportional Hazards Model

In [None]:
from lifelines import CoxPHFitter

# fit Cox PH model
cph = CoxPHFitter()
cph.fit(relevant_data, 
        duration_col = 'Length FU', 
        event_col = 'Status')

# Print the summary of the model
cph.print_summary()

# Plot the coefficients
cph.plot()
plt.title('Cox Regression Coefficients')
plt.show()

# check assumptions
cph.check_assumptions(relevant_data, 
                      p_value_threshold = 0.05, 
                      show_plots = True)

## Random Survival Forest

In [None]:
from sksurv.ensemble import RandomSurvivalForest
from sksurv.preprocessing import OneHotEncoder
from sklearn.inspection import permutation_importance

data_x = clean_data.drop('Status', 
                         axis = 1)
data_y = clean_data['Status']

# Encode categorical variables
encoder = OneHotEncoder()
data_x = encoder.fit_transform(data_x)

# Train a Random Survival Forest model
rsf = RandomSurvivalForest(random_state = 27)
rsf.fit(data_x, data_y)

# please refer to https://scikit-survival.readthedocs.io/en/stable/user_guide/random-survival-forest.html
result = permutation_importance(rsf, 
                                data_x, 
                                data_y, 
                                n_repeats=15, 
                                random_state = 27)
feature_importance = pd.DataFrame(
         {
        k: result[k]
        for k in (
            "importances_mean",
            "importances_std",
        )
    },
    index = data_x.columns,
).sort_values(by = "importances_mean", 
              ascending = False)

# Sort by importances_mean and plot
feature_importance = feature_importance.sort_values(by = "importances_mean", 
                                                    ascending = False)

plt.figure(figsize=(10, 6))
plt.title('Feature Importances')
plt.barh(feature_importance.index, 
         feature_importance['importances_mean'], 
         xerr = feature_importance['importances_std'], 
         align = 'center')
plt.xlabel('Mean Importance')
plt.ylabel('Features')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()