In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from lifelines import CoxPHFitter, KaplanMeierFitter
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.model_selection import train_test_split, KFold, GridSearchCV

# Step 1: Load the data
data = pd.read_csv('dataDIVAT3.csv')

# Step 2: Advanced Data Exploration
# Summary statistics and missing values
print(data.describe())
print(data.isnull().sum())

# Visualize distribution of key numerical features
for col in ['AGE', 'DIALYSIS_DURATION', 'time']:
    sns.histplot(data[col], kde=True)
    plt.title(f'Distribution of {col}')
    plt.show()

# Correlation matrix for numerical features
sns.heatmap(data.corr(), annot=True, cmap='coolwarm')
plt.title('Correlation Matrix')
plt.show()

# Check survival status distribution
sns.countplot(x='event', data=data)
plt.title('Distribution of Survival Status')
plt.show()

# Step 3: Handle Missing Data
# Fill missing numerical data with the mean
data['AGE'].fillna(data['AGE'].mean(), inplace=True)
data['DIALYSIS_DURATION'].fillna(data['DIALYSIS_DURATION'].mean(), inplace=True)

# Drop rows with missing critical values
data.dropna(subset=['time', 'event'], inplace=True)

# Step 4: Handle Imbalanced Data (if applicable)
# If the survival event is imbalanced, consider techniques like oversampling, undersampling, or SMOTE.
from imblearn.over_sampling import SMOTE

X = data.drop(columns=['event', 'time'])
y = data['event']

# Use SMOTE to oversample the minority class
smote = SMOTE(random_state=42)
X_res, y_res = smote.fit_resample(X, y)

# Update the DataFrame after resampling
data_resampled = pd.concat([pd.DataFrame(X_res, columns=X.columns), pd.DataFrame(y_res, columns=['event'])], axis=1)

# Step 5: One-Hot Encode Categorical Variables
categorical_cols = ['GENDER', 'DISEASE_TYPE']  # Update this list based on actual columns
encoder = OneHotEncoder(drop='first')
encoded_cats = encoder.fit_transform(data_resampled[categorical_cols]).toarray()
encoded_columns = encoder.get_feature_names_out(categorical_cols)
encoded_df = pd.DataFrame(encoded_cats, columns=encoded_columns)
data_resampled = pd.concat([data_resampled, encoded_df], axis=1).drop(columns=categorical_cols)

# Step 6: Feature Scaling
scaler = StandardScaler()
data_resampled[['AGE', 'DIALYSIS_DURATION']] = scaler.fit_transform(data_resampled[['AGE', 'DIALYSIS_DURATION']])

# Step 7: Train/Test Split
train_data, test_data = train_test_split(data_resampled, test_size=0.2, random_state=42)

# Prepare the training data for the Cox model
df_train_for_cox = train_data[['time', 'event', 'AGE', 'DIALYSIS_DURATION'] + list(encoded_columns)]

# Step 8: Fit the Cox Proportional Hazards Model with Cross-Validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)
concordance_indices = []

for train_index, test_index in kf.split(df_train_for_cox):
    df_train, df_test = df_train_for_cox.iloc[train_index], df_train_for_cox.iloc[test_index]
    cph = CoxPHFitter()
    cph.fit(df_train, duration_col='time', event_col='event')
    concordance_indices.append(cph.concordance_index_)

print("Average Concordance Index:", np.mean(concordance_indices))

# Step 9: Hyperparameter Tuning (if applicable)
# Assuming the Cox model has tunable parameters (lifelines does not have too many, but if needed)
# Example: None in basic CoxPHFitter, but show grid search for potential hyperparameters
# Example for other models that might need tuning

# Step 10: Compare with Kaplan-Meier Estimator
kmf = KaplanMeierFitter()
kmf.fit(train_data['time'], event_observed=train_data['event'])
kmf.plot_survival_function()
plt.title('Kaplan-Meier Survival Curve')
plt.show()

# Step 11: Predict Survival Probabilities on Test Data
df_test_for_cox = test_data[['AGE', 'DIALYSIS_DURATION'] + list(encoded_columns)]
test_data['predicted_survival'] = cph.predict_survival_function(df_test_for_cox).iloc[0].values

# Visualize survival curve for an example patient
new_patient = df_test_for_cox.iloc[0:1]
survival_function = cph.predict_survival_function(new_patient)
plt.plot(survival_function)
plt.title('Survival Function for an Example Patient')
plt.xlabel('Time')
plt.ylabel('Survival Probability')
plt.show()

# Step 12: Model Interpretation
cph.plot()
plt.title('Cox Model Coefficients')
plt.show()

# Step 13: Save and Deploy the Model
cph.save_model('cox_model_final_year.pkl')

# Load the model for future use
# cph_loaded = CoxPHFitter().load_model('cox_model_final_year.pkl')

