In [1]:
from util.gen_utils import *
from util.ml_utils import *
from util.plot_utils import *

%load_ext autoreload
%autoreload 2
%matplotlib inline

# Load training data

In [2]:
path_prefix = '../../data/rnaseq_stanford_all/discovery/'

train_all = rnaseq_and_meta_data(path_prefix + "sample_meta_postQC.csv", 
                                 path_prefix + "logCPM_postQC_RemovedBatch.csv",
                                 path_prefix + "TMM_postQC.csv",
                                 is_logCPM_path_R = True,
                                 counts_index_cols = [0]
                                )

In [3]:
#Train mask
training_frac = 0.8 #Bcz all PE_STAN data was used in DE analysis and that is used for feature selection, all should be used for training

masks = data_masks(train_frac = training_frac, seed = 1041, label_col = 'case') 
masks.add_mask('is_collected_pre17wks', (train_all.meta.ga_at_collection <= 16))
masks.add_mask('is_training', masks.get_sampled_mask(train_all.meta, addtnl_mask_label = 'is_collected_pre17wks', blocking_col = 'subject'))
masks.add_mask('is_pp', (train_all.meta.is_pp == 1)) #Want to filter post-partum samples

#Logical combinations
masks.add_mask_logical_and_combinations('is_training', 'is_collected_pre17wks')
masks.add_mask_logical_and_combinations('is_training', 'is_pp')

#Filter train_all to relevant samples
train_mask = masks.masks['is_training_and_is_collected_pre17wks']
val_mask = masks.masks['not_is_training_and_is_collected_pre17wks']
train_late_mask = np.logical_and(masks.masks['is_training_and_not_is_collected_pre17wks'], 
                                 masks.masks['is_training_and_not_is_pp'])

train_split = train_all.filter_samples(train_mask, inplace = False)
val_split = train_all.filter_samples(val_mask, inplace = False)

train_all.filter_samples(np.logical_or(train_mask, val_mask))

# Feature selection and model training
* Start with features identified in DE and identify logFC changes pre and post 17 weeks
* Filter initial list based on:
    * Coefficient of variation cutoff [Want genes that appear to have stable logFC]
    * logFC cutoff [Want genes that appear to be sig different between PE and control]
* To choose appro cutoff for each filter, do param sweep + LR model
* Post model selection, choose appro threshold using training data

In [4]:
cv_cutoff_vals = np.array([0.5,1,2,10]) #CV = 10 means don't consider CV as part of feat selection
logFC_cutoff_vals = np.arange(0.25, 1.25, 0.25) #logFC thresholding, visualizing data previously, mean |logFC| < 1.0 so sweep between 0 - 1.0

In [5]:
de_PE = de_data("out/de/DE_PEspecific_onlyGA_changes_timeToPE_w_covar_bmi_fsex_w_batch.csv", 
                alpha = 0.05, de_type = 'PE preg changes', to_round = False)

ml_data_kwargs = {'to_norm_to_stable_genes' : True, 
                  'stable_genes' : de_PE.de.loc[de_PE.de.adj_pval > 0.99].index,
                  'to_center' : True, 
                  'to_scale' : True,
                  'impute_dropout' : False
                 }

lfc_col_name = 'Pre 17 weeks'

In [6]:
logFC_pre17_sig = logFC_data_by_group(de_PE.sig_genes,
                                      {True : lfc_col_name}, 
                                      group_col = 'pre17_weeks', 
                                      CV_cutoff = 1.0, logFC_cutoff = 0.5
                                     )

logFC_pre17_sig.get_logFC_and_CI_by_group(train_split.rnaseq.logCPM.loc[de_PE.sig_genes], 
                                          train_split.meta.join(train_mask.rename('pre17_weeks')))

Now calculating logFC for Pre 17 weeks
Now estimating logFC confidence interval for Pre 17 weeks
1000 resampling iterations completed
2000 resampling iterations completed
Identifying when during gestation we observe changes


In [7]:
best_fit = training_pipeline(train_split, logFC_pre17_sig, [lfc_col_name],
                             cv_cutoffs_to_try = cv_cutoff_vals, 
                             logFC_cutoffs_to_try = logFC_cutoff_vals,
                             val_rnaseq_meta = val_split, 
                             **ml_data_kwargs
                            )

So far - Best score = 0.82, Best val score = 0.98 with 13 features and CV cutoff = 0.50, logFC cutoff = 0.25
So far - Best score = 0.97, Best val score = 0.99 with 23 features and CV cutoff = 1.00, logFC cutoff = 0.25
So far - Best score = 0.99, Best val score = 1.00 with 18 features and CV cutoff = 10.00, logFC cutoff = 0.25
Best score = 0.99, Best val score = 1.00 with 18 features and CV cutoff = 10.00, logFC cutoff = 0.25


In [8]:
pd.DataFrame(index = best_fit['features'], data = {'coef' : best_fit['model'].coef_[0, :], 
                                                   'odds' :  np.exp(best_fit['model'].coef_[0, :]).round(2),
                                                  'logFC' : logFC_pre17_sig.logFC.loc[best_fit['features'], 'Pre 17 weeks'],
                                                  'CV' : logFC_pre17_sig.CV.loc[best_fit['features'], 'Pre 17 weeks']}
            )

Unnamed: 0_level_0,Unnamed: 1_level_0,coef,odds,logFC,CV
gene_name,gene_num,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
TRIM21,ENSG00000132109,-0.127229,0.88,-0.68,1.35
Y_RNA,ENSG00000201412,0.617053,1.85,1.42,0.61
LRRC58,ENSG00000163428,-0.236199,0.79,-0.29,5.62
NDUFV3,ENSG00000160194,-0.417266,0.66,-0.63,2.73
KIAA1109,ENSG00000138688,-0.119836,0.89,-0.55,1.51
MYLIP,ENSG00000007944,-0.067027,0.94,-1.1,0.95
USB1,ENSG00000103005,-0.160827,0.85,-0.54,1.3
RNF149,ENSG00000163162,-0.204563,0.82,-1.26,1.33
TFIP11,ENSG00000100109,-0.108378,0.9,-0.49,2.47
CAMK2G,ENSG00000148660,-0.436554,0.65,-0.48,1.15


# Find optimal threshold
* Using training data

In [9]:
train_data = ML_data(rnaseq_meta = train_all, y_col = 'case', 
                     group_col = 'subject', features = best_fit['features'],
                     only_gene_name = True,
                    **ml_data_kwargs)

roc_aucs_pr_train = get_auc_roc_CI(best_fit['model'], train_data)

In [10]:
tshld_at_10pct_fpr = roc_aucs_pr_train['roc_curve_tshlds'][np.where(roc_aucs_pr_train['fpr'].round(1) == 0.1)].min()
tshld_at_10pct_fpr = ((tshld_at_10pct_fpr*100 // 5) * 5)/100 #Round to nearest 5th since data is noisy and having an extremely specific threshold seems odd at this point
tshld_at_10pct_fpr

0.35

# Check model performance
* Check with training data
* Check with independent dataset - Del Vecchio et al
* Check with qPCR dataset from separate samples

In [11]:
get_classification_results('Training', best_fit['model'], train_data, threshold = tshld_at_10pct_fpr)

Training results:
ROC AUC = 0.99
Report:
              precision    recall  f1-score   support

           0       1.00      0.85      0.92        61
           1       0.73      1.00      0.84        24

    accuracy                           0.89        85
   macro avg       0.86      0.93      0.88        85
weighted avg       0.92      0.89      0.90        85

Confusion matrix:
[[52  9]
 [ 0 24]]



In [12]:
val_data_prefix = "../../data/rnaseq_stanford_all/val/" 
val_data = rnaseq_and_meta_data(val_data_prefix + "sample_meta_postQC.csv", 
                                 val_data_prefix + "htseq_postQC.csv", 
                                 val_data_prefix + "TMM_postQC.csv")

val_data.filter_samples((val_data.meta.ga_at_collection <= 16))

val_ml_data = ML_data(val_data, y_col = 'case',
                      features = best_fit['features'],
                      only_gene_name = True,
                      **ml_data_kwargs,
                     fitted_scaler = train_data.fitted_scaler
                     )

get_classification_results('Validation', best_fit['model'], val_ml_data, threshold = tshld_at_10pct_fpr)

Validation results:
ROC AUC = 0.71
Report:
              precision    recall  f1-score   support

           0       0.91      0.57      0.70        35
           1       0.29      0.75      0.41         8

    accuracy                           0.60        43
   macro avg       0.60      0.66      0.56        43
weighted avg       0.79      0.60      0.65        43

Confusion matrix:
[[20 15]
 [ 2  6]]



In [13]:
delvecchio_all = rnaseq_and_meta_data("../../data/delvecchio_data/sample_meta_w_ga_col.csv", 
                                 "../../data/delvecchio_data/htseq_merged.csv", 
                                 "../../data/delvecchio_data/TMM.csv",  mL_plasma = 0.2)

gest_ht_id = ['SRR12214586', 'SRR12214596', 'SRR12214601']

#ML data should only contains samples <= 16 weeks which per Del Vecchio et al are T1 samples
delvecchio_all.filter_samples((delvecchio_all.meta.term == 1))
delvecchio_all.filter_samples(~delvecchio_all.meta.index.isin(gest_ht_id))

delvecchio_ml_data = ML_data(delvecchio_all, y_col = 'case',
                             features = best_fit['features'],
                             **ml_data_kwargs,
                             fitted_scaler = train_data.fitted_scaler,
                            )

get_classification_results('Delvecchio [PE vs Normotensive AND Other APOs]', best_fit['model'], delvecchio_ml_data, threshold = tshld_at_10pct_fpr)

only_pe_v_NT = delvecchio_all.meta.loc[delvecchio_all.meta.complication_during_pregnancy.isin(['No Complications', 'Preeclampsia/gestational hypertension'])].index
delvecchio_ml_data_only_pe_v_NT = delvecchio_ml_data.filter_samples(only_pe_v_NT)

get_classification_results('Delvecchio [PE vs Normotensive]', best_fit['model'], delvecchio_ml_data_only_pe_v_NT, threshold = tshld_at_10pct_fpr)

Delvecchio [PE vs Normotensive AND Other APOs] results:
ROC AUC = 0.74
Report:
              precision    recall  f1-score   support

           0       0.89      1.00      0.94        17
           1       1.00      0.60      0.75         5

    accuracy                           0.91        22
   macro avg       0.95      0.80      0.85        22
weighted avg       0.92      0.91      0.90        22

Confusion matrix:
[[17  0]
 [ 2  3]]

Delvecchio [PE vs Normotensive] results:
ROC AUC = 0.80
Report:
              precision    recall  f1-score   support

           0       0.80      1.00      0.89         8
           1       1.00      0.60      0.75         5

    accuracy                           0.85        13
   macro avg       0.90      0.80      0.82        13
weighted avg       0.88      0.85      0.84        13

Confusion matrix:
[[8 0]
 [2 3]]



In [14]:
w_pred_dv = add_pred_to_meta(delvecchio_all.meta, best_fit['model'], delvecchio_ml_data, threshold = tshld_at_10pct_fpr)
w_pred_dv.groupby('complication_during_pregnancy').agg({'score' : ['median', 'mean', 'std']})

Unnamed: 0_level_0,score,score,score
Unnamed: 0_level_1,median,mean,std
complication_during_pregnancy,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2
Chronic hypertension,0.183657,0.183657,0.126258
Gestational Diabetes,0.144472,0.149283,0.084637
No Complications,0.071676,0.098006,0.07991
Preeclampsia/gestational hypertension,0.373341,0.332159,0.250379


In [15]:
gapps = rnaseq_and_meta_data("../../data/gapps/sample_meta_postQC.csv", 
                                 "../../data/gapps/htseq_postQC.csv", 
                                 "../../data/gapps/TMM_postQC.csv",  mL_plasma = 1.0)

gapps_ml_data = ML_data(gapps, y_col = 'case',
                        features = best_fit['features'],
                        only_gene_name = True,
                        **ml_data_kwargs,
                        fitted_scaler = train_data.fitted_scaler
                       )

get_classification_results('GAPPs', best_fit['model'], gapps_ml_data, threshold = tshld_at_10pct_fpr)

GAPPs results:
ROC AUC = 0.72
Report:
              precision    recall  f1-score   support

           0       0.78      0.69      0.73        61
           1       0.46      0.57      0.51        28

    accuracy                           0.65        89
   macro avg       0.62      0.63      0.62        89
weighted avg       0.68      0.65      0.66        89

Confusion matrix:
[[42 19]
 [12 16]]



# Save model and features

In [16]:
write_pkl(best_fit['model'], 'out/ml/fitted_model.pkl')
write_pkl(masks, 'out/ml/train_data_masks.pkl')
write_pkl(tshld_at_10pct_fpr, 'out/ml/selected_tshld.pkl')

best_fit['features'].to_frame().to_csv('out/ml/fitted_model_features.csv', index = False)
ml_data_kwargs['stable_genes'].to_frame().to_csv('out/ml/stable_genes.csv', index = False)