# ACTIN-1289 Test Kaplan Meier assumptions
In order to actually use the kaplan-meier method the data needs to adhere to specific assumptions (see: https://docs.google.com/document/d/1s1TUogmw6y0wAti4xqoJpvNYCkkSjb6tJvVo_t-XNns/edit?usp=sharing).

In this notebook we look into the assumptions that need some data-investigation.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import kruskal
import pymysql

In [None]:
db_connection = pymysql.connect(
    read_default_file='/home/jupyter/.my.cnf',
    read_default_group='RAnalysis', 
    db = 'actin_personalization_v2'
)

query = "SELECT * FROM knownPalliativeTreatedReference "

km_df = pd.read_sql(query, db_connection)

db_connection.close()

km_df.head()

## Assumption 5
Assumption 5 states:

> There should be no secular trends (also known as secular changes). 
> A characteristic of many studies that involve survival analysis is that:
> - (a) there is often a long time period between the start and end of the experiment; 
> - and (b) not all cases (e.g., participants) tend to start the experiment at the same time. 


In [None]:
df = km_df[km_df['diagnosisYear'] != 2022]
years = df['diagnosisYear'].unique()

def plot_survival_distribution_per_year(df, years, survival_column):
    
    survival_data = [df[df['diagnosisYear'] == year][survival_column].dropna() for year in years]
    kruskal_test = kruskal(*survival_data)

    print("Kruskal-Wallis test results:", kruskal_test)

    plt.figure(figsize=(10, 6))
    sns.boxplot(x='diagnosisYear', y=survival_column, data=df)
    plt.title(f"{survival_column} Distribution by Start Year")
    plt.xticks(rotation=45)
    plt.show()
        

In [None]:
plot_survival_distribution_per_year(df, years, survival_column = 'daysBetweenTreatmentStartAndProgression')

In [None]:
plot_survival_distribution_per_year(df, years, survival_column = 'survivalDaysSinceTreatmentStart')

In [None]:
treatments = df['firstSystemicTreatmentAfterMetastaticDiagnosis'].dropna().unique()

def plot_survival_distribution_per_year_per_treatment(treatments, years, survival_column, num_cols = 4):
    num_rows = (len(treatments) + num_cols - 1) // num_cols  
    

    fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(18, num_rows * 4))
    fig.tight_layout(pad=5.0)
    axes = axes.flatten()

    kruskal_results = []

    for i, treatment in enumerate(treatments):
        treatment_data = df[df['firstSystemicTreatmentAfterMetastaticDiagnosis'] == treatment]
        survival_data = [treatment_data[treatment_data['diagnosisYear'] == year][survival_column].dropna() for year in years]

        filtered_survival_data = [data for data in survival_data if len(data) > 0]

        if len(filtered_survival_data) > 1:
            kruskal_test = kruskal(*filtered_survival_data)
            kruskal_results.append([treatment, kruskal_test.statistic, kruskal_test.pvalue])

            sns.boxplot(x='diagnosisYear', y=survival_column, data=treatment_data, ax=axes[i])
            axes[i].set_title(f"{survival_column} Distribution by Year for {treatment}")
            axes[i].set_xticks(axes[i].get_xticks())
            axes[i].set_xticklabels(axes[i].get_xticklabels(), rotation=45)

    for j in range(i+1, len(axes)):
        fig.delaxes(axes[j])

    plt.show()
    
    return pd.DataFrame(kruskal_results, columns=['Treatment', 'Kruskal-Wallis Statistic', 'p-value'])

In [None]:
pfs_kruskal_results = plot_survival_distribution_per_year_per_treatment(treatments, years, survival_column = 'daysBetweenTreatmentStartAndProgression')

pfs_kruskal_results

In [None]:
os_kruskal_results = plot_survival_distribution_per_year_per_treatment(treatments, years, survival_column = 'survivalDaysSinceTreatmentStart')

os_kruskal_results

## Assumption 6

Assumption 6 states: 
> There should be a similar amount and pattern of censorship per group. 
One of the assumptions of the Kaplan-Meier method and the statistical tests for differences between group survival distributions (e.g., the log rank test, which we discuss later in the guide) is that censoring is similar in all groups tested. 

In [None]:
def plot_censoring_per_treatment(km_df, survival_event_column, exclude_small_sample_sizes=True, small_sample_size_threshold = 50):
    event_counts = km_df.pivot_table(index='firstSystemicTreatmentAfterMetastaticDiagnosis', 
                                  columns=survival_event_column,
                                  values='sourceId', 
                                  aggfunc='count', 
                                  fill_value=0)

    event_counts['censored_percentage'] = event_counts[0] / (event_counts[1] + event_counts[0]) * 100

    event_counts['total_events'] = event_counts[1] + event_counts[0]
    event_counts['censor_events'] = event_counts[0]
    event_counts['progression_events'] = event_counts[1]

    if exclude_small_sample_sizes:
        event_counts = event_counts[event_counts['total_events'] >= small_sample_size_threshold]

    plt.figure(figsize=(10, 6))
    bars = event_counts['censored_percentage'].plot(kind='bar')
    plt.title(f"Censoring Percentage by Treatment Group (Using {survival_event_column})")
    plt.ylabel("Censored Percentage")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()

    plt.show()
    
    return event_counts

def kruskall_wallis_test_treatment_type(event_counts, exclude_small_sample_sizes=True, small_sample_size_threshold = 50):
    event_counts = event_counts.reset_index()

    if exclude_small_sample_sizes:
        event_counts = event_counts[event_counts['total_events'] >= small_sample_size_threshold]

    event_counts['Treatment_Type'] = event_counts['firstSystemicTreatmentAfterMetastaticDiagnosis'].apply(
        lambda x: 'Immunotherapy' if x in ['PEMBROLIZUMAB', 'NIVOLUMAB'] else 'Chemotherapy'
    )

    chemotherapy_data_filtered = event_counts[event_counts['Treatment_Type'] == 'Chemotherapy']
    immunotherapy_data_filtered = event_counts[event_counts['Treatment_Type'] == 'Immunotherapy']
    combined_data_filtered = event_counts

    # For Chemotherapy group
    chemotherapy_censoring_filtered = [chemotherapy_data_filtered[chemotherapy_data_filtered['firstSystemicTreatmentAfterMetastaticDiagnosis'] == treatment]['censored_percentage']
                                       for treatment in chemotherapy_data_filtered['firstSystemicTreatmentAfterMetastaticDiagnosis'].unique()]
    kruskal_chem_filtered = kruskal(*chemotherapy_censoring_filtered)
    print(f"Kruskal-Wallis Test for Chemotherapy treatments (filtered): {kruskal_chem_filtered}")

    # For Immunotherapy group
    if not exclude_small_sample_sizes:
        immunotherapy_censoring_filtered = [immunotherapy_data_filtered[immunotherapy_data_filtered['firstSystemicTreatmentAfterMetastaticDiagnosis'] == treatment]['censored_percentage']
                                            for treatment in immunotherapy_data_filtered['firstSystemicTreatmentAfterMetastaticDiagnosis'].unique()]
        kruskal_immu_filtered = kruskal(*immunotherapy_censoring_filtered)
        print(f"Kruskal-Wallis Test for Immunotherapy treatments (filtered): {kruskal_immu_filtered}")

    # For combined Chemotherapy + Immunotherapy group
    combined_censoring_filtered = [combined_data_filtered[combined_data_filtered['firstSystemicTreatmentAfterMetastaticDiagnosis'] == treatment]['censored_percentage']
                                   for treatment in combined_data_filtered['firstSystemicTreatmentAfterMetastaticDiagnosis'].unique()]
    kruskal_combined_filtered = kruskal(*combined_censoring_filtered)
    print(f"Kruskal-Wallis Test for Combined treatments (filtered): {kruskal_combined_filtered}")

In [None]:
pfs_event_counts = plot_censoring_per_treatment(km_df, survival_event_column = 'hadProgressionEvent' )
kruskall_wallis_test_treatment_type(pfs_event_counts)

In [None]:
os_event_counts = plot_censoring_per_treatment(km_df, survival_event_column = 'hadSurvivalEvent' )
kruskall_wallis_test_treatment_type(os_event_counts)