In [1]:
# /*==========================================================================================*\
# **                        _           _ _   _     _  _         _                            **
# **                       | |__  _   _/ | |_| |__ | || |  _ __ | |__                         **
# **                       | '_ \| | | | | __| '_ \| || |_| '_ \| '_ \                        **
# **                       | |_) | |_| | | |_| | | |__   _| | | | | | |                       **
# **                       |_.__/ \__,_|_|\__|_| |_|  |_| |_| |_|_| |_|                       **
# \*==========================================================================================*/


# -----------------------------------------------------------------------------------------------
# Author: Bùi Tiến Thành - Tien-Thanh Bui (@bu1th4nh)
# Title: playground_survival.ipynb
# Date: 2025/02/05 17:30:55
# Description: 
# 
# (c) 2025 bu1th4nh. All rights reserved. 
# Written with dedication in the University of Central Florida, EPCOT and the Magic Kingdom.
# -----------------------------------------------------------------------------------------------

import os
import warnings 
warnings.filterwarnings("ignore") 
os.environ["PYTHONWARNINGS"] = "ignore::UserWarning"



import logging
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Dict, Any, Tuple, Union, Literal

import pymongo
from s3fs import S3FileSystem

from sklearn import set_config
from sklearn.model_selection import GridSearchCV, KFold, train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.exceptions import FitFailedWarning

from sksurv.linear_model import CoxnetSurvivalAnalysis
import lifelines

import matplotlib.pyplot as plt




key = 'bu1th4nh'
secret = 'ariel.anna.elsa'
endpoint_url = 'http://localhost:9000'

s3 = S3FileSystem(
    anon=False, 
    endpoint_url=endpoint_url,
    key=key,
    secret=secret,
    use_ssl=False
)
storage_options = {
    'key': key,
    'secret': secret,
    'endpoint_url': endpoint_url,
}


mongo = pymongo.MongoClient(
    host='mongodb://localhost',
    port=27017,
    username='bu1th4nh',
    password='ariel.anna.elsa',
)




In [3]:
def data_selection(omics_mode, disease):
    base_data_path = 's3://datasets'
    base_result_path = 's3://results'
    if omics_mode == "3omics":
        mongo_db_name           = 'SimilarSampleCrossOmicNMF_3Omics'
        base_result_path        = f'{base_result_path}/SimilarSampleCrossOmicNMF_3Omics'
        omic_folder             = 'processed_3_omics_mRNA_miRNA_methDNA'
        cls_target_folder       = 'clinical_testdata_3_omics_mRNA_miRNA_methDNA'
        surv_target_folder      = 'survivalanalysis_testdata_3_omics_mRNA_miRNA_methDNA'
        experiment_addon_ext    = '_3Omics'
    elif omics_mode == "2omics":
        mongo_db_name           = 'SimilarSampleCrossOmicNMF'
        base_result_path        = f'{base_result_path}/SimilarSampleCrossOmicNMF'
        omic_folder             = 'processed_2_omics_mRNA_miRNA'
        cls_target_folder       = 'clinical_testdata_2_omics_mRNA_miRNA'
        surv_target_folder      = 'survivalanalysis_testdata_2_omics_mRNA_miRNA'
        experiment_addon_ext    = ''


    # Disease
    if disease == "brca":
        dataset_id              = 'BRCA'
        mongo_collection        = 'BRCA'
        disease_data_folder     = 'BreastCancer'
        disease_result_folder   = 'brca'
        experiment_name         = f'SimilarSampleCrossOmicNMFv3_BRCA{experiment_addon_ext}'
    elif disease == "luad":
        dataset_id              = 'LUAD'
        mongo_collection        = 'LUAD'
        disease_data_folder     = 'LungCancer'
        disease_result_folder   = 'luad'
        experiment_name         = f'SimilarSampleCrossOmicNMFv3_LUAD{experiment_addon_ext}'
    elif disease == "ov":
        dataset_id              = 'OV'
        mongo_collection        = 'OV'
        disease_data_folder     = 'OvarianCancer'
        disease_result_folder   = 'ov'
        experiment_name         = f'SimilarSampleCrossOmicNMFv3_OV{experiment_addon_ext}'
    elif disease == "test":
        dataset_id              = 'test'
        mongo_collection        = 'TEST'
        disease_data_folder     = 'BreastCancer'
        disease_result_folder   = 'test'
        experiment_name         = 'test_experiment'

    
    # Aggregate
    SA_TARG_PATH = f'{base_data_path}/{disease_data_folder}/{surv_target_folder}'
    DATA_PATH = f'{base_data_path}/{disease_data_folder}/{omic_folder}'
    RESULT_PRE_PATH = f'{base_result_path}/{disease_result_folder}'


    return mongo_db_name, dataset_id, DATA_PATH, RESULT_PRE_PATH, SA_TARG_PATH

In [None]:
from downstream.survival import surv_analysis

for omics_mode in ['2omics', '3omics']:
    for disease in ['brca', 'luad', 'ov']:
        mongo_db_name, dataset_id, DATA_PATH, RESULT_PRE_PATH, SA_TARG_PATH = data_selection(omics_mode, disease)

        # Obtain db and collection
        mongo_db = mongo[mongo_db_name]
        hparams_runs = mongo_db['HPARAMS_OPTS']
        surv_result = mongo_db['SURVIVAL_ANALYSIS']


        # Obtain survival analysis targets
        surv_targets_data = {}
        surv_target_folder = [f's3://{a}' for a in s3.ls(SA_TARG_PATH)]
        for tar in tqdm(surv_target_folder, desc='Preloading target data'):
            target_id = str(tar.split('/')[-1]).split('.')[0]
            # print(target_id, tar)
            try:
                surv_targets_data[target_id] = pd.read_parquet(tar, storage_options=storage_options)
                print(surv_targets_data[target_id].columns)

            except FileNotFoundError:
                logging.error(f"Target {tar} not found. Skipping...")   



        # Get target id for each disease => find the best hparams for each target
        classification_target_ids_for_disease = hparams_runs.find(
            {'dataset': dataset_id},
        ).distinct('target_id')
        


        # Get the best hparams for each target and run SA
        for classification_target_id in classification_target_ids_for_disease:
            Ariel = (
                pd.DataFrame.from_records(
                    hparams_runs
                    .find(
                        {
                            "dataset": dataset_id,
                            "target_id": classification_target_id,
                        },
                        {
                            "_id": 0,
                            "test_id": 1,   
                            "config": 1,
                            "AUROC": 1,
                        }
                    ).to_list()
                )[['config', 'AUROC']]
                .groupby('config')
                .mean()
            )
            best_cfg = Ariel.index[np.argmax(Ariel.values)]
            H = pd.read_parquet(f'{RESULT_PRE_PATH}/{best_cfg}/H.parquet', storage_options=storage_options)


            # Get all survival analysis targets
            for surv_target_id in surv_targets_data.keys():
                survival = surv_targets_data[surv_target_id]
                train_sample_ids, test_sample_ids = train_test_split(survival.index, test_size=0.2)

                if surv_target_id == 'survival': 
                    event_label = 'Overall Survival Status'
                    time_label = 'Overall Survival (Months)'
                else:
                    event_label = 'Disease Free Status'
                    time_label = 'Disease Free (Months)'

                attempt = 0
                while True:
                    attempt += 1
                    print(f'SA for {dataset_id} with {classification_target_id}, config {best_cfg} and {surv_target_id}, attempt {attempt}')

                    try:
                        surv_result = surv_analysis(
                            H,
                            train_sample_ids,
                            survival.loc[train_sample_ids],
                            test_sample_ids,
                            survival.loc[test_sample_ids],
                            event_label,
                            time_label,
                        )


                        if surv_result['p_value'] < 0.05:
                            print(surv_result)
                            break
                        else:
                            print('p-value > 0.05:', surv_result['p_value'])
                    except:
                        print('Error')
                        continue

            




Preloading target data: 100%|██████████| 2/2 [00:00<00:00, 14.29it/s]


Index(['Disease Free Status', 'Disease Free (Months)'], dtype='object')
Index(['Overall Survival Status', 'Overall Survival (Months)'], dtype='object')
SA for BRCA with ER, config k-100-alpha-0.01-beta-0.01-gamma-overridden and diseasefree, attempt 1
Error
SA for BRCA with ER, config k-100-alpha-0.01-beta-0.01-gamma-overridden and diseasefree, attempt 2


In [None]:

print(Ariel['results'])

X_low = Ariel['kaplan_meier_curve']['X_low']
Y_low = Ariel['kaplan_meier_curve']['Y_low']
X_high = Ariel['kaplan_meier_curve']['X_high']
Y_high = Ariel['kaplan_meier_curve']['Y_high']
censor_low = Ariel['kaplan_meier_curve']['censor_low']
censor_high = Ariel['kaplan_meier_curve']['censor_high']
censor_low_pred = Ariel['kaplan_meier_curve']['censor_low_pred']
censor_high_pred = Ariel['kaplan_meier_curve']['censor_high_pred']
low_risk_ids = Ariel['test_low_risk_ids']
high_risk_ids = Ariel['test_high_risk_ids']
p_value = Ariel['p_value']


# Plot survival functions
plt.figure(figsize=(10, 6))
plt.step(X_low, Y_low, where="post", label=f"low-risk ({len(low_risk_ids)})", color="blue", linestyle="--")
plt.step(X_high, Y_high, where="post", label=f"high-risk ({len(high_risk_ids)})", color="red", linestyle="-")

# Plot censor points
plt.scatter(censor_high, censor_high_pred, marker='+', color='black')
plt.scatter(censor_low, censor_low_pred, marker='+', color='black')

# Add labels and legend
plt.title("Kaplan-Meier Curves")
plt.xlabel("Time (Months)")
plt.ylabel("Survival Probability")
plt.legend()
# plt.grid(True)

# Add p-value in a box on the bottom left of the plot
plt.text(0.03, 0.05, f'p-value: {p_value:.4f}', transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=0.5, edgecolor='black'))


# Show plot
plt.show()