In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from src.features.cleaning import clean_data, split_X_and_y_data
from src.utils import get_project_root

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


## Pre-processing (TODO: recode tumor size and check for other known descriptors like the MGMT gene and prescence of genetic mutation)

In [3]:
dir_root = get_project_root()
dir_data = 'data'
filename_data_brain = os.path.join(
    dir_root,
    dir_data,
    'survival_brain_2000_to_2020_seer_2022_db.csv'
)
data_raw = pd.read_csv(filename_data_brain)

  data_raw = pd.read_csv(filename_data_brain)


In [4]:
data = clean_data(data_raw)

In [5]:
# Age groupings according to age standard for survival type II. One typically stratifies on these
# groups or uses the corresponding population weights to perform an age-adjusted analysis.
#
# The groupings were designed so as to minimize the difference between raw and age-standardized
# 5-year survival ratios in the EUROCARE-2 dataset.
#
# Hence, it may be best to use these groupings for our age categorical.
data['Age standard for survival'].value_counts()

Age standard for survival
15    18546
55    17306
65    15350
75    12612
45    12513
Name: count, dtype: int64

In [6]:
# Site recoding for tumors of adolescents and young adults (AYA), ages 15-39
data['AYA site recode 2020 Revision'].value_counts()

AYA site recode 2020 Revision
3.1.2.2 Glioblastoma - invasive                                45457
3.1.4.3 Other astrocytoma/astroglial - invasive                16247
3.1.1.2 Oligodendroglioma - invasive                            7509
3.10.2 Other and unspecified CNS - invasive                     2213
3.1.4.1 Pilocytic astrocytoma                                   1780
3.1.3.2 Ependymoma - invasive                                   1206
3.2 Medulloblastoma and other invasive embryonal CNS tumors     1108
7.3 Germ cell and trophoblastic - CNS                            242
5.2.2 Other                                                      189
3.4.2 Neuronal and mixed neuronal-glial - invasive                86
4.13 Chordoma                                                     73
3.6.2 Choroid plexus - invasive                                   37
3.3.2 Neuroblastoma/ganglioneuroblastoma - invasive               31
4.15 Other soft tissue sarcomas                                   24
4.2 

In [7]:
# Site recoding grouped by major histological categories for brain-specific tumors
data['SEER Brain and CNS Recode'].value_counts()

SEER Brain and CNS Recode
1.1.2 Glioblastoma                                      45457
1.1.1 Diffuse astrocytoma and anaplastic astrocytoma    11688
1.1.4 Oligodendroglioma                                  5442
1.1.9 Glioma, unspecified                                4040
1.6 Other Malignant Brain/ONS                            2831
1.1.6 Other astrocytic tumors                            2128
1.1.5 Oligoastrocytoma                                   1950
1.1.8 Ependymal tumors                                   1203
1.2 Embryonal tumors                                     1134
1.1.10 Other                                              210
1.1.3 Diffuse midline glioma, H3 K27M-mutant               85
1.5 Neuronal and mixed neuronal-glial tumors               81
1.1.7 Astroblastoma                                        41
1.4 Choroid plexus tumors                                  37
Name: count, dtype: int64

In [8]:
# Almost identical conceptually to grade recode
data['CS site-specific factor 1 (2004-2017 varying by schema)'].value_counts()

CS site-specific factor 1 (2004-2017 varying by schema)
040    24422
999    12264
030     6169
020     5754
998     2436
010     1258
Name: count, dtype: int64

In [9]:
data['Grade Recode (thru 2017)'].value_counts()

Grade Recode (thru 2017)
Unknown                                   50140
Undifferentiated; anaplastic; Grade IV    19676
Moderately differentiated; Grade II        3388
Poorly differentiated; Grade III           1999
Well differentiated; Grade I               1124
Name: count, dtype: int64

In [10]:
data['Brain Molecular Markers (2018+)'].value_counts()

Brain Molecular Markers (2018+)
Glioblastoma, IDH-wildtype (9440/3)                                                 6042
Not documented; No microscopic confirmation; Not assessed or unknown if assessed    2587
NA: Histology not 9400/3, 9401/3, 9440/3, 9450/3, 9451/3, 9471/3, 9478/3            1294
Oligodendroglioma, IDH-mutant and 1 p/19q co-deleted (9450/3)                        437
Diffuse astrocytoma, IDH-mutant (9400/3)                                             387
Anaplastic astrocytoma, IDH-mutant (9401/3)                                          376
Anaplastic astrocytoma, IDH-wildtype (9401/3)                                        362
Diffuse astrocytoma, IDH-wildtype (9400/3)                                           242
Anaplastic oligodendroglioma, IDH-mutant and 1 p/19q co-deleted (9451/3)             214
Medulloblastoma, SHH-activated and TP53-wildtype (9471/3)                             42
Test ordered, results not in chart                                            

In [11]:
data['Chromosome 19q: Loss of Heterozygosity (LOH) Recode (2010+)'].value_counts()

Chromosome 19q: Loss of Heterozygosity (LOH) Recode (2010+)
Not documented; Cannot be determined; Not assessed or unknown if assessed    34613
Chromosome 19q deletion/LOH not identified/not present                        4517
Chromosome 19q deletion/LOH present                                           2154
Not applicable: Information not collected for this case                       1874
Test ordered, results not in chart                                             126
Name: count, dtype: int64

In [12]:
data['Chromosome 1p: Loss of Heterozygosity (LOH) Recode (2010+)'].value_counts()

Chromosome 1p: Loss of Heterozygosity (LOH) Recode (2010+)
Not documented; Cannot be determined; Not assessed or unknown if assessed    34656
Chromosome 1p deletion/LOH not identified/not present                         4458
Chromosome 1p deletion/LOH identified/present                                 2169
Not applicable: Information not collected for this case                       1875
Test ordered, results not in chart                                             126
Name: count, dtype: int64

## Train-Test Split

In [13]:
from sklearn.model_selection import train_test_split
from sksurv.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer, make_column_transformer, make_column_selector
from sklearn.impute import SimpleImputer
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, KFold
from sklearn.feature_selection import SelectKBest
from scipy import stats

In [14]:
data_train, data_test = train_test_split(
    data,
    test_size=0.2,
    random_state=2984,
    shuffle=True,
    stratify=data['Vital status recode (study cutoff used)']
)

X_train, y_train = split_X_and_y_data(data_train)
X_test, y_test = split_X_and_y_data(data_test)

In [15]:
X_train.info(verbose=True, show_counts=True)

<class 'pandas.core.frame.DataFrame'>
Index: 61061 entries, 36640 to 46191
Data columns (total 128 columns):
 #    Column                                                                 Non-Null Count  Dtype   
---   ------                                                                 --------------  -----   
 0    Sex                                                                    61061 non-null  category
 1    Year of diagnosis                                                      61061 non-null  int64   
 2    PRCDA 2020                                                             61061 non-null  object  
 3    Race recode (W, B, AI, API)                                            61061 non-null  category
 4    Origin recode NHIA (Hispanic, Non-Hisp)                                61061 non-null  category
 5    Race and origin recode (NHW, NHB, NHAIAN, NHAPI, Hispanic)             61061 non-null  category
 6    TNM 7/CS v0204+ Schema (thru 2017)                                   

In [16]:
X_train['Sex'].value_counts()

Sex
Male      34659
Female    26402
Name: count, dtype: int64

## Model Training

In [17]:
from sksurv.nonparametric import kaplan_meier_estimator
from sksurv.linear_model import CoxPHSurvivalAnalysis, CoxnetSurvivalAnalysis
from sksurv.metrics import concordance_index_censored

from src.transformers.transformers import DataFrameTransformer

In [18]:
impute_ohe_scale = make_pipeline(
    SimpleImputer(strategy='constant', add_indicator=True).set_output(
        transform="pandas"),
    DataFrameTransformer(),
    OneHotEncoder(),
    StandardScaler()
)
impute_mean_scale = make_pipeline(
    SimpleImputer(strategy='mean', add_indicator=True),
    StandardScaler()
)

column_transformer = make_column_transformer(
    (StandardScaler(), [
        # 'Age recode with <1 year olds',
        'Year of diagnosis',
        'Median household income inflation adj to 2021 (thousands USD)',
        'Total number of in situ/malignant tumors for patient',  # should we restrict the cohort to N=1?
        'Combined Tumor Size',
        'No tumor found',
        'Unknown tumor size',
    ]),
    (impute_ohe_scale, [
        'Age standard for survival',  # handle standardized age groupings categorically
        'Sex',
        'Race recode (W, B, AI, API)',
        'Race and origin recode (NHW, NHB, NHAIAN, NHAPI, Hispanic)',
        'Marital status at diagnosis',
        # 'AYA site recode 2020 Revision',
        'SEER Brain and CNS Recode',
        'Primary Site - labeled',
        'Histologic Type ICD-O-3',
        'Grade Recode (thru 2017)',
        'Diagnostic Confirmation',
        'Histology recode - broad groupings',
        # 'SEER Combined Summary Stage 2000 (2004-2017)',
        'Chromosome 19q: Loss of Heterozygosity (LOH) Recode (2010+)',
        'Chromosome 1p: Loss of Heterozygosity (LOH) Recode (2010+)',
        'Brain Molecular Markers (2018+)',
    ]),
    (impute_mean_scale, [
        'Months from diagnosis to treatment',
    ]),
    # (, ['']),
    remainder='drop',
    n_jobs=-1,
)

In [19]:
ct = column_transformer.fit(X_train)
feature_names = ct.get_feature_names_out()
feature_names

array(['standardscaler__Year of diagnosis',
       'standardscaler__Median household income inflation adj to 2021 (thousands USD)',
       'standardscaler__Total number of in situ/malignant tumors for patient',
       'standardscaler__Combined Tumor Size',
       'standardscaler__No tumor found',
       'standardscaler__Unknown tumor size',
       'pipeline-1__Age standard for survival=45',
       'pipeline-1__Age standard for survival=55',
       'pipeline-1__Age standard for survival=65',
       'pipeline-1__Age standard for survival=75', 'pipeline-1__Sex=Male',
       'pipeline-1__Race recode (W, B, AI, API)=Asian or Pacific Islander',
       'pipeline-1__Race recode (W, B, AI, API)=Black',
       'pipeline-1__Race recode (W, B, AI, API)=Unknown',
       'pipeline-1__Race recode (W, B, AI, API)=White',
       'pipeline-1__Race and origin recode (NHW, NHB, NHAIAN, NHAPI, Hispanic)=Non-Hispanic American Indian/Alaska Native',
       'pipeline-1__Race and origin recode (NHW, NHB, NHAIA

In [20]:
pipeline = make_pipeline(
    column_transformer,
    CoxnetSurvivalAnalysis(verbose=True, fit_baseline_model=False)
)

In [21]:
pipeline.fit(X_train, y_train.to_records(index=False))

In [22]:
X_transformed = column_transformer.fit_transform(X_train)
X_transformed

array([[ 0.40049512, -1.36485533, -0.15788384, ...,  0.43276706,
        -0.42146307, -0.42691093],
       [ 1.07362015, -1.36485533, -0.15788384, ...,  0.43276706,
        -0.42146307, -0.42691093],
       [-0.27262991, -1.36485533, -0.15788384, ...,  0.43276706,
         1.43958292, -0.42691093],
       ...,
       [ 0.90533889, -0.92908403, -0.15788384, ...,  0.43276706,
         2.37010592, -0.42691093],
       [ 0.23221386,  0.81400118, -0.15788384, ...,  0.43276706,
        -0.42146307, -0.42691093],
       [-0.27262991, -1.36485533, -0.15788384, ...,  0.43276706,
        -0.42146307, -0.42691093]])

In [23]:
pipeline.score(X_train, y_train.to_records(index=False))

0.7760636962367231

In [24]:
coxnet = CoxnetSurvivalAnalysis(l1_ratio=0.8, verbose=True, fit_baseline_model=False)
coxnet.fit(X_transformed, y_train.to_records(index=False))
coxnet.score(X_transformed, y_train.to_records(index=False))

0.7760598085921208

In [25]:
def build_2D_gridsearch(pipeline, cv, l1_ratios, alphas):
    # argument alphas of CoxnetSurvivalAnalysis must be an array at each grid point
    wrapped_alphas = [[alpha] for alpha in alphas]
    cv_search_dist = {
        'l1_ratio': l1_ratios,
        'alphas': wrapped_alphas,
    }
    cv_search = GridSearchCV(
        pipeline,
        cv_search_dist,
        cv=cv,
        error_score=0.5,  # replaces fit errors with random models
        n_jobs=-1,
        verbose=4,
    )
    return cv_search

In [26]:
# Grid search hyperparameters
n_splits = 3
l1_ratios = np.linspace(0.8, 1.0, 6)
alphas = 10.0 ** np.linspace(-6, -1, 6)

In [27]:
# Make the K-fold split deterministic for now
cv = KFold(n_splits=n_splits, shuffle=True, random_state=2984)

# Build the grid search metaestimator
coxnet = CoxnetSurvivalAnalysis(verbose=True, fit_baseline_model=False)
cv_search = build_2D_gridsearch(coxnet, cv, l1_ratios, alphas)

In [None]:
cv_search.fit(X_transformed, y_train.to_records(index=False))

In [None]:
cv_results = pd.DataFrame(cv_search.cv_results_)
cv_results

In [None]:
cv_search.best_params_

In [None]:
cv_results.param_alphas

In [None]:
cv_results.mean_test_score

In [None]:
l1s = []
mean = []
std = []
for i, alphas in enumerate(cv_results.param_alphas):
    if alphas != cv_search.best_params_["alphas"]:
        continue
    l1s.append(cv_results.param_l1_ratio[i])
    mean.append(cv_results.mean_test_score[i])
    std.append(cv_results.std_test_score[i])

fig, ax = plt.subplots(figsize=(9, 6))
ax.plot(l1s, mean, 'o-')
ax.fill_between(l1s, np.asarray(mean) - np.asarray(std), np.asarray(mean) + np.asarray(std), alpha=0.15)
ax.set_xscale("log")
ax.set_ylabel("concordance index")
ax.set_xlabel("l1")
ax.axvline(cv_search.best_params_["l1_ratio"], c="C1")
ax.axhline(0.5, color="grey", linestyle="--")
ax.grid(True)

In [None]:
alphas = []
mean = []
std = []
for i, l1_ratio in enumerate(cv_results.param_l1_ratio):
    if l1_ratio != cv_search.best_params_["l1_ratio"]:
        continue
    alphas.append(cv_results.param_alphas[i][0])
    mean.append(cv_results.mean_test_score[i])
    std.append(cv_results.std_test_score[i])

fig, ax = plt.subplots(figsize=(9, 6))
ax.plot(alphas, mean)
ax.fill_between(alphas, np.asarray(mean) - np.asarray(std),
                np.asarray(mean) + np.asarray(std), alpha=0.15)
ax.set_xscale("log")
ax.set_ylabel("concordance index")
ax.set_xlabel("alpha")
ax.axvline(cv_search.best_params_["alphas"][0], c="C1")
ax.axhline(0.5, color="grey", linestyle="--")
ax.grid(True)

In [None]:
def build_1D_gridsearch(pipeline, cv, alphas):
    # argument alphas of CoxnetSurvivalAnalysis must be an array at each grid point
    wrapped_alphas = [[alpha] for alpha in alphas]
    cv_search_dist = {'alphas': wrapped_alphas}
    cv_search_1D = GridSearchCV(
        pipeline,
        cv_search_dist,
        cv=cv,
        error_score=0.5,  # replaces fit errors with random models
        n_jobs=-1,
        verbose=4,
    )
    return cv_search_1D

In [None]:
# Grid search hyperparameters
n_splits = 3
log10_best_alpha = np.log10(cv_search.best_params_["alphas"])
# log10_best_alpha = np.log10(0.00035938136638046257)
alphas = 10 ** np.linspace(log10_best_alpha - 3, log10_best_alpha + 3, 50)
alphas

In [None]:
# Make the K-fold split deterministic for now
cv = KFold(n_splits=n_splits, shuffle=True, random_state=2984)

# Build the grid search metaestimator
best_l1_ratio = cv_search.best_params_["l1_ratio"]
# best_l1_ratio = 0.78
coxnet = CoxnetSurvivalAnalysis(
    l1_ratio=best_l1_ratio, verbose=True, fit_baseline_model=False)
cv_search = build_1D_gridsearch(coxnet, cv, alphas)

In [None]:
cv_search.fit(X_transformed, y_train.to_records(index=False))

In [None]:
best_model = cv_search.best_estimator_
best_coeffs = pd.DataFrame(best_model.coef_, index=feature_names, columns=['coefficient'])

non_zero = np.sum(best_coeffs.iloc[:, 0] != 0)
print(f"Number of non-zero coefficients: {non_zero}")

non_zero_coefs = best_coeffs.query("coefficient != 0")
coef_order = non_zero_coefs.abs().sort_values("coefficient").index

_, ax = plt.subplots(figsize=(6, 20))
non_zero_coefs.loc[coef_order].plot.barh(ax=ax, legend=False)
ax.set_xlabel("coefficient")
ax.grid(True)

In [None]:
coxnet_pred = CoxnetSurvivalAnalysis(verbose=True, fit_baseline_model=True)
coxnet_pred.set_params(**cv_search.best_params_)
coxnet_pred.fit(X_transformed, y_train.to_records(index=False))

In [None]:
X_train.loc[:, 'Sex']

In [None]:
impute_ohe = make_pipeline(
    SimpleImputer(strategy='constant', add_indicator=True).set_output(
        transform="pandas"),
    DataFrameTransformer(),
    OneHotEncoder(),
)

In [None]:
# Why is the scaled and one-hot version so inhomogeneous compared to the original categorical?
idx_male = np.where(feature_names == "pipeline__Sex=Male")[0].item()
for Xt, count in np.unique(X_transformed[:, idx_male], return_counts=True):
    print(f'Counts pipeline__Sex=Male = {Xt}: {count}')
print()
print(X_train['Sex'].value_counts())

In [None]:
surv_fns = coxnet_pred.predict_survival_function(X_transformed)
idx_male = np.where(feature_names == "pipeline__Sex=Male")[0].item()

coeffs_to_categories = {
    0: "Female",
    15: "Male",
}

time_points = np.quantile(y_train["Survival months"], np.linspace(0, 0.6, 100))
legend_handles = []
legend_labels = []
_, ax = plt.subplots(figsize=(9, 6))
for i, (fn, male_val) in enumerate(zip(surv_fns, X_transformed[:, idx_male])):
    if i > 50:
        break
    category = coeffs_to_categories[round(male_val)]
    if category == "Male":
        label = 0
    else:
        label = 1
    (line,) = ax.step(time_points, fn(time_points),
                      where="post", color=f"C{label}", alpha=0.5)
# TODO: add legend
ax.set_xlabel("Time")
ax.set_ylabel("Survival probability")
ax.grid(True)