In [1]:
import pandas as pd
import numpy as np
import os
from sklearn.ensemble import RandomForestRegressor
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score, KFold
from sklearn.compose import ColumnTransformer, TransformedTargetRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, explained_variance_score
from skopt.space import Real, Integer, Categorical
from skopt import gp_minimize

import utils.dev_config as dev_conf
import utils.preprocessing as prep
import utils.optimization as opt

In [2]:
dirs = dev_conf.get_dev_directories("../dev_paths.txt")
unified_dsets = ["unified_cervical_data", "unified_uterine_data", "unified_uterine_endometrial_data"]
matrisome_list = f"{dirs.data_dir}/matrisome/matrisome_hs_masterlist.tsv"

In [3]:
dset_idx = 0

In [4]:
matrisome_df = prep.load_matrisome_df(matrisome_list)

In [5]:
seed = 123
rand = np.random.RandomState()

# Load and filter survival data

In [6]:
event_code = {"Alive": 0, "Dead": 1}
covariate_cols = ["figo_stage", "age_at_diagnosis", "race", "ethnicity"]
dep_cols = ["vital_status", "survival_time"]
cat_cols = ["race", "ethnicity", "figo_chr"]
survival_df = prep.load_survival_df(f"{dirs.data_dir}/{unified_dsets[dset_idx]}/survival_data.tsv", event_code)

In [7]:
filtered_survival_df = (
    prep.decode_figo_stage(survival_df[["sample_name"] + dep_cols + covariate_cols].dropna(), to="c")
        .query("vital_status == 1")
        .drop(["vital_status"], axis=1)
        .pipe(pd.get_dummies, columns=cat_cols)
        .reset_index(drop = True)
)
filtered_survival_df.columns = filtered_survival_df.columns.str.replace(' ', '_')

print(filtered_survival_df.shape)
# filtered_survival_df.head()

(66, 16)


# Load normalized matrisome count data

In [8]:
norm_matrisome_counts_df = pd.read_csv(f"{dirs.data_dir}/{unified_dsets[dset_idx]}/norm_matrisome_counts.tsv", sep='\t')
norm_filtered_matrisome_counts_t_df = prep.transpose_df(
    norm_matrisome_counts_df[["geneID"] + list(filtered_survival_df.sample_name)], "geneID", "sample_name"
)
print(norm_filtered_matrisome_counts_t_df.shape)
# norm_filtered_matrisome_counts_t_df.head()

(66, 1009)


# Join survival and count data

In [9]:
joined_df = (
    pd.merge(filtered_survival_df, norm_filtered_matrisome_counts_t_df, on="sample_name")
        .set_index("sample_name")
)
print(joined_df.shape)
# joined_df.head()

(66, 1023)


# Optimize model

In [10]:
rand.seed(seed)
x_df, y_df = prep.shuffle_data(joined_df, rand)

In [11]:
mean_baseline = mean_squared_error(y_df.values, np.repeat(np.mean(y_df.values.squeeze()), y_df.shape[0]))
median_baseline = mean_absolute_error(y_df.values, np.repeat(np.median(y_df.values.squeeze()), y_df.shape[0]))
r2_baseline = r2_score(y_df.values, np.repeat(np.mean(y_df.values.squeeze()), y_df.shape[0]))
expl_var_baseline = explained_variance_score(y_df.values, np.repeat(np.mean(y_df.values.squeeze()), y_df.shape[0]))

print(f"L2 baseline: {mean_baseline}")
print(f"L1 baseline: {median_baseline}")
print(f"R2 baseline: {r2_baseline}")
print(f"explained variance baseline: {expl_var_baseline}")

L2 baseline: 641687.6988062444
L1 baseline: 518.3333333333334
R2 baseline: 0.0
explained variance baseline: 0.0


In [12]:
def objective(h_params, X, y, scoring_default, r, verbose=True):
    if verbose:
        print(h_params)
    model = RandomForestRegressor(
        n_estimators=h_params[0],
        max_depth=h_params[1],
        max_features=h_params[2],
        min_samples_split=h_params[3],
        min_samples_leaf=h_params[4],
        bootstrap=h_params[5],
        n_jobs=-1,
        random_state=r
    )
    return -np.mean(cross_val_score(
        model,
        X,
        y,
        cv=KFold(n_splits=5),
        n_jobs=-1,
        scoring=scoring_default
    ))

In [13]:
space = [
    Integer(int(1e2), int(1e3), name="n_estimators"),
    Integer(int(10), int(100), name="max_depth"),
    Categorical(["auto", "sqrt", "log2"], name="max_features"),
    Integer(int(2), int(4), name="min_samples_split"),
    Integer(int(1), int(3), name="min_samples_leaf"),
    Categorical([True, False], name="bootstrap")
]
n_initial = 10 * len(space)
n_calls = 50 * len(space)

In [14]:
scoring_default = "explained_variance"
callback_file = f"{unified_dsets[dset_idx]}_opt_rfr_h_params_{scoring_default}.tsv"

try:
    os.remove(callback_file)
except OSError:
    pass


res = gp_minimize(
    lambda h_ps: objective(h_ps, x_df, y_df, scoring_default, rand),
    space,
    verbose=True,
    random_state=rand,
    n_initial_points=n_initial,
    n_calls = n_calls,
    n_jobs=-1,
    callback=lambda x: opt.save_callback(x, callback_file, n = 5, sep="\t")
)

Iteration No: 1 started. Evaluating function at random point.
[776, 53, 'sqrt', 3, 2, False]
Iteration No: 1 ended. Evaluation done at random point.
Time taken: 3.0097
Function value obtained: 0.0395
Current minimum: 0.0395
Iteration No: 2 started. Evaluating function at random point.
[407, 72, 'sqrt', 3, 2, False]
Iteration No: 2 ended. Evaluation done at random point.
Time taken: 1.5592
Function value obtained: 0.0326
Current minimum: 0.0326
Iteration No: 3 started. Evaluating function at random point.
[853, 34, 'log2', 2, 2, True]
Iteration No: 3 ended. Evaluation done at random point.
Time taken: 3.2801
Function value obtained: -0.0122
Current minimum: -0.0122
Iteration No: 4 started. Evaluating function at random point.
[262, 54, 'auto', 3, 1, True]
Iteration No: 4 ended. Evaluation done at random point.
Time taken: 2.3538
Function value obtained: 0.1392
Current minimum: -0.0122
Iteration No: 5 started. Evaluating function at random point.
[995, 53, 'log2', 3, 2, True]
Iteration N



Iteration No: 63 ended. Search finished for the next optimal point.
Time taken: 3.5442
Function value obtained: -0.0157
Current minimum: -0.0272
Iteration No: 64 started. Searching for the next optimal point.
[100, 100, 'log2', 2, 3, True]
Iteration No: 64 ended. Search finished for the next optimal point.
Time taken: 1.6280
Function value obtained: -0.0467
Current minimum: -0.0467
Iteration No: 65 started. Searching for the next optimal point.
[115, 11, 'log2', 2, 3, True]
Iteration No: 65 ended. Search finished for the next optimal point.
Time taken: 1.3398
Function value obtained: -0.0313
Current minimum: -0.0467
Iteration No: 66 started. Searching for the next optimal point.
[119, 98, 'log2', 4, 3, True]
Iteration No: 66 ended. Search finished for the next optimal point.
Time taken: 1.7902
Function value obtained: -0.0212
Current minimum: -0.0467
Iteration No: 67 started. Searching for the next optimal point.
[100, 100, 'log2', 2, 3, True]




Iteration No: 67 ended. Search finished for the next optimal point.
Time taken: 1.9997
Function value obtained: -0.0141
Current minimum: -0.0467
Iteration No: 68 started. Searching for the next optimal point.
[101, 14, 'sqrt', 4, 3, True]
Iteration No: 68 ended. Search finished for the next optimal point.
Time taken: 1.9020
Function value obtained: 0.0298
Current minimum: -0.0467
Iteration No: 69 started. Searching for the next optimal point.
[997, 81, 'sqrt', 2, 2, True]
Iteration No: 69 ended. Search finished for the next optimal point.
Time taken: 4.0840
Function value obtained: 0.0067
Current minimum: -0.0467
Iteration No: 70 started. Searching for the next optimal point.
[981, 12, 'log2', 2, 3, True]
Iteration No: 70 ended. Search finished for the next optimal point.
Time taken: 4.6103
Function value obtained: -0.0146
Current minimum: -0.0467
Iteration No: 71 started. Searching for the next optimal point.
[964, 98, 'sqrt', 2, 3, True]
Iteration No: 71 ended. Search finished for th



Iteration No: 73 ended. Search finished for the next optimal point.
Time taken: 2.2731
Function value obtained: -0.0640
Current minimum: -0.0640
Iteration No: 74 started. Searching for the next optimal point.
[112, 96, 'log2', 4, 1, False]
Iteration No: 74 ended. Search finished for the next optimal point.
Time taken: 2.2578
Function value obtained: 0.0854
Current minimum: -0.0640
Iteration No: 75 started. Searching for the next optimal point.
[989, 25, 'sqrt', 4, 2, True]
Iteration No: 75 ended. Search finished for the next optimal point.
Time taken: 3.6558
Function value obtained: -0.0045
Current minimum: -0.0640
Iteration No: 76 started. Searching for the next optimal point.
[101, 16, 'log2', 4, 3, True]
Iteration No: 76 ended. Search finished for the next optimal point.
Time taken: 2.1407
Function value obtained: -0.0175
Current minimum: -0.0640
Iteration No: 77 started. Searching for the next optimal point.
[992, 91, 'log2', 4, 3, True]
Iteration No: 77 ended. Search finished for 



Iteration No: 174 ended. Search finished for the next optimal point.
Time taken: 6.8210
Function value obtained: -0.0083
Current minimum: -0.0640
Iteration No: 175 started. Searching for the next optimal point.
[132, 99, 'log2', 4, 3, True]
Iteration No: 175 ended. Search finished for the next optimal point.
Time taken: 4.7054
Function value obtained: -0.0413
Current minimum: -0.0640
Iteration No: 176 started. Searching for the next optimal point.
[988, 100, 'sqrt', 2, 3, True]
Iteration No: 176 ended. Search finished for the next optimal point.
Time taken: 7.6983
Function value obtained: -0.0263
Current minimum: -0.0640
Iteration No: 177 started. Searching for the next optimal point.
[117, 98, 'log2', 4, 3, False]
Iteration No: 177 ended. Search finished for the next optimal point.
Time taken: 5.4224
Function value obtained: -0.0164
Current minimum: -0.0640
Iteration No: 178 started. Searching for the next optimal point.
[972, 99, 'sqrt', 4, 3, True]
Iteration No: 178 ended. Search fi



Iteration No: 192 ended. Search finished for the next optimal point.
Time taken: 4.6566
Function value obtained: 0.0499
Current minimum: -0.0640
Iteration No: 193 started. Searching for the next optimal point.
[992, 99, 'log2', 2, 3, True]
Iteration No: 193 ended. Search finished for the next optimal point.
Time taken: 6.9438
Function value obtained: -0.0220
Current minimum: -0.0640
Iteration No: 194 started. Searching for the next optimal point.
[982, 20, 'sqrt', 2, 3, True]
Iteration No: 194 ended. Search finished for the next optimal point.
Time taken: 7.0324
Function value obtained: -0.0190
Current minimum: -0.0640
Iteration No: 195 started. Searching for the next optimal point.
[995, 26, 'log2', 4, 1, False]
Iteration No: 195 ended. Search finished for the next optimal point.
Time taken: 7.1119
Function value obtained: 0.0173
Current minimum: -0.0640
Iteration No: 196 started. Searching for the next optimal point.
[102, 14, 'sqrt', 4, 1, True]
Iteration No: 196 ended. Search finis



Iteration No: 219 ended. Search finished for the next optimal point.
Time taken: 6.6501
Function value obtained: 0.0299
Current minimum: -0.0640
Iteration No: 220 started. Searching for the next optimal point.
[107, 95, 'sqrt', 4, 3, True]
Iteration No: 220 ended. Search finished for the next optimal point.
Time taken: 5.9954
Function value obtained: 0.0104
Current minimum: -0.0640
Iteration No: 221 started. Searching for the next optimal point.
[995, 10, 'log2', 4, 2, False]
Iteration No: 221 ended. Search finished for the next optimal point.
Time taken: 8.5608
Function value obtained: 0.0123
Current minimum: -0.0640
Iteration No: 222 started. Searching for the next optimal point.
[105, 93, 'log2', 2, 2, True]
Iteration No: 222 ended. Search finished for the next optimal point.
Time taken: 6.4910
Function value obtained: -0.0103
Current minimum: -0.0640
Iteration No: 223 started. Searching for the next optimal point.
[977, 10, 'log2', 2, 3, True]
Iteration No: 223 ended. Search finish



Iteration No: 258 ended. Search finished for the next optimal point.
Time taken: 9.9882
Function value obtained: -0.0176
Current minimum: -0.0640
Iteration No: 259 started. Searching for the next optimal point.
[980, 100, 'sqrt', 4, 3, True]
Iteration No: 259 ended. Search finished for the next optimal point.
Time taken: 10.4204
Function value obtained: -0.0019
Current minimum: -0.0640
Iteration No: 260 started. Searching for the next optimal point.
[111, 13, 'log2', 4, 3, True]
Iteration No: 260 ended. Search finished for the next optimal point.
Time taken: 8.3076
Function value obtained: -0.0148
Current minimum: -0.0640
Iteration No: 261 started. Searching for the next optimal point.
[100, 11, 'log2', 2, 3, False]
Iteration No: 261 ended. Search finished for the next optimal point.
Time taken: 8.4841
Function value obtained: 0.0195
Current minimum: -0.0640
Iteration No: 262 started. Searching for the next optimal point.
[100, 97, 'log2', 4, 2, False]
Iteration No: 262 ended. Search f

In [15]:
scoring_default = "neg_mean_absolute_error"
callback_file = f"{unified_dsets[dset_idx]}_opt_rfr_h_params_{scoring_default}.tsv"

try:
    os.remove(callback_file)
except OSError:
    pass


res = gp_minimize(
    lambda h_ps: objective(h_ps, x_df, y_df, scoring_default, rand),
    space,
    verbose=True,
    random_state=rand,
    n_initial_points=n_initial,
    n_calls = n_calls,
    n_jobs=-1,
    callback=lambda x: opt.save_callback(x, callback_file, n = 5, sep="\t")
)

Iteration No: 1 started. Evaluating function at random point.
[815, 50, 'log2', 2, 1, False]
Iteration No: 1 ended. Evaluation done at random point.
Time taken: 1.3804
Function value obtained: 592.5586
Current minimum: 592.5586
Iteration No: 2 started. Evaluating function at random point.
[825, 22, 'auto', 3, 2, False]
Iteration No: 2 ended. Evaluation done at random point.
Time taken: 7.8265
Function value obtained: 738.5091
Current minimum: 592.5586
Iteration No: 3 started. Evaluating function at random point.
[255, 29, 'sqrt', 3, 3, False]
Iteration No: 3 ended. Evaluation done at random point.
Time taken: 0.6771
Function value obtained: 585.8501
Current minimum: 585.8501
Iteration No: 4 started. Evaluating function at random point.
[918, 15, 'log2', 4, 2, False]
Iteration No: 4 ended. Evaluation done at random point.
Time taken: 1.9115
Function value obtained: 576.2952
Current minimum: 576.2952
Iteration No: 5 started. Evaluating function at random point.
[586, 99, 'log2', 2, 3, Tr



Iteration No: 76 ended. Search finished for the next optimal point.
Time taken: 3.8873
Function value obtained: 585.2548
Current minimum: 565.8574
Iteration No: 77 started. Searching for the next optimal point.
[1000, 63, 'log2', 4, 3, False]
Iteration No: 77 ended. Search finished for the next optimal point.
Time taken: 4.3319
Function value obtained: 571.8044
Current minimum: 565.8574
Iteration No: 78 started. Searching for the next optimal point.
[1000, 83, 'log2', 3, 3, True]
Iteration No: 78 ended. Search finished for the next optimal point.
Time taken: 4.6933
Function value obtained: 569.4142
Current minimum: 565.8574
Iteration No: 79 started. Searching for the next optimal point.
[1000, 100, 'sqrt', 4, 3, False]




Iteration No: 79 ended. Search finished for the next optimal point.
Time taken: 3.9469
Function value obtained: 582.0696
Current minimum: 565.8574
Iteration No: 80 started. Searching for the next optimal point.
[100, 10, 'sqrt', 2, 3, True]
Iteration No: 80 ended. Search finished for the next optimal point.
Time taken: 1.6086
Function value obtained: 584.0014
Current minimum: 565.8574
Iteration No: 81 started. Searching for the next optimal point.
[1000, 100, 'sqrt', 4, 3, True]
Iteration No: 81 ended. Search finished for the next optimal point.
Time taken: 3.5341
Function value obtained: 581.4830
Current minimum: 565.8574
Iteration No: 82 started. Searching for the next optimal point.
[1000, 65, 'log2', 4, 3, False]
Iteration No: 82 ended. Search finished for the next optimal point.
Time taken: 3.7003
Function value obtained: 575.5319
Current minimum: 565.8574
Iteration No: 83 started. Searching for the next optimal point.
[128, 10, 'log2', 4, 3, False]
Iteration No: 83 ended. Search 



Iteration No: 84 ended. Search finished for the next optimal point.
Time taken: 3.1380
Function value obtained: 574.8183
Current minimum: 565.8574
Iteration No: 85 started. Searching for the next optimal point.
[138, 10, 'log2', 4, 3, False]
Iteration No: 85 ended. Search finished for the next optimal point.
Time taken: 1.6526
Function value obtained: 566.8364
Current minimum: 565.8574
Iteration No: 86 started. Searching for the next optimal point.
[804, 38, 'log2', 4, 3, False]
Iteration No: 86 ended. Search finished for the next optimal point.
Time taken: 2.7670
Function value obtained: 569.5374
Current minimum: 565.8574
Iteration No: 87 started. Searching for the next optimal point.
[100, 40, 'log2', 4, 3, False]
Iteration No: 87 ended. Search finished for the next optimal point.
Time taken: 1.6949
Function value obtained: 572.2396
Current minimum: 565.8574
Iteration No: 88 started. Searching for the next optimal point.
[1000, 10, 'sqrt', 4, 3, True]
Iteration No: 88 ended. Search f



Iteration No: 93 ended. Search finished for the next optimal point.
Time taken: 3.9949
Function value obtained: 571.5696
Current minimum: 565.8574
Iteration No: 94 started. Searching for the next optimal point.
[1000, 10, 'log2', 4, 3, False]




Iteration No: 94 ended. Search finished for the next optimal point.
Time taken: 4.6460
Function value obtained: 570.6951
Current minimum: 565.8574
Iteration No: 95 started. Searching for the next optimal point.
[962, 10, 'log2', 4, 3, False]
Iteration No: 95 ended. Search finished for the next optimal point.
Time taken: 3.8108
Function value obtained: 572.6602
Current minimum: 565.8574
Iteration No: 96 started. Searching for the next optimal point.
[1000, 10, 'log2', 4, 3, False]




Iteration No: 96 ended. Search finished for the next optimal point.
Time taken: 4.3315
Function value obtained: 575.8609
Current minimum: 565.8574
Iteration No: 97 started. Searching for the next optimal point.
[966, 10, 'log2', 2, 3, False]
Iteration No: 97 ended. Search finished for the next optimal point.
Time taken: 3.5974
Function value obtained: 570.5405
Current minimum: 565.8574
Iteration No: 98 started. Searching for the next optimal point.
[1000, 10, 'log2', 2, 3, True]
Iteration No: 98 ended. Search finished for the next optimal point.
Time taken: 4.3418
Function value obtained: 571.4341
Current minimum: 565.8574
Iteration No: 99 started. Searching for the next optimal point.
[1000, 10, 'log2', 4, 3, False]




Iteration No: 99 ended. Search finished for the next optimal point.
Time taken: 4.4445
Function value obtained: 571.5476
Current minimum: 565.8574
Iteration No: 100 started. Searching for the next optimal point.
[1000, 10, 'log2', 4, 3, False]




Iteration No: 100 ended. Search finished for the next optimal point.
Time taken: 4.3576
Function value obtained: 574.8489
Current minimum: 565.8574
Iteration No: 101 started. Searching for the next optimal point.
[600, 10, 'log2', 4, 3, False]
Iteration No: 101 ended. Search finished for the next optimal point.
Time taken: 3.4039
Function value obtained: 569.2937
Current minimum: 565.8574
Iteration No: 102 started. Searching for the next optimal point.
[832, 10, 'log2', 4, 3, False]
Iteration No: 102 ended. Search finished for the next optimal point.
Time taken: 3.5988
Function value obtained: 574.9826
Current minimum: 565.8574
Iteration No: 103 started. Searching for the next optimal point.
[522, 10, 'log2', 4, 3, False]
Iteration No: 103 ended. Search finished for the next optimal point.
Time taken: 3.4522
Function value obtained: 571.9806
Current minimum: 565.8574
Iteration No: 104 started. Searching for the next optimal point.
[859, 10, 'log2', 4, 3, False]
Iteration No: 104 ended.



Iteration No: 107 ended. Search finished for the next optimal point.
Time taken: 4.6609
Function value obtained: 571.3573
Current minimum: 565.8574
Iteration No: 108 started. Searching for the next optimal point.
[978, 10, 'log2', 4, 3, False]
Iteration No: 108 ended. Search finished for the next optimal point.
Time taken: 4.7156
Function value obtained: 578.0823
Current minimum: 565.8574
Iteration No: 109 started. Searching for the next optimal point.
[116, 10, 'log2', 4, 3, False]
Iteration No: 109 ended. Search finished for the next optimal point.
Time taken: 2.1449
Function value obtained: 563.9186
Current minimum: 563.9186
Iteration No: 110 started. Searching for the next optimal point.
[102, 15, 'log2', 2, 3, False]
Iteration No: 110 ended. Search finished for the next optimal point.
Time taken: 2.9302
Function value obtained: 574.2480
Current minimum: 563.9186
Iteration No: 111 started. Searching for the next optimal point.
[167, 11, 'log2', 4, 3, True]
Iteration No: 111 ended. 