# Organ aging model training

In [1]:
import pandas as pd
import numpy as np
from sklearn import preprocessing
from scipy import stats
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import seaborn as sns
from scipy.stats import zscore
from sklearn import metrics
import json
import pickle
import time
import random 
import os
from adjustText import adjust_text

%matplotlib inline


In [5]:
import mkl
mkl.set_num_threads(1)


16

# Train Model

In [71]:
import multiprocessing as mp
from sklearn.linear_model import Lasso
from sklearn.model_selection import GridSearchCV

def Train_all_tissue_aging_model(md_hot_train, df_prot_train,
                                 tissue_plist_dict, seed_list, 
                                 performance_CUTOFF, train_cohort,
                                 norm, agerange, NPOOL=15):
    NUM_BOOTSTRAP=len(seed_list)
    
    # final lists for output
    all_coef_dfs = []   
    
    # bootstrap clocks for each tissue
    for tissue,plist in tissue_plist_dict.items():
        if len(plist)>0:            
            print(tissue) 
            
            # Subset to tissue proteins, setup dfX/dfY
            df_prot_train_tissue = df_prot_train[plist]
            
            # zscore
            scaler = StandardScaler()
            scaler.fit(df_prot_train_tissue)
            tmp = scaler.transform(df_prot_train_tissue)
            df_prot_train_tissue = pd.DataFrame(tmp, index=df_prot_train_tissue.index, columns=df_prot_train_tissue.columns)
            
            # save the scaler
            path = '../data/ml_models/'+train_cohort+'/'+agerange+'/'+norm+'/'+tissue
            fn = '/'+train_cohort+'_'+agerange+'_based_'+tissue+'_protein_zscore_scaler.pkl'
            pickle.dump(scaler, open(path+fn, 'wb'))
    
            # add sex 
            if "Sex_F" in list(md_hot_train.columns):
                df_X_train = pd.concat([md_hot_train[["Sex_F"]], df_prot_train_tissue], axis=1)
            else:
                df_X_train = df_prot_train_tissue.copy()
            df_Y_train = md_hot_train[["Age"]].copy()
    
            # Bootstrap training
            pool = mp.Pool(NPOOL)
            input_list = [([df_X_train, df_Y_train, train_cohort,
                            tissue, performance_CUTOFF, norm, agerange] + [seed_list[i]]) for i in range(NUM_BOOTSTRAP)]        
            coef_list = pool.starmap(Bootstrap_train, input_list)
            pool.close()
            pool.join()
            
            df_tissue_coef = pd.DataFrame(coef_list, columns=["tissue", "BS_Seed", "alpha", "y_intercept"]+list(df_X_train.columns))
            all_coef_dfs.append(df_tissue_coef)
    dfcoef=pd.concat(all_coef_dfs, join="outer")
    return dfcoef
  
    
def Bootstrap_train(df_X_train, df_Y_train, train_cohort,
              tissue, performance_CUTOFF, norm, agerange, seed):
    
    #setup
    X_train_sample = df_X_train.sample(frac=1, replace=True, random_state=seed).to_numpy()
    Y_train_sample = df_Y_train.sample(frac=1, replace=True, random_state=seed).to_numpy()    
    
    # LASSO
    lasso = Lasso(random_state=0, tol=0.01, max_iter=5000)
    alphas = np.logspace(-3, 1, 100)
    tuned_parameters = [{'alpha': alphas}]
    n_folds=5
    clf = GridSearchCV(lasso, tuned_parameters, cv=n_folds, scoring="neg_mean_squared_error", refit=False)
    clf.fit(X_train_sample, Y_train_sample)
    gsdf = pd.DataFrame(clf.cv_results_)    
    best_alpha=Plot_and_pick_alpha(gsdf, performance_CUTOFF, plot=False)   #pick best alpha
     
    # Retrain 
    lasso = Lasso(alpha=best_alpha, random_state=0, tol=0.01, max_iter=5000)
    lasso.fit(X_train_sample, Y_train_sample)

    # SAVE MODEL
    savefp="../data/ml_models/"+train_cohort+"/"+agerange+"/"+norm+"/"+tissue+"/"+train_cohort+"_"+agerange+"_"+norm+"_lasso_"+tissue+"_seed"+str(seed)+"_aging_model.pkl"
    pickle.dump(lasso, open(savefp, 'wb'))
    
    # SAVE coefficients            
    coef_list = [tissue, seed, best_alpha, lasso.intercept_[0]] + list(lasso.coef_)

    return coef_list
    


def Plot_and_pick_alpha(gsdf, performance_CUTOFF, plot=True):
    
    #pick alpha at 90-95% top performance, negative derivative (higher alpha)
    gsdf["mean_test_score_norm"] = NormalizeData(gsdf["mean_test_score"])
    gsdf["mean_test_score_norm_minus95"] = gsdf["mean_test_score_norm"]-performance_CUTOFF
    gsdf["mean_test_score_norm_minus95_abs"] = np.abs(gsdf["mean_test_score_norm_minus95"])

        #derivative of performance by alpha
    x=gsdf.param_alpha.to_numpy()
    y=gsdf.mean_test_score_norm.to_numpy()
    dx=0.1
    gsdf["derivative"] = np.gradient(y, dx)
    tmp=gsdf.loc[gsdf.derivative<0]
    if len(tmp)!=0:
        best_alpha = list(tmp.loc[tmp.mean_test_score_norm_minus95_abs == np.min(tmp.mean_test_score_norm_minus95_abs)].param_alpha)[0]
    else:
        print('no alpha with derivative <0')
        tmp2=gsdf
        best_alpha = list(tmp2.loc[tmp2.mean_test_score_norm_minus95_abs == np.min(tmp2.mean_test_score_norm_minus95_abs)].param_alpha)[0]
        
    # PLOT
    if plot:
        fig,axs=plt.subplots(1,2,figsize=(7,3))
        sns.scatterplot(data=gsdf, x="param_alpha", y="mean_test_score_norm", ax=axs[0])
        sns.scatterplot(data=gsdf.loc[gsdf.param_alpha==best_alpha], x="param_alpha", y="mean_test_score_norm", ax=axs[0])
        sns.scatterplot(data=gsdf, x="param_alpha", y="mean_test_score_norm", ax=axs[1])
        sns.scatterplot(data=gsdf.loc[gsdf.param_alpha==best_alpha], x="param_alpha", y="mean_test_score_norm", ax=axs[1])
        axs[0].set_xlim(-0.02,best_alpha+0.1)
        axs[0].set_ylim(0.8,1.05)
        axs[0].axvline(0.008)
        axs[0].axhline(performance_CUTOFF)
        plt.tight_layout()
        plt.show()
    return best_alpha
    
    
def NormalizeData(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))
    


In [None]:
NPOOL=16
agerange="HC"
performance_CUTOFF=0.95
norm="Zprot_perf"+str(int(performance_CUTOFF*100))
train_cohort="KADRC"

        
#95% performance
start_time = time.time()
dfcoef = Train_all_tissue_aging_model(md_hot_train, #meta data dataframe with age and sex (binary) as columns
                                       df_prot_train, #protein expression dataframe with SeqIds as columns
                                       tissue_plist_dict, #tissue:protein list dictionary
                                       bs_seed_list, #bootstrap seeds
                                       performance_CUTOFF=performance_CUTOFF, #heuristic for model simplification
                                       NPOOL=NPOOL, #parallelize
                                       
                                       train_cohort=train_cohort, #these three variables for file naming
                                       norm=norm, 
                                       agerange=agerange, 
                                       )
print((time.time() - start_time)/60)


# Test model
save predicted vs actual age lowess regression model

In [None]:
import dill
from matplotlib import colors as clr
import statsmodels.api as sm
from scipy.interpolate import interp1d


def Test_all_tissue_aging_model(md_hot_train, df_prot_train, 
                                md_hot_test, df_prot_test, 
                                md_train, md_test,
                                 tissue_plist_dict, colormap_dict, seed_list, 
                                 performance_CUTOFF, train_cohort, norm, agerange, NPOOL=15):  #these variables used to pull filepath for ML model
                                 
    NUM_BOOTSTRAP=len(seed_list)
    
    # final lists for output
    all_Y = []
    all_coef_dfs = []   
    
    # bootstrap clocks for each tissue
    for tissue,plist in tissue_plist_dict.items():
        if len(plist)>0:            
            print(tissue) 
            
            # Subset to tissue proteins, setup dfX/dfY
            df_prot_train_tissue = df_prot_train[plist]
            df_prot_test_tissue = df_prot_test[plist]
            
            # zscore
            path = '../data/ml_models/'+train_cohort+'/'+agerange+'/'+norm+'/'+tissue
            fn = '/'+train_cohort+'_'+agerange+'_based_'+tissue+'_protein_zscore_scaler.pkl'
            scaler = pickle.load(open(path+fn, 'rb'))
            
            tmp = scaler.transform(df_prot_train_tissue)
            df_prot_train_tissue = pd.DataFrame(tmp, index=df_prot_train_tissue.index, columns=df_prot_train_tissue.columns)

            tmp2 = scaler.transform(df_prot_test_tissue)
            df_prot_test_tissue = pd.DataFrame(tmp2, index=df_prot_test_tissue.index, columns=df_prot_test_tissue.columns)
            
            # add sex
            if "Sex_F" in list(md_hot_train.columns):
                df_X_train = pd.concat([md_hot_train[["Sex_F"]], df_prot_train_tissue], axis=1)
                df_X_test = pd.concat([md_hot_test[["Sex_F"]], df_prot_test_tissue], axis=1)
            else:
                df_X_train = df_prot_train_tissue.copy()
                df_X_test = df_prot_test_tissue.copy()
            df_Y_train = md_train[["Age"]].copy()
            df_Y_test = md_test[["Age"]].copy()
    
            # Bootstrap training and testing
            pool = mp.Pool(NPOOL)
            input_list = [([df_X_train, df_Y_train, df_X_test, df_Y_test, train_cohort,
                            tissue, performance_CUTOFF, norm, agerange] + [seed_list[i]]) for i in range(NUM_BOOTSTRAP)]        
            predages = pool.starmap(Bootstrap_test, input_list)
            pool.close()
            pool.join()

            #Organize predicte age info
            predage = np.mean(predages, axis=0)   #mean of bootstraps            
            df_Y_info = pd.concat([md_train, md_test]).copy()
            df_Y_info["Pred_Age"] = predage
            lowess_fit_int = Plot_age_prediction_return_lowess_function(df_Y_info, md_hot_train, colormap_dict, tissue)

            #save lowess model
            savefp_lowess="../data/ml_models/"+train_cohort+"/"+agerange+"/"+norm+"/"+tissue+"/"+train_cohort+"_"+agerange+"_"+norm+"_lasso_"+tissue+"_age_prediction_lowess.dill"
            with open(savefp_lowess, "wb") as dill_file:
                 dill.dump(lowess_fit_int, dill_file)
        
            #calculate ∆age and rest of info
            x_lowess=list(df_Y_info.Age)
            x_lowess=[np.min(x_lowess)-1] + x_lowess + [np.max(x_lowess)+1]
            y_lowess=lowess_fit_int(x_lowess)
            lowess_dict = dict(zip(x_lowess, y_lowess))


            #df_Y_info["yhat_lowess"]=lowess_fit_int(list(df_Y_info.Age))    
            
            df_Y_info["yhat_lowess"]=df_Y_info["Age"].map(lowess_dict)
            if len(df_Y_info.loc[df_Y_info.yhat_lowess.isna()])>0:
                print("NA samples removed, n="+str(len(df_Y_info.loc[df_Y_info.yhat_lowess.isna()])))
                print(df_Y_info.loc[df_Y_info.yhat_lowess.isna()])
                df_Y_info=df_Y_info.dropna(subset="yhat_lowess")
            df_Y_info["dage_resid"]=df_Y_info["Pred_Age"]-df_Y_info["yhat_lowess"]
            
            dagez_scaler = StandardScaler()
            dagez_scaler.fit(df_Y_info.loc[df_Y_info.Cohort2==train_cohort][["dage_resid"]].to_numpy())
            df_Y_info["dage_resid_zscored"] = dagez_scaler.transform(df_Y_info[["dage_resid"]].to_numpy()).flatten()
            df_Y_info["dage_resid_zscored"] = df_Y_info["dage_resid_zscored"] - dagez_scaler.transform([[0]]).flatten()[0]
            
            #save dage scaler
            savefp_dagez_scaler="../data/ml_models/"+train_cohort+"/"+agerange+"/"+norm+"/"+tissue+"/"+train_cohort+"_"+agerange+"_"+norm+"_lasso_"+tissue+"_agegap_zscore_scaler.pkl"
            pickle.dump(dagez_scaler, open(savefp_dagez_scaler, 'wb'))
            
            #calculate ∆age per cohort
            df_Y_info=Calculate_dage_per_cohort_one_tissue(df_Y_info)
            df_Y_info["tissue"]=tissue
            all_Y.append(df_Y_info)
        
    df_all_Y=pd.concat(all_Y)
    return df_all_Y
            
            
def Bootstrap_test(df_X_train, df_Y_train, df_X_test, df_Y_test, train_cohort,
              tissue, performance_CUTOFF, norm, agerange, seed):
    X_train = df_X_train.to_numpy()
    X_test = df_X_test.to_numpy() 
    
    # SAVE MODEL
    savefp="../data/ml_models/"+train_cohort+"/"+agerange+"/"+norm+"/"+tissue+"/"+train_cohort+"_"+agerange+"_"+norm+"_lasso_"+tissue+"_seed"+str(seed)+"_aging_model.pkl"    
    lasso = pickle.load(open(savefp, 'rb'))

    #Predict on train/test
    Ypred_train=lasso.predict(X_train)
    Ypred_test=lasso.predict(X_test)    
    Ypred = np.concatenate([Ypred_train, Ypred_test])
    return Ypred


def Plot_age_prediction_return_lowess_function(dfY, md_hot_train,
                                               colormap_dict, tissue):
    fig,axs=plt.subplots(1,2,figsize=(7,3))

    #density
    cmap = clr.LinearSegmentedColormap.from_list(tissue+"_colormap", ['whitesmoke', colormap_dict[tissue]], N=256)
    toplot=dfY.copy()
    x=toplot.Age.to_numpy()
    y=toplot.Pred_Age.to_numpy()
    xy = np.vstack([x,y])
    z = 2**np.log10(stats.gaussian_kde(xy)(xy))
    toplot["z"]=z
    sns.scatterplot(data=toplot, x="Age", y="Pred_Age", hue="z", palette=cmap, alpha=0.6, edgecolor=None, ax=axs[0])
    
    #Plot loess
    dfY_train = dfY.loc[dfY.index.isin(md_hot_train.index)]
    lowess = sm.nonparametric.lowess
    lowess_fit=lowess(dfY_train.Pred_Age.to_numpy(), dfY_train.Age.to_numpy(), frac=2/3, it=5)
    lowess_fit_int = interp1d(lowess_fit[:,0], lowess_fit[:,1], bounds_error=False, kind='linear', fill_value='extrapolate') 

    x_lowess=np.arange(dfY.Age.min(), dfY.Age.max()+2, 0.1)
    y_lowess=lowess_fit_int(x_lowess)
    
#    sns.scatterplot(x=x_lowess, y=y_lowess, color="black", edgecolor=None, ax=axs[0])
    axs[0].plot(x_lowess, y_lowess, color="black", linewidth=3)
    axs[0].legend().remove()
    axs[0].set_title(tissue)
    axs[0].spines.right.set_visible(False)
    axs[0].spines.top.set_visible(False)    
    
    #Plot by dataset
    dcolormap={"Covance":"#E8DFCA","LonGenity":"#B8E8FC", 
           "SADRC":"#1C6758", "KADRC":"#86C8BC","SAMS":"#023e8a",}
    
    dsets=list(set(dfY.Cohort))
    dsets=[x for x in dsets if x!="SADRC_SAMS"]    
    dsets.sort()
    for d in dsets:
        dfY_d = dfY.loc[dfY.Cohort.str.contains(d)]    #include all samples for r calculation
        r,p=stats.pearsonr(dfY_d.Age, dfY_d.Pred_Age)
        print(d,len(dfY_d),r)
        if d=="Kaci":   #Kaci's dataset is only MCI
            sns.regplot(data=dfY_d.loc[dfY_d.Diagnosis_group=="MCI"],
                        x="Age", y="Pred_Age", color=dcolormap[d], label=d,
                        scatter_kws={"alpha":0.3, "edgecolor":None},
                        line_kws={"linewidth":5}, ax=axs[1])
        else:
            sns.regplot(data=dfY_d.loc[dfY_d.Diagnosis_group=="HC"], 
                x="Age", y="Pred_Age", color=dcolormap[d], label=d,
                scatter_kws={"alpha":0.3, "edgecolor":None},
                line_kws={"linewidth":5}, ax=axs[1])
      
    axs[1].spines.right.set_visible(False)
    axs[1].spines.top.set_visible(False)
    axs[1].legend(loc="upper left", bbox_to_anchor=(1,1))#.remove()
    axs[1].set_title(tissue)
   
    plt.show()
    return lowess_fit_int


def Calculate_dage_per_cohort_one_tissue(dfY):
    
    todf = []
    for cohort in ['Covance', 'KADRC', 'LonGenity', 'SADRC', 'SAMS']:
        if cohort in list(dfY.Cohort):

            #subset to cohort
            if cohort=="SADRC":
                dfY_c=dfY.loc[dfY.Cohort.str.contains("SADRC")].copy()
            else:
                dfY_c=dfY.loc[dfY.Cohort==cohort].copy()     

            #lowess on hc
            dfY_c_hc = dfY_c.loc[dfY_c.Diagnosis_group=="HC"].copy()   
            lowess = sm.nonparametric.lowess
            lowess_fit=lowess(dfY_c_hc.Pred_Age.to_numpy(), dfY_c_hc.Age.to_numpy(), frac=1, it=5)
            lowess_fit_int = interp1d(lowess_fit[:,0], lowess_fit[:,1], bounds_error=False, kind="linear", fill_value='extrapolate') 

            x_lowess=list(dfY_c.Age.to_numpy())
            x_lowess=[np.min(x_lowess)-1] + x_lowess + [np.max(x_lowess)+1]
            y_lowess=lowess_fit_int(x_lowess)
            lowess_dict = dict(zip(x_lowess, y_lowess))

            dfY_c["yhat_lowess_cohort"]=dfY_c.Age.map(lowess_dict)
            dfY_c=dfY_c.dropna(subset="yhat_lowess_cohort")
            dfY_c["dage_resid_cohort"]=dfY_c["Pred_Age"]-dfY_c["yhat_lowess_cohort"]
            todf.append(dfY_c)

    #zscore all cohorts together
    dfYnew = pd.concat(todf)
    dfYnew["dage_resid_zscored_cohort"]=stats.zscore(dfYnew["dage_resid_cohort"])
    
    #make healthy controls mean 0
    tmp=dfYnew.loc[dfYnew.Diagnosis_group=="HC"]
    dfYnew["dage_resid_zscored_cohort"] = dfYnew["dage_resid_zscored_cohort"] - tmp.dage_resid_zscored_cohort.mean()
    return dfYnew
    
    

In [None]:
dfY = Test_all_tissue_aging_model(md_hot_train, df_prot_train, 
                                  md_hot_test, df_prot_test,
                                  md_train, md_test,
                                  tissue_plist_dict_enr_keep, colormap_dict, bs_seed_list, 
                                  performance_CUTOFF=0.95,
                                  train_cohort=train_cohort,
                                  norm=norm, 
                                  agerange=agerange, 
                                  NPOOL=NPOOL)
