In [1]:
import pandas as pd
import numpy as np
import os
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import cross_val_score, KFold
from sklearn.metrics import f1_score
from skopt.space import Real, Integer, Categorical
from skopt import gp_minimize

import utils.dev_config as dev_conf
import utils.preprocessing as prep
import utils.optimization as opt

In [2]:
dirs = dev_conf.get_dev_directories("../dev_paths.txt")
unified_dsets = ["unified_cervical_data", "unified_uterine_data", "unified_uterine_endometrial_data"]

In [3]:
dset_idx = 0

In [4]:
seed = 123
rand = np.random.RandomState()

# Load and filter survival data

In [5]:
event_code = {"Alive": 0, "Dead": 1}
covariate_cols = ["age_at_diagnosis", "race", "ethnicity"]
dep_cols = ["figo_stage"]
cat_cols = ["race", "ethnicity"]
survival_df = prep.load_survival_df(f"{dirs.data_dir}/{unified_dsets[dset_idx]}/survival_data.tsv", event_code)

In [6]:
filtered_survival_df = (
    prep.decode_figo_stage(survival_df[["sample_name"] + dep_cols + covariate_cols].dropna(), to="n")
        .pipe(pd.get_dummies, columns=cat_cols)
        .reset_index(drop = True)
        .pipe(prep.cols_to_front, ["sample_name", "figo_num"])
)
filtered_survival_df.columns = filtered_survival_df.columns.str.replace(' ', '_')
print(filtered_survival_df.shape)
# filtered_survival_df.head()

(255, 12)


# Load normalized matrisome count data

In [7]:
norm_matrisome_counts_df = pd.read_csv(f"{dirs.data_dir}/{unified_dsets[dset_idx]}/norm_matrisome_counts.tsv", sep='\t')
norm_filtered_matrisome_counts_t_df = prep.transpose_df(
    norm_matrisome_counts_df[["geneID"] + list(filtered_survival_df.sample_name)], "geneID", "sample_name"
)
print(norm_filtered_matrisome_counts_t_df.shape)
# norm_filtered_matrisome_counts_t_df.head()

(255, 1009)


# Join survival and count data

In [8]:
joined_df = (
    pd.merge(filtered_survival_df, norm_filtered_matrisome_counts_t_df, on="sample_name")
        .set_index("sample_name")
)
print(joined_df.shape)
# joined_df.head()

(255, 1019)


In [11]:
list(norm_filtered_matrisome_counts_t_df.columns)

['sample_name',
 'PGF',
 'TIMP4',
 'C1QTNF6',
 'TNC',
 'PRL',
 'OGN',
 'C1QL3',
 'FGB',
 'NDNF',
 'CCL22',
 'ELSPBP1',
 'CYR61',
 'ECM1',
 'ANGPT2',
 'SERPINF2',
 'SCUBE3',
 'CRELD2',
 'KITLG',
 'THSD4',
 'MEPE',
 'CELA2B',
 'CLEC4G',
 'ANGPTL7',
 'CSF3',
 'LOXL1',
 'CLEC18A',
 'MUC3A',
 'PXDNL',
 'FGF9',
 'SFTPD',
 'S100P',
 'SFRP1',
 'LGALS2',
 'SERPINA4',
 'VWCE',
 'DSPP',
 'COL10A1',
 'HMSD',
 'FGG',
 'FGF7',
 'LOX',
 'FGF18',
 'MUC8',
 'PDGFD',
 'IFNB1',
 'CSTA',
 'CXCL13',
 'REG1A',
 'ANGPTL2',
 'SERPINA5',
 'IGFBP1',
 'INHBE',
 'CSF2',
 'TGM6',
 'CILP',
 'EGFL6',
 'COMP',
 'ANXA7',
 'TLL2',
 'LMAN1L',
 'IL10',
 'CBLN4',
 'C1QTNF9',
 'MUC13',
 'IFNW1',
 'GDF15',
 'P4HA3',
 'SERPINB10',
 'SEMA5A',
 'CRHBP',
 'LTBP2',
 'CCL14',
 'IL25',
 'CXCL2',
 'MUC15',
 'FCN1',
 'LEFTY2',
 'ADAM30',
 'TGFB1',
 'TNFSF14',
 'CXCL14',
 'TNFSF18',
 'CTSC',
 'GPC6',
 'WNT5A',
 'CCL4L2',
 'SLIT2',
 'ADAM33',
 'VIT',
 'VEGFB',
 'BMP1',
 'ADAMTS15',
 'F12',
 'ANXA6',
 'KAL1',
 'COL23A1',
 'COL12A1',
 '

In [12]:
joined_df[["figo_num"] + list(norm_filtered_matrisome_counts_t_df.columns[1:])]

Unnamed: 0_level_0,figo_num,PGF,TIMP4,C1QTNF6,TNC,PRL,OGN,C1QL3,FGB,NDNF,...,PIK3IP1,C1QTNF2,PCSK5,ANXA1,HGF,VWA2,FGF3,POSTN,NTF3,S100A6
sample_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
TCGA-C5-A1BF-01B-11R-A13Y-07,1,10.800637,6.228003,11.669331,13.002928,5.063964,4.869744,5.063964,8.834522,6.410767,...,9.013453,8.190325,9.503647,14.077995,6.569726,7.315604,4.602649,12.062300,5.649441,16.558407
TCGA-EK-A2RM-01A-21R-A18M-07,1,9.674879,7.277164,10.712783,13.003138,4.602649,5.086466,5.630820,5.086466,5.761877,...,10.854224,6.581217,8.437154,15.816261,7.644559,6.406766,4.998296,11.731128,6.028879,17.119594
TCGA-Q1-A73P-01A-11R-A32P-07,1,8.036801,5.247645,9.894159,13.321633,4.602649,5.769802,7.289183,6.336043,9.843850,...,10.854487,5.629541,9.602922,14.174748,6.987468,6.731154,4.602649,9.293089,4.893018,16.649488
TCGA-C5-A8YT-01A-11R-A37O-07,1,7.830611,5.733875,12.445548,13.765468,5.455125,13.049104,5.146455,5.074289,10.569544,...,9.453187,6.398956,12.288955,13.396332,10.228758,8.542025,4.602649,11.765396,5.318924,13.556322
TCGA-UC-A7PI-01A-11R-A42S-07,1,7.243036,5.328548,9.392965,14.243570,4.879491,5.583359,5.862713,5.377532,9.604209,...,10.655786,7.368694,8.444696,14.402125,5.940529,9.163491,4.602649,8.118925,5.889309,16.314001
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TCGA-VS-A9V4-01A-12R-A42T-07,4,10.495288,7.433105,10.140340,10.482026,4.602649,4.602649,5.864623,9.485394,5.723060,...,9.101395,5.346437,10.017004,10.726536,7.231687,10.082923,5.035225,11.727620,5.458443,14.885224
TCGA-C5-A7X5-01A-11R-A36F-07,4,8.933151,7.254909,9.416965,13.486333,4.602649,4.885849,6.457759,7.473299,6.316378,...,8.613894,7.904529,9.314140,15.640314,5.567193,7.578045,4.602649,11.317061,5.166365,17.134436
TCGA-VS-A9UV-01A-11R-A42T-07,4,10.872615,5.368238,9.196958,10.896139,4.948236,5.723194,6.151946,4.602649,4.948236,...,10.327597,9.100672,6.384585,16.518829,5.620639,5.564539,4.602649,10.145029,5.368238,16.577061
TCGA-VS-A8EK-01A-12R-A37O-07,4,10.857763,5.172154,9.335402,10.540475,4.602649,5.096646,5.815866,6.192199,5.238361,...,11.782721,5.651758,11.139308,15.215973,6.474744,7.152241,4.602649,10.103185,5.006648,14.949027


# Optimize model

In [10]:
rand.seed(seed)
x_df, y_df = prep.shuffle_data(joined_df, rand)

## Get baselines

In [11]:
rand.seed(seed)

label_value_counts_df = (
    pd.DataFrame(y_df.figo_num.value_counts()).reset_index()
        .rename(columns={"index": "label", "figo_num": "n"})
        .sort_values("n", ascending=False)
)

most_frequent_label = label_value_counts_df.label[0]
most_frequent_baseline = f1_score(y_df.values.squeeze(), np.repeat(most_frequent_label, y_df.shape[0]), average="weighted")

mc_baseline = opt.mc_classification_baseline(
    y=y_df.values.squeeze(),
    labels=label_value_counts_df.label.values,
    weights=label_value_counts_df.n.values / label_value_counts_df.n.values.sum(),
    metric=lambda y, yhat: f1_score(y, yhat, average="weighted"),
    n=1001
)

print(f"Most frequent baseline: {most_frequent_baseline}")
print(f"Monte Carlo baseline: {mc_baseline.mean()}")

Most frequent baseline: 0.3665158371040725
Monte Carlo baseline: 0.36630101199745085


## SMBO

In [None]:
def objective(h_params, X, y, loss_default, scoring_default, r, verbose=True):
    if verbose:
        print(h_params)
    model = GradientBoostingClassifier(
        loss=loss_default,
        learning_rate=h_params[0],
        n_estimators=h_params[1],
        max_depth=h_params[2],
        max_features=h_params[3],
        min_samples_split=h_params[4],
        min_samples_leaf=h_params[5],
        random_state=r
    )
    return -np.mean(cross_val_score(
        model,
        X,
        y,
        cv=KFold(n_splits=5),
        n_jobs=-1,
        scoring=scoring_default
    ))

In [None]:
space = [
    Real(1e-3, 1e-1, name="learning_rate"),
    Integer(int(1e2), int(1e3), name="n_estimators"),
    Integer(2, 5, name="max_depth"),
    Categorical(["auto", "sqrt", "log2"], name="max_features"),
    Integer(int(2), int(6), name="min_samples_split"),
    Integer(int(1), int(3), name="min_samples_leaf")
]
n_initial = 10 * len(space)
n_calls = 50 * len(space)

In [None]:
loss_default = "deviance"
scoring_default = "f1_weighted"
callback_file = f"{unified_dsets[dset_idx]}_opt_gbc_h_params_{scoring_default}.tsv"

try:
    os.remove(callback_file)
except OSError:
    pass

res = gp_minimize(
    lambda h_ps: objective(h_ps, x_df, y_df, loss_default, scoring_default, rand),
    space,
    verbose=True,
    random_state=rand,
    n_initial_points=n_initial,
    n_calls=n_calls,
    n_jobs=-1,
    callback=lambda x: opt.save_callback(x, callback_file, n = 5, sep="\t")
)