In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
import numpy as np
import base64
import os
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn import model_selection
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.inspection import permutation_importance
import shap
from statannot import add_stat_annotation
from xgboost.sklearn import XGBClassifier
from sklearn.metrics import roc_auc_score

In [127]:
# read data in
cytof_freq = pd.read_csv("../../data/TMA36_project/CyTOF/processed/Data_paper2/both/cytof_freq.csv", index_col=0)
cytof_medianprot = pd.read_csv("../../data/TMA36_project/CyTOF/processed/Data_paper2/both/cytof_medianprot.csv", index_col=0)
all_vars_scaled = pd.read_csv("../../data/TMA36_project/data_integration/cytof_rna_hm.csv", index_col=0)
all_vars_raw = pd.read_csv("../../data/TMA36_project/data_integration/cytof_rna_hm_raw.csv", index_col=0)
clusters_patients = pd.read_csv("../../data/TMA36_project/data_integration/clusters_patients.csv", index_col=0)
clusters_features = pd.read_csv("../../data/TMA36_project/data_integration/clusters_features.csv", index_col=0)
cde = pd.read_csv("../../data/TMA36_project/CDE/CDE_TMA36_2021SEPT21_DR_MF.csv", index_col=1)
rna_xcell = pd.read_csv("../../data/TMA36_project/RNA_Seq/deconvolution/output/rna_only/XCELL/xCell_rnaseq_tpm_xCell_1212060320.txt", index_col=0, sep="\t")
rna_sampleinfo = pd.read_csv("../../data/TMA36_project/RNA_Seq/processed/rnaseq_batchinfo.csv")

In [128]:
df_model = all_vars_scaled.T
pd.options.display.max_columns = 100
df_model.head()

Unnamed: 0,11938,13376,13436,8356,12994,12929,12924,13622,13771,13651,13074,11817,13536,11906,13276,13207,13317,12915,13769,11855,11851,11538,12889,12931,11813,11646,11759,13014,14855,11952,11561,11886,13724,14958,12281,12323,14955,15001,14048,15224,14965,15325,14962,15187,15506,14301,13538,15326,15569,14610,13988,13155,15083,11652,15002,12546,12890,15467,15741
ECC_3,0.1868,-0.631154,0.005564,-0.134184,-0.680289,-0.54291,3.58049,-0.583956,2.753695,2.435661,-0.680436,-0.287625,-0.095565,0.460352,-0.683008,-0.679394,-0.461542,2.104494,2.001386,-0.121782,-0.402633,-0.574554,-0.286718,-0.403456,1.678489,-0.506544,-0.259214,-0.517432,-0.347619,2.038249,1.311233,0.866032,-0.635349,-0.50471,0.361352,-0.49977,-0.60165,-0.362693,-0.348633,-0.647818,-0.581567,-0.632856,-0.509335,-0.438699,-0.648579,-0.679033,-0.292368,-0.35673,-0.6032,-0.3768,-0.457914,-0.683553,-0.538951,-0.335463,-0.134634,1.452327,-0.677761,-0.473312,-0.334732
ECC_5,0.307454,3.082364,2.927108,-0.507525,0.511481,-0.41101,-0.289044,-0.230105,-0.557147,-0.062763,-0.468206,2.868349,-0.171945,-0.049849,0.596095,-0.08711,-0.264221,-0.579679,-0.395031,-0.521292,-0.044152,0.024264,2.154958,4.03678,-0.44882,-0.575623,-0.600969,-0.340768,-0.006732,-0.306953,-0.538051,-0.502654,0.068645,-0.436453,-0.303971,-0.247562,-0.593408,-0.396968,0.061803,-0.56931,-0.425139,-0.392617,0.139315,-0.629383,-0.640711,-0.574171,1.125695,-0.578024,-0.243788,-0.223385,-0.600951,-0.192706,-0.503454,-0.590068,-0.454688,-0.572922,-0.432308,0.193056,-0.535733
fmes_3,0.30759,-0.496758,-0.671509,-0.020063,-0.667603,0.619594,-0.589446,-0.458566,-0.508911,0.182973,-0.668572,-0.242877,0.765386,0.884826,-0.662929,-0.674882,-0.649432,-0.39429,0.460608,-0.408942,0.528974,1.436831,-0.274457,-0.280292,-0.198217,1.641003,-0.128624,-0.503614,-0.382613,-0.509255,-0.216284,3.900234,-0.043778,4.973845,-0.497254,-0.54629,-0.567008,-0.446554,-0.329586,-0.374933,-0.283016,-0.431958,-0.3121,-0.551555,-0.606578,-0.615426,0.454904,-0.64166,0.319358,0.670903,-0.245634,-0.677165,0.339833,0.764809,0.107356,-0.160998,-0.66674,-0.308526,-0.444131
OtherI_4,0.656051,-0.588162,-0.634426,6.463594,-0.754413,0.889071,-0.334573,-0.617773,-0.159927,0.004934,-0.770057,0.095343,0.042088,0.004777,-0.75637,-0.714602,-0.582119,-0.396937,-0.091827,-0.486931,0.235172,0.036656,0.229782,0.241448,0.521,0.595265,0.679387,-0.708688,0.607122,1.068751,0.193798,0.563759,0.059703,0.685431,-0.349649,-0.457279,-0.540823,-0.354585,-0.414091,-0.343938,-0.210919,-0.525598,-0.29554,-0.53837,-0.498022,-0.690175,0.024667,-0.601079,-0.032184,-0.425829,0.347124,-0.768894,0.451955,1.230099,0.371382,-0.293839,-0.771624,0.029977,-0.619095
HLA DR,0.908159,-1.394724,-1.396901,2.0768,-1.41976,-0.21461,1.044602,-1.137505,1.017268,1.216033,-1.442311,-0.004341,0.274479,0.323914,-1.580919,-1.491276,-1.333701,1.275205,1.308717,0.254624,-0.010199,0.824075,-0.133009,-0.038215,1.50542,0.560615,-0.420794,-1.533449,-0.595232,1.059516,0.777863,1.556205,-0.667249,0.958461,1.350706,-0.218287,-0.151409,-0.438091,-0.330587,-0.56404,-0.478065,-1.212089,0.602124,-0.279069,0.903941,-1.00308,0.23984,-1.203248,0.566487,-0.057051,1.024275,-1.600645,0.696642,0.636704,0.726871,0.486592,-1.59288,-1.123442,0.89004


# Model training

In [129]:
# train-test-split
X_train, X_test, y_train, y_test = train_test_split(df_model, clusters_features['cluster'], test_size=0.2, random_state=66)

In [130]:
y_train

Anchoring of the basal body to the plasma membrane    4
AUTO_LARGEST_PLANAR_ORTHO_DIAMETER_MM                 2
Collagen formation                                    3
GPCR ligand binding                                   1
Signaling by Interleukins                             1
                                                     ..
Semaphorin interactions                               1
Cilium Assembly                                       4
SOLID_VOLUME_ML                                       2
COPI dependent Golgi to ER retrograde traffic         4
CORONAL_SHORT_AXIS_MM                                 2
Name: cluster, Length: 240, dtype: int64

In [131]:
# Hyperparameter tuning
rfc = RandomForestClassifier(random_state = 1)
n_estimators = [100, 300, 500, 800, 1200]
max_depth = [5, 8, 15, 25, 30]
min_samples_split = [2, 5, 10, 15, 100]
min_samples_leaf = [1, 2, 5, 10] 

hyperF = dict(n_estimators = n_estimators, max_depth = max_depth,  
              min_samples_split = min_samples_split, 
             min_samples_leaf = min_samples_leaf)

gridF = GridSearchCV(rfc, hyperF, 
                     scoring = 'f1_macro',
                     cv = 3, 
                     verbose = True, 
                     n_jobs = 10)
bestF = gridF.fit(X_train, y_train)
bestF

Fitting 3 folds for each of 500 candidates, totalling 1500 fits


GridSearchCV(cv=3, estimator=RandomForestClassifier(random_state=1), n_jobs=10,
             param_grid={'max_depth': [5, 8, 15, 25, 30],
                         'min_samples_leaf': [1, 2, 5, 10],
                         'min_samples_split': [2, 5, 10, 15, 100],
                         'n_estimators': [100, 300, 500, 800, 1200]},
             scoring='f1_macro', verbose=True)

In [140]:
# random forest model creation
rfc = RandomForestClassifier(random_state = 1, 
                             max_depth = bestF.best_params_['max_depth'],
                             n_estimators = bestF.best_params_['n_estimators'], 
                             min_samples_split = bestF.best_params_['min_samples_split'], 
                             min_samples_leaf = bestF.best_params_['min_samples_leaf'])
rfc.fit(X_train,y_train)
# predictions
rfc_predict = rfc.predict(X_test)
# Probabilities for each class
rfc_probs = rfc.predict_proba(X_test)[:, 1]

In [141]:
rfc_probs

array([0.        , 0.055     , 0.5525    , 0.        , 0.05      ,
       0.65316667, 0.        , 0.02839286, 0.00333333, 0.91333333,
       0.        , 0.02      , 0.01      , 0.063     , 0.        ,
       0.00666667, 0.        , 0.09666667, 0.        , 0.01      ,
       0.0125    , 0.01      , 0.        , 0.        , 0.02114286,
       0.90466667, 0.71727381, 0.02      , 0.        , 0.        ,
       0.95566667, 0.028     , 0.98666667, 0.        , 0.        ,
       0.        , 0.01      , 0.992     , 0.12083333, 0.04566667,
       0.07658333, 0.        , 0.        , 0.01880952, 0.        ,
       0.00658333, 0.0125    , 0.        , 0.        , 0.02      ,
       0.98666667, 0.        , 0.05666667, 0.        , 0.        ,
       0.46916667, 0.02      , 0.96130952, 0.91333333, 0.035     ,
       0.82319048])

In [142]:
print("=== Confusion Matrix ===")
print(confusion_matrix(y_test, rfc_predict))
print('\n')
print("=== Classification Report ===")
print(classification_report(y_test, rfc_predict))
print('\n')

=== Confusion Matrix ===
[[23  0  1  0]
 [ 0 13  0  0]
 [ 2  0  6  0]
 [ 0  0  0 16]]


=== Classification Report ===
              precision    recall  f1-score   support

           1       0.92      0.96      0.94        24
           2       1.00      1.00      1.00        13
           3       0.86      0.75      0.80         8
           4       1.00      1.00      1.00        16

    accuracy                           0.95        61
   macro avg       0.94      0.93      0.93        61
weighted avg       0.95      0.95      0.95        61





In [143]:
rna_sampleinfo.drop('Batch', axis=1, inplace=True)
dict_names = rna_sampleinfo.set_index('Vantage_ID').squeeze().to_dict()
rna_xcell.rename(columns=dict_names, inplace=True)

In [144]:
cytof_freq = cytof_freq.T
cytof_medianprot = cytof_medianprot.T

In [145]:
# Select only patients included in training
rna_xcell = rna_xcell[rna_xcell.columns.intersection(df_model.columns)][df_model.columns]
cytof_freq = cytof_freq[cytof_freq.columns.intersection(df_model.columns)][df_model.columns]
cytof_medianprot = cytof_medianprot[cytof_medianprot.columns.intersection(df_model.columns)][df_model.columns]

In [183]:
def predict_ft(model, data_to_predict, cutoff = 0.5):
    data_to_predict = (data_to_predict - data_to_predict.mean()) / data_to_predict.std()
    res = pd.concat([pd.DataFrame(model.predict(data_to_predict), columns= ['pred_cluster']),
           pd.DataFrame(model.predict_proba(data_to_predict), columns= ['1', '2', '3', '4'])],
         axis=1)
    res.index = data_to_predict.index
    res=res.loc[(res['1'] >cutoff) | (res['2'] >cutoff) | (res['3'] >cutoff) | (res['4'] >cutoff) ]
    return res

In [181]:
predict_ft(rfc, rna_xcell, cutoff = 0.5)

Unnamed: 0,pred_cluster,1,2,3,4
Basophils,1,0.513333,0.340476,0.060833,0.085357
CD4+ T-cells,1,0.689167,0.220667,0.029,0.061167
CD4+ Tem,1,0.6075,0.267833,0.036667,0.088
CD4+ memory T-cells,1,0.724167,0.177,0.023333,0.0755
CD8+ T-cells,1,0.625833,0.242,0.044,0.088167
CD8+ Tcm,1,0.680833,0.209,0.024,0.086167
CD8+ naive T-cells,1,0.574667,0.322667,0.0575,0.045167
CMP,2,0.2755,0.504119,0.154167,0.066214
Class-switched memory B-cells,1,0.5375,0.315333,0.050667,0.0965
DC,2,0.388333,0.504952,0.062833,0.043881


In [188]:
predict_ft(rfc, cytof_freq, cutoff = 0.5)

Unnamed: 0,pred_cluster,1,2,3,4
Stroma,1,0.572,0.217333,0.150333,0.060333
Endothelial,2,0.297833,0.537119,0.114667,0.050381
Fib_Mesenchymal,1,0.607667,0.218333,0.116333,0.057667
ECC_4,1,0.5485,0.311833,0.096333,0.043333
ECC_3,1,0.7645,0.163833,0.0175,0.054167
ECC_5,2,0.2695,0.564619,0.107167,0.058714
ECC_2,2,0.281929,0.506679,0.151071,0.060321
ECC_1,2,0.347833,0.524786,0.091667,0.035714
fmes_1,2,0.334167,0.509952,0.108,0.047881
fmes_3,1,0.724167,0.1865,0.071833,0.0175


In [190]:
predict_ft(rfc, cytof_medianprot, cutoff = 0.5)

Unnamed: 0,pred_cluster,1,2,3,4
EpCAM,1,0.5745,0.355595,0.031714,0.03819
CD4,1,0.528667,0.283583,0.107381,0.080369
Cytokeratin,1,0.510833,0.286333,0.089,0.113833
MET,1,0.513548,0.302,0.121452,0.063
CD90,2,0.256667,0.504952,0.185167,0.053214
