In [1]:
import pandas as pd
import numpy as np
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
from sklearn.model_selection import cross_val_score, KFold
from sklearn.inspection import permutation_importance
from sklearn.metrics import mean_absolute_error
from scipy import stats
import seaborn as sns
import matplotlib.pyplot as plt

import utils.dev_config as dev_conf
import utils.preprocessing as prep
import utils.optimization as opt
import utils.feature_selection as feat_sel

In [2]:
dirs = dev_conf.get_dev_directories("../dev_paths.txt")
unified_dsets = ["unified_cervical_data", "unified_uterine_data", "unified_uterine_endometrial_data"]
matrisome_list = f"{dirs.data_dir}/matrisome/matrisome_hs_masterlist.tsv"

In [3]:
dset_idx = 0

In [4]:
matrisome_df = prep.load_matrisome_df(matrisome_list)

In [5]:
seed = 123
rand = np.random.RandomState()

# Load and filter survival data

In [6]:
event_code = {"Alive": 0, "Dead": 1}
covariate_cols = ["figo_stage", "age_at_diagnosis", "race", "ethnicity"]
dep_cols = ["vital_status", "survival_time"]
cat_cols = ["race", "ethnicity", "figo_chr"]
survival_df = prep.load_survival_df(f"{dirs.data_dir}/{unified_dsets[dset_idx]}/survival_data.tsv", event_code)

In [7]:
filtered_survival_df = (
    prep.decode_figo_stage(survival_df[["sample_name"] + dep_cols + covariate_cols].dropna(), to="c")
        .query("vital_status == 1")
        .drop(["vital_status"], axis=1)
        .pipe(pd.get_dummies, columns=cat_cols)
        .reset_index(drop = True)
)
filtered_survival_df.columns = filtered_survival_df.columns.str.replace(' ', '_')

print(filtered_survival_df.shape)
# filtered_survival_df.head()

(66, 16)


# Load normalized matrisome count data

In [8]:
norm_matrisome_counts_df = pd.read_csv(f"{dirs.data_dir}/{unified_dsets[dset_idx]}/norm_matrisome_counts.tsv", sep='\t')
norm_filtered_matrisome_counts_t_df = prep.transpose_df(
    norm_matrisome_counts_df[["geneID"] + list(filtered_survival_df.sample_name)], "geneID", "sample_name"
)
print(norm_filtered_matrisome_counts_t_df.shape)
# norm_filtered_matrisome_counts_t_df.head()

(66, 1009)


# Join survival and count data

In [9]:
joined_df = (
    pd.merge(filtered_survival_df, norm_filtered_matrisome_counts_t_df, on="sample_name")
        .set_index("sample_name")
)
print(joined_df.shape)
# joined_df.head()

(66, 1023)


# Build models

In [10]:
gbr_h_param_df = pd.read_csv(f"{unified_dsets[dset_idx]}_opt_gbr_h_params.tsv", sep="\t")
gbrs = [
    GradientBoostingRegressor(
        **dict(zip(gbr_h_param_df.columns[:-1], gbr_h_param_df.iloc[i, :-1])), loss="lad", random_state=rand
    ) for i in range(gbr_h_param_df.shape[0])
]

In [11]:
rfr_h_param_df = pd.read_csv(f"{unified_dsets[dset_idx]}_opt_rfr_h_params.tsv", sep="\t")
rfrs = [
    RandomForestRegressor(
        **dict(zip(rfr_h_param_df.columns[:-1], rfr_h_param_df.iloc[i, :-1])), random_state=rand
    ) for i in range(rfr_h_param_df.shape[0])
]

# Collect cross validated feature permutation results

In [12]:
def collect_feature_perm_results(models, x_df, y_df, r, gene_cols):
    all_mean_perm_results = []
    all_ref_scores = []
    all_perm_res_dfs = []
    
    for i, m in enumerate(models):
        perm_results, ref_scores = opt.cv_permutation_importance(m, x_df, y_df, "neg_mean_absolute_error", k=5, random_state=r)
        perm_importances = np.concatenate([r.importances for r in perm_results], axis=1)
        perm_importance_means = np.mean(perm_importances, axis=1)
        
        all_mean_perm_results.append(perm_importance_means)
        all_ref_scores.append(ref_scores)
        
        res_df = feat_sel.gather_perm_res(x_df, perm_importance_means, np.mean(ref_scores), gene_cols)
        res_df = res_df.rename(columns={"mean_imp": f"mean_imp_{i}", "score_pct_improvement": f"score_pct_improvement_{i}"})
        all_perm_res_dfs.append(res_df)
    
    return all_mean_perm_results, all_ref_scores, all_perm_res_dfs


def merge_perm_results(perm_res_dfs):
    merge_df = perm_res_dfs[0]
    for i in range(1, len(perm_res_dfs)):
        merge_df = merge_df.merge(perm_res_dfs[i], on = "geneID", how = "inner")
    merge_df = (
        merge_df.assign(consensus_imp_mean = merge_df.filter(regex="mean_imp").mean(axis=1))
            .assign(consensus_imp_std = merge_df.filter(regex="mean_imp").std(axis=1))
    )
    merge_df = merge_df.assign(consensus_imp_cv = merge_df.consensus_imp_std / merge_df.consensus_imp_mean)
    return merge_df

In [13]:
rand.seed(seed)
x_df, y_df = prep.shuffle_data(joined_df, rand)

## GBR

In [14]:
gbr_mean_perm_res, gbr_ref_scores, gbr_perm_res_dfs = collect_feature_perm_results(
    gbrs, x_df, y_df, rand, norm_filtered_matrisome_counts_t_df.columns[1:]
)

In [15]:
gbr_merge_df = merge_perm_results(gbr_perm_res_dfs)

## Do these GBRs do any better than just guessing the median?

In [16]:
baseline = mean_absolute_error(np.repeat(np.median(np.array(y_df.survival_time)), y_df.shape[0]), y_df.values.squeeze())
gbr_model_perf = -np.array(gbr_ref_scores).mean(axis=1)

In [17]:
gbr_pct_change = (gbr_model_perf - baseline) / baseline
print(f"Model % change: {gbr_pct_change}")
print(f"Avg. model % change: {gbr_pct_change.mean()}")

Model % change: [ 0.0335534   0.01799931 -0.017535    0.04377563  0.00991889]
Avg. model % change: 0.017542445406661698


In [18]:
gbr_merge_df.sort_values("consensus_imp_mean", ascending=False)

Unnamed: 0,geneID,mean_imp_0,score_pct_improvement_0,mean_imp_1,score_pct_improvement_1,mean_imp_2,score_pct_improvement_2,mean_imp_3,score_pct_improvement_3,mean_imp_4,score_pct_improvement_4,consensus_imp_mean,consensus_imp_std,consensus_imp_cv
701,EGFL7,1.274956,0.237987,2.344971,0.444407,4.612491,0.905752,5.297876,0.979232,1.162500,0.222074,2.938559,1.913291,0.651098
763,ADAMTS14,4.013052,0.749088,1.514602,0.287040,3.866497,0.759262,0.459347,0.084903,1.973276,0.376957,2.365355,1.539375,0.650801
189,IL1B,3.239555,0.604705,0.594968,0.112755,2.144449,0.421104,-0.079658,-0.014724,2.743965,0.524183,1.728656,1.417904,0.820235
568,FAM20C,1.833005,0.342154,3.078460,0.583414,1.189653,0.233611,0.816079,0.150840,0.422618,0.080733,1.467963,1.039484,0.708113
311,LOXL2,-0.538700,-0.100555,0.214775,0.040703,-0.490758,-0.096370,2.578367,0.476572,4.670944,0.892296,1.286926,2.278266,1.770316
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
140,ANXA9,-1.713522,-0.319851,-0.003250,-0.000616,-1.204686,-0.236563,-1.064808,-0.196814,-1.293689,-0.247135,-1.055991,0.636258,-0.602522
770,FCN2,-0.457108,-0.085325,-2.097356,-0.397480,-3.491510,-0.685626,0.634042,0.117193,0.000000,0.000000,-1.082386,1.683946,-1.555772
720,SERPINB12,-4.669633,-0.871647,-0.141241,-0.026767,-1.601803,-0.314545,-0.184214,-0.034049,0.919151,0.175587,-1.135548,2.169252,-1.910313
124,IFNG,-2.104490,-0.392830,-2.352906,-0.445911,-0.175947,-0.034551,-1.258233,-0.232565,0.000000,0.000000,-1.178315,1.076687,-0.913751


## RFR

In [19]:
rfr_mean_perm_res, rfr_ref_scores, rfr_perm_res_dfs = collect_feature_perm_results(
    rfrs, x_df, y_df, rand, norm_filtered_matrisome_counts_t_df.columns[1:]
)

In [20]:
rfr_merge_df = merge_perm_results(rfr_perm_res_dfs)

## Do these GBRs do any better than just guessing the median?

In [21]:
baseline = mean_absolute_error(np.repeat(np.median(np.array(y_df.survival_time)), y_df.shape[0]), y_df.values.squeeze())
rfr_model_perf = -np.array(rfr_ref_scores).mean(axis=1)

In [22]:
rfr_pct_change = (rfr_model_perf - baseline) / baseline
print(f"Model % change: {rfr_pct_change}")
print(f"Avg. model % change: {rfr_pct_change.mean()}")

Model % change: [0.09025692 0.14862811 0.10603461 0.06880937 0.13704611]
Avg. model % change: 0.11015502352614444


In [23]:
rfr_merge_df.sort_values("consensus_imp_mean", ascending=False)

Unnamed: 0,geneID,mean_imp_0,score_pct_improvement_0,mean_imp_1,score_pct_improvement_1,mean_imp_2,score_pct_improvement_2,mean_imp_3,score_pct_improvement_3,mean_imp_4,score_pct_improvement_4,consensus_imp_mean,consensus_imp_std,consensus_imp_cv
833,MUC4,0.850602,0.150518,1.378401,0.231519,1.129827,0.197076,1.128915,0.203775,1.918769,0.325563,1.281303,0.402325,0.313997
527,TNFSF10,0.536004,0.094848,0.552604,0.092817,1.137950,0.198493,2.731915,0.493126,0.292160,0.049572,1.050127,0.990264,0.942995
638,LGALS9C,0.649792,0.114984,1.162093,0.195188,0.256338,0.044713,1.457711,0.263125,1.144337,0.194163,0.934054,0.477111,0.510796
367,LAMB1,0.035714,0.006320,0.105250,0.017678,0.278090,0.048507,2.700786,0.487507,1.058968,0.179678,0.835762,1.119471,1.339462
13,ANGPT2,0.084429,0.014940,0.921257,0.154736,1.130410,0.197178,0.872174,0.157432,0.965252,0.163777,0.794704,0.408746,0.514337
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
425,THBS1,-0.360599,-0.063810,-0.329077,-0.055272,0.125110,0.021823,-1.152208,-0.207980,-1.956084,-0.331895,-0.734572,0.823081,-1.120490
506,CLEC4C,-0.054000,-0.009556,-0.812306,-0.136437,-0.457129,-0.079737,-0.845818,-0.152675,-1.523145,-0.258437,-0.738479,0.543218,-0.735590
160,THBS2,-0.715044,-0.126530,-0.527280,-0.088563,-0.245511,-0.042825,-0.882252,-0.159251,-1.556854,-0.264156,-0.785388,0.491828,-0.626222
460,IL26,0.074082,0.013109,-2.255935,-0.378912,0.417379,0.072804,-0.664633,-0.119970,-1.698015,-0.288107,-0.825425,1.138782,-1.379632
