In [1]:
import pandas as pd
import numpy as np
import os
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score, KFold
from sklearn.compose import ColumnTransformer, TransformedTargetRegressor
from skopt.space import Real, Integer, Categorical
from skopt import gp_minimize

import utils.dev_config as dev_conf
import utils.preprocessing as prep
import utils.optimization as opt

In [2]:
dirs = dev_conf.get_dev_directories("../dev_paths.txt")
unified_dsets = ["unified_cervical_data", "unified_uterine_data", "unified_uterine_endometrial_data"]
matrisome_list = f"{dirs.data_dir}/matrisome/matrisome_hs_masterlist.tsv"

In [3]:
i = 0

In [4]:
matrisome_df = prep.load_matrisome_df(matrisome_list)

In [5]:
seed = 123
rand = np.random.RandomState()

# Load and filter survival data

In [6]:
event_code = {"Alive": 0, "Dead": 1}
covariate_cols = ["figo_stage", "age_at_diagnosis", "race", "ethnicity"]
dep_cols = ["vital_status", "survival_time"]
cat_cols = ["race", "ethnicity", "figo_chr"]
survival_df = prep.load_survival_df(f"{dirs.data_dir}/{unified_dsets[i]}/survival_data.tsv", event_code)

In [7]:
filtered_survival_df = (
    prep.decode_figo_stage(survival_df[["sample_name"] + dep_cols + covariate_cols].dropna(), to="c")
        .query("vital_status == 1")
        .drop(["vital_status"], axis=1)
        .pipe(pd.get_dummies, columns=cat_cols)
        .reset_index(drop = True)
)
filtered_survival_df.columns = filtered_survival_df.columns.str.replace(' ', '_')

print(filtered_survival_df.shape)
# filtered_survival_df.head()

(66, 16)


# Load normalized matrisome count data

In [8]:
norm_matrisome_counts_df = pd.read_csv(f"{dirs.data_dir}/{unified_dsets[i]}/norm_matrisome_counts.tsv", sep='\t')
norm_filtered_matrisome_counts_t_df = prep.transpose_df(
    norm_matrisome_counts_df[["geneID"] + list(filtered_survival_df.sample_name)], "geneID", "sample_name"
)
print(norm_filtered_matrisome_counts_t_df.shape)
# norm_filtered_matrisome_counts_t_df.head()

(66, 1009)


# Join survival and count data

In [9]:
joined_df = (
    pd.merge(filtered_survival_df, norm_filtered_matrisome_counts_t_df, on="sample_name")
        .set_index("sample_name")
)
print(joined_df.shape)
# joined_df.head()

(66, 1023)


# Optimize model

In [10]:
rand.seed(seed)
x_df, y_df = prep.shuffle_data(joined_df, rand)

In [11]:
def objective(h_params, X, y, r, verbose=True):
    if verbose:
        print(h_params)
    model = GradientBoostingRegressor(
        loss=h_params[0],
        learning_rate=h_params[1],
        n_estimators=h_params[2],
        max_depth=h_params[3],
        max_features=h_params[4],
        alpha=h_params[5],
        random_state=r
    )
#     # Standardize all variables (except one-hots)
#     # Don't really need to do this for decision tree methods, but might as well
#     # repeat the procedure that we're doing for the SVMs, just for posterity
    return -np.mean(cross_val_score(
        model,
        X,
        y,
        cv=KFold(n_splits=5),
        n_jobs=-1,
        scoring="neg_mean_absolute_error"
    ))

In [12]:
space = [
    # We skip "ls" since we're optimizing for mean absolute error in the h_param opt.
    Categorical(["lad", "huber", "quantile"], name="loss"),
    Real(1e-3, 1e-1, name="learning_rate"),
    Integer(int(1e2), int(1e3), name="n_estimators"),
    Integer(2, 5, name="max_depth"),
    Categorical(["auto", "sqrt", "log2"], name="max_features"),
    Real(0.5, 0.9, name="alpha"),
]
n_initial = 10 * (len(space[0].categories) + len(space[4].categories))
n_calls = 100 * (len(space[0].categories) + len(space[4].categories))

In [13]:
callback_file = f"{unified_dsets[i]}_opt_gbr_h_params.tsv"
try:
    os.remove(callback_file)
except OSError:
    pass


res = gp_minimize(
    lambda h_ps: objective(h_ps, x_df, y_df, rand),
    space,
    verbose=True,
    random_state=rand,
    n_initial_points=n_initial,
    n_calls = n_calls,
    n_jobs=-1,
    callback=lambda x: opt.save_callback(x, callback_file, sep="\t")
)

Iteration No: 1 started. Evaluating function at random point.
['quantile', 0.04860832066690728, 818, 3, 'log2', 0.5039029398683332]
Iteration No: 1 ended. Evaluation done at random point.
Time taken: 1.5380
Function value obtained: 577.2109
Current minimum: 577.2109
Iteration No: 2 started. Evaluating function at random point.
['lad', 0.06923792863971574, 882, 3, 'log2', 0.6073125230846899]
Iteration No: 2 ended. Evaluation done at random point.
Time taken: 1.2777
Function value obtained: 553.5123
Current minimum: 553.5123
Iteration No: 3 started. Evaluating function at random point.
['quantile', 0.027780951796428976, 577, 3, 'log2', 0.8564436383356643]
Iteration No: 3 ended. Evaluation done at random point.
Time taken: 1.1139
Function value obtained: 643.2673
Current minimum: 553.5123
Iteration No: 4 started. Evaluating function at random point.
['huber', 0.04993730401749343, 291, 4, 'auto', 0.8676227203709375]
Iteration No: 4 ended. Evaluation done at random point.
Time taken: 3.6785



Iteration No: 68 ended. Search finished for the next optimal point.
Time taken: 7.0969
Function value obtained: 520.7674
Current minimum: 491.7044
Iteration No: 69 started. Searching for the next optimal point.
['huber', 0.001, 100, 2, 'auto', 0.5]
Iteration No: 69 ended. Search finished for the next optimal point.
Time taken: 1.8693
Function value obtained: 516.4989
Current minimum: 491.7044
Iteration No: 70 started. Searching for the next optimal point.
['lad', 0.001, 100, 2, 'log2', 0.5]
Iteration No: 70 ended. Search finished for the next optimal point.
Time taken: 1.2063
Function value obtained: 521.8830
Current minimum: 491.7044
Iteration No: 71 started. Searching for the next optimal point.
['huber', 0.001, 1000, 2, 'sqrt', 0.5]
Iteration No: 71 ended. Search finished for the next optimal point.
Time taken: 2.2841
Function value obtained: 516.2416
Current minimum: 491.7044
Iteration No: 72 started. Searching for the next optimal point.
['huber', 0.1, 1000, 2, 'log2', 0.5]
Iterat



Iteration No: 120 ended. Search finished for the next optimal point.
Time taken: 2.0745
Function value obtained: 521.2409
Current minimum: 491.7044
Iteration No: 121 started. Searching for the next optimal point.
['lad', 0.001, 100, 4, 'sqrt', 0.5]




Iteration No: 121 ended. Search finished for the next optimal point.
Time taken: 2.0893
Function value obtained: 521.3124
Current minimum: 491.7044
Iteration No: 122 started. Searching for the next optimal point.
['lad', 0.001, 100, 4, 'sqrt', 0.5]




Iteration No: 122 ended. Search finished for the next optimal point.
Time taken: 1.9894
Function value obtained: 521.1126
Current minimum: 491.7044
Iteration No: 123 started. Searching for the next optimal point.
['lad', 0.001, 285, 2, 'sqrt', 0.5497564385691771]
Iteration No: 123 ended. Search finished for the next optimal point.
Time taken: 2.2580
Function value obtained: 520.3560
Current minimum: 491.7044
Iteration No: 124 started. Searching for the next optimal point.
['huber', 0.09780109382163404, 100, 2, 'auto', 0.9]
Iteration No: 124 ended. Search finished for the next optimal point.
Time taken: 2.4312
Function value obtained: 587.1083
Current minimum: 491.7044
Iteration No: 125 started. Searching for the next optimal point.
['lad', 0.001, 298, 2, 'sqrt', 0.5359519499362061]
Iteration No: 125 ended. Search finished for the next optimal point.
Time taken: 2.1264
Function value obtained: 519.1702
Current minimum: 491.7044
Iteration No: 126 started. Searching for the next optimal p



Iteration No: 169 ended. Search finished for the next optimal point.
Time taken: 3.3725
Function value obtained: 518.4981
Current minimum: 491.7044
Iteration No: 170 started. Searching for the next optimal point.
['lad', 0.001, 670, 2, 'sqrt', 0.53536068053286]
Iteration No: 170 ended. Search finished for the next optimal point.
Time taken: 3.4797
Function value obtained: 518.4502
Current minimum: 491.7044
Iteration No: 171 started. Searching for the next optimal point.
['lad', 0.001, 100, 5, 'sqrt', 0.5]
Iteration No: 171 ended. Search finished for the next optimal point.
Time taken: 3.0681
Function value obtained: 521.6534
Current minimum: 491.7044
Iteration No: 172 started. Searching for the next optimal point.
['lad', 0.007049808023349759, 572, 2, 'log2', 0.5237893784706554]
Iteration No: 172 ended. Search finished for the next optimal point.
Time taken: 3.4022
Function value obtained: 522.9896
Current minimum: 491.7044
Iteration No: 173 started. Searching for the next optimal poin



Iteration No: 181 ended. Search finished for the next optimal point.
Time taken: 3.7060
Function value obtained: 519.2412
Current minimum: 491.7044
Iteration No: 182 started. Searching for the next optimal point.
['lad', 0.001, 590, 2, 'sqrt', 0.5]
Iteration No: 182 ended. Search finished for the next optimal point.
Time taken: 3.9508
Function value obtained: 519.7914
Current minimum: 491.7044
Iteration No: 183 started. Searching for the next optimal point.
['lad', 0.017052503883895093, 187, 2, 'log2', 0.5359424449828781]
Iteration No: 183 ended. Search finished for the next optimal point.
Time taken: 3.7180
Function value obtained: 529.1970
Current minimum: 491.7044
Iteration No: 184 started. Searching for the next optimal point.
['lad', 0.1, 199, 2, 'sqrt', 0.756383615851575]
Iteration No: 184 ended. Search finished for the next optimal point.
Time taken: 3.6252
Function value obtained: 554.4894
Current minimum: 491.7044
Iteration No: 185 started. Searching for the next optimal point



Iteration No: 348 ended. Search finished for the next optimal point.
Time taken: 12.1635
Function value obtained: 514.1726
Current minimum: 491.7044
Iteration No: 349 started. Searching for the next optimal point.
['lad', 0.005774289950271251, 216, 2, 'log2', 0.5]
Iteration No: 349 ended. Search finished for the next optimal point.
Time taken: 11.9294
Function value obtained: 517.6622
Current minimum: 491.7044
Iteration No: 350 started. Searching for the next optimal point.
['lad', 0.005929810929642991, 219, 2, 'log2', 0.5]
Iteration No: 350 ended. Search finished for the next optimal point.
Time taken: 11.5072
Function value obtained: 517.2247
Current minimum: 491.7044
Iteration No: 351 started. Searching for the next optimal point.
['lad', 0.0057554277265534255, 219, 2, 'log2', 0.5]
Iteration No: 351 ended. Search finished for the next optimal point.
Time taken: 12.1894
Function value obtained: 515.7141
Current minimum: 491.7044
Iteration No: 352 started. Searching for the next optim



Iteration No: 482 ended. Search finished for the next optimal point.
Time taken: 22.1877
Function value obtained: 515.0747
Current minimum: 491.7044
Iteration No: 483 started. Searching for the next optimal point.
['lad', 0.005781426507983793, 219, 2, 'log2', 0.5]
Iteration No: 483 ended. Search finished for the next optimal point.
Time taken: 21.0785
Function value obtained: 519.1877
Current minimum: 491.7044
Iteration No: 484 started. Searching for the next optimal point.
['lad', 0.005779033236654437, 219, 2, 'log2', 0.5]
Iteration No: 484 ended. Search finished for the next optimal point.
Time taken: 20.9166
Function value obtained: 518.5552
Current minimum: 491.7044
Iteration No: 485 started. Searching for the next optimal point.
['lad', 0.005777292230471553, 219, 2, 'log2', 0.5]
Iteration No: 485 ended. Search finished for the next optimal point.
Time taken: 20.9207
Function value obtained: 516.1288
Current minimum: 491.7044
Iteration No: 486 started. Searching for the next optima



Iteration No: 546 ended. Search finished for the next optimal point.
Time taken: 28.3841
Function value obtained: 517.5407
Current minimum: 491.7044
Iteration No: 547 started. Searching for the next optimal point.
['lad', 0.0054771181169108, 227, 2, 'log2', 0.5]
Iteration No: 547 ended. Search finished for the next optimal point.
Time taken: 28.6384
Function value obtained: 513.2961
Current minimum: 491.7044
Iteration No: 548 started. Searching for the next optimal point.
['lad', 0.005623250737473681, 220, 2, 'log2', 0.5]
Iteration No: 548 ended. Search finished for the next optimal point.
Time taken: 28.9404
Function value obtained: 513.7558
Current minimum: 491.7044
Iteration No: 549 started. Searching for the next optimal point.
['lad', 0.005563201960541791, 221, 2, 'log2', 0.5]
Iteration No: 549 ended. Search finished for the next optimal point.
Time taken: 30.4003
Function value obtained: 514.1073
Current minimum: 491.7044
Iteration No: 550 started. Searching for the next optimal 