In [None]:
from copy import deepcopy
from joblib import Parallel, delayed
from pathlib import Path
from pprint import pprint 

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle

from sklearn import set_config
from sklearn.base import clone
from sklearn.utils import resample

from sksurv.metrics import concordance_index_censored
from sksurv.util import Surv

pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", 10)
pd.set_option("display.max_colwidth", 100)
pd.set_option("display.width", 110)

set_config(display="text")  # displays text representation of estimators

In [None]:
import warnings

warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", RuntimeWarning)

In [3]:
# local paths
data_dir_proc = Path('data/preprocess')
data_dir_eda = Path('data/eda')
data_dir_wgcna = Path('data/wgcna')
results_dir = Path('results')

In [4]:
# method = 'coxnet_pca'
method = 'coxnet_wgcna'

In [5]:
event_col = 'OS'
time_col = 'OS.time'
survival = pd.read_parquet(data_dir_proc / 'survival.parquet')
y = Surv.from_dataframe(event=event_col, time=time_col, data=survival)
y = pd.DataFrame(y, index = survival.index)

clinical_cat = pd.read_parquet(data_dir_eda / 'clinical_cat.parquet')
clinical_num = pd.read_parquet(data_dir_eda / 'clinical_num.parquet')
clinical = clinical_cat.merge(clinical_num, how='inner', left_index=True, right_index=True)

cls_path = sorted(list(results_dir.glob(f'cls_{method}*.sav')))[-1]
print(cls_path)
cls = pickle.load(open(cls_path, 'rb'))

if method == 'coxnet_pca':
    expression = pd.read_parquet(Path(data_dir_eda) / 'expression.parquet')
    X = clinical.merge(expression, how='inner', left_index=True, right_index=True)

if method == 'coxnet_wgcna':
    eigengenes = pd.read_parquet(data_dir_wgcna / 'eigengenes.parquet')
    X = clinical.merge(eigengenes, how='inner', left_index=True, right_index=True)

results/cls_coxnet_wgcna_20251204_002712.sav


In [6]:
def bootstrap_iteration(X, y, cls):

    warnings.filterwarnings("ignore", category=UserWarning)
    warnings.filterwarnings("ignore", category=RuntimeWarning)

    indices = resample(X.index, replace=True)
    oob = np.setdiff1d(X.index, indices)

    X_boot, y_boot = X.loc[indices, :], y.loc[indices, :]
    X_oob,  y_oob  = X.loc[oob, :],  y.loc[oob, :]

    model = clone(cls.best_estimator_)
    model.set_params(**cls.best_params_)
    model.fit(X_boot, Surv.from_dataframe(event=event_col, time=time_col, data=y_boot))

    if len(oob) == 0:
        return None

    try:
        risk = model.predict(X_oob)
        cindex_score = concordance_index_censored(y_oob[event_col], y_oob[time_col], risk)[0]
        return cindex_score
    except Exception as e:
        print(f"skipped due to {e}")
        return None


In [7]:
bootstrap_iteration(X, y, cls)

np.float64(0.7399310595065312)

In [8]:
B = 1000

cindex_scores = Parallel(n_jobs=-1)(
    delayed(bootstrap_iteration)(X, y, cls)
    for _ in range(B)
)

In [9]:
cindex_scores = pd.Series(cindex_scores).dropna()
len(cindex_scores)

1000

In [10]:

lower, upper = np.percentile(cindex_scores, [2.5, 97.5])
mean_score = np.mean(cindex_scores)

In [11]:
print(f'{method} bootstrap estimate : {mean_score:.3} ({lower:.3}, {upper:.3})')

coxnet_wgcna bootstrap estimate : 0.751 (0.701, 0.801)
