In [None]:
# Set environment variable to handle OpenMP runtime conflict
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

# Also set this for the current session
if os.name == 'nt':  # Windows
    try:
        import ctypes
        ctypes.CDLL('mkl_rt.dll')
    except:
        pass

In [None]:
import importlib
import ukko.survival
importlib.reload(ukko.survival)

# Import required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from lifelines import CoxPHFitter, LogLogisticAFTFitter, KaplanMeierFitter

# Import ukko survival functions
from ukko.survival import plot_KM, plot_loglogistic_hazard, generate_survival_data_LL

In [None]:
def generate_data(n_samples=1000, n_features=3, n_informative=1, 
                 shape=5, scale=10.0, censoring=0.3, random_seed=42):
    """Generate synthetic survival data with log-logistic distribution."""
    np.random.seed(random_seed)
    
    survival_data, true_coefficients = generate_survival_data_LL(
        n_samples, n_features, n_informative,
        shape, scale, censoring,
        nonlinear=False
    )
    feature_cols = [col for col in survival_data.columns if col.startswith('feature_')]
    cox_data = survival_data[feature_cols + ['observed_time', 'event_observed']]
    
    return survival_data, cox_data, true_coefficients, feature_cols

def fit_cox_model(cox_data):
    """Fit Cox proportional hazards model to survival data."""
    
    cph = CoxPHFitter()
    cph.fit(cox_data, duration_col='observed_time', event_col='event_observed')
    return cph

def plot_coefficient_comparison(model, true_coefficients, title):
    """Plot comparison between fitted and true coefficients."""
    plt.figure(figsize=(10, 7))
    ax = model.plot()
    
    
    #If model = coxph:
    feature_names = [f'feature_{i}' for i in range(len(true_coefficients))]
    model_params = model.params_.copy()  # Create a copy to avoid modifying original (for aft)
    xlabel = 'Log Hazard Ratio'

    if model._class_name == 'LogLogisticAFTFitter':
      # Update index for AFT model parameters
      new_index = [f"{cov}: {param}" if cov else param 
                  for param, cov in model.params_.index]
      model_params.index = new_index

      feature_names = [f'feature_{i}: alpha_' for i in range(len(true_coefficients))]
      #true_coef_series = pd.Series(true_coefficients, index=feature_names)
      #ordered_true_coef = true_coef_series.reindex(aftll_params.sort_values().index)

    print(feature_names)
    true_coef_series = pd.Series(true_coefficients, index=feature_names)
    ordered_true_coef = true_coef_series.reindex(model_params.sort_values().index)
    print(ordered_true_coef)

    ax.scatter(ordered_true_coef.values, range(len(ordered_true_coef)), 
              color='red', marker='o', s=80, zorder=5, label='True Coefficients')

    plt.title(title)
    plt.xlabel(xlabel)
    plt.legend()
    plt.show()
    
    return ordered_true_coef

def fit_aft_model(cox_data):
    """Fit AFT log-logistic model and plot results."""
    aftll = LogLogisticAFTFitter()
    aftll.fit(cox_data, duration_col='observed_time', event_col='event_observed')
        
    # Plot comparison
    plt.figure(figsize=(10, 7))
    #ax = aftll.plot()
    
    return aftll #, ordered_true_coef

def stratify_and_plot(survival_data, model, feature_cols, n_groups=4):
    """Stratify patients into multiple risk groups and plot KM curves."""
    # Predict survival and stratify
    predicted_survival = model.predict_median(survival_data[feature_cols])
    survival_data['predicted_median_survival'] = predicted_survival
    
    # Calculate quantile cutoffs
    quantiles = np.linspace(0, 1, n_groups + 1)
    cutoffs = np.quantile(survival_data['predicted_median_survival'], quantiles)
    
    # Plot KM curves
    plt.figure(figsize=(12, 8))
    
    # Create color map for groups
    colors = plt.cm.managua(np.linspace(0, 1, n_groups))
    
    for i in range(n_groups):
        if i == 0:
            mask = survival_data['predicted_median_survival'] <= cutoffs[1]
            group_label = f'Highest Risk (Q1)'
        elif i == n_groups - 1:
            mask = survival_data['predicted_median_survival'] > cutoffs[-2]
            group_label = f'Lowest Risk (Q{n_groups})'
        else:
            mask = (survival_data['predicted_median_survival'] > cutoffs[i]) & \
                   (survival_data['predicted_median_survival'] <= cutoffs[i+1])
            group_label = f'Q{i+1}'
        
        group = survival_data[mask]
        
        kmf = KaplanMeierFitter()
        kmf.fit(durations=group['observed_time'],
                event_observed=group['event_observed'],
                label=group_label)
        kmf.plot_survival_function(show_censors=True, ci_show=False, color=colors[i])
        
        print(f"{group_label} size: {len(group)}")
    
    plt.title(f'Kaplan-Meier Survival Curves by Risk Group ({n_groups} groups): {model._class_name}')
    plt.xlabel('Time')
    plt.ylabel('Survival Probability')
    plt.grid(True)
    plt.legend()
    plt.show()
    
    return cutoffs

def main():
    # Generate synthetic data
    survival_data, cox_data, true_coefficients, feature_cols = generate_data()
    
    # Plot hazard function
    plot_loglogistic_hazard([5], scale=10.0, max_time=100)
    
    # Show data summary
    print("\nLog-Logistic AFT Survival Data Head:")
    print(survival_data.head(3))
    #print("\nDescriptive Statistics:")
    #print(survival_data['observed_time'].describe())
    print(f"\nActual censoring: {1 - survival_data['event_observed'].mean():.2f}")
    
    # Plot KM curve
    plot_KM(survival_data)
    
    # Fit and evaluate models
    cph = fit_cox_model(cox_data)
    print("\nCox PH Model Summary:")
    cph.print_summary()
    
    ordered_true_coef_cox = plot_coefficient_comparison(
        cph, true_coefficients, 
        'Cox PH Model Fitted vs. True Coefficients'
    )
    
    aftll = fit_aft_model(cox_data)
    print("\nAFT Log-logistic Model Summary:")
    aftll.print_summary()

    ordered_true_coef_aftll = plot_coefficient_comparison(
        aftll, true_coefficients, 
        'AFT LL Model Fitted vs. True Coefficients' 
    )
    
    # Stratify and plot risk groups
    cutoffs = stratify_and_plot(survival_data, cph, feature_cols, n_groups=5)

if __name__ == "__main__":
    main()