In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn import set_config

pd.set_option("display.max_columns", 85)
sns.set_theme(context="paper", font_scale=1.5, style="ticks", rc={"axes.grid": True})
set_config(display="diagram")

# AdaBoosting: Scoring by Nested Cross-Validation

### Load the data

In [2]:
# Read Data (NEW from PP)
df = pd.read_csv("../data/new_abnormal_writeout_noscale.data.csv", index_col=0)

# Drop NaNs
df.dropna(inplace=True)

# Collect Features and Labels
features_df = pd.DataFrame()
conf = df.drop(labels=["response", "occ_total_sum", "oldest_phylostratum"], axis=1)

features_df["occ_total_sum"] = df["occ_total_sum"]
features_df["oldest_phylostratum"] = df["oldest_phylostratum"]
features_df = pd.concat([features_df, conf], axis=1)

X = features_df.to_numpy()
y = df["response"].to_numpy()

features_df

Unnamed: 0,occ_total_sum,oldest_phylostratum,cds_length,dnase_gene,dnase_cds,H3k4me1_gene,H3k4me3_gene,H3k27ac_gene,H3k4me1_cds,H3k4me3_cds,H3k27ac_cds,lamin_gene,repli_gene,nsome_gene,nsome_cds,transcription_gene,repeat_gene,repeat_cds,recomb_gene,AAA_freq,AAC_freq,AAG_freq,AAT_freq,ACA_freq,ACC_freq,ACG_freq,ACT_freq,AGA_freq,AGC_freq,AGG_freq,AGT_freq,ATA_freq,ATC_freq,ATG_freq,ATT_freq,CAA_freq,CAC_freq,CAG_freq,CAT_freq,CCA_freq,CCC_freq,CCG_freq,CCT_freq,CGA_freq,CGC_freq,CGG_freq,CGT_freq,CTA_freq,CTC_freq,CTG_freq,CTT_freq,GAA_freq,GAC_freq,GAG_freq,GAT_freq,GCA_freq,GCC_freq,GCG_freq,GCT_freq,GGA_freq,GGC_freq,GGG_freq,GGT_freq,GTA_freq,GTC_freq,GTG_freq,GTT_freq,TAA_freq,TAC_freq,TAG_freq,TAT_freq,TCA_freq,TCC_freq,TCG_freq,TCT_freq,TGA_freq,TGC_freq,TGG_freq,TGT_freq,TTA_freq,TTC_freq,TTG_freq
1,33,12,1488,0.612230,0.758065,0.561429,1.000000,0.216855,0.661290,1.000000,0.198925,0.0,0.041809,0.809254,0.706453,6.798234,0.040516,0.0,0.000000,0.004755,0.008152,0.007473,0.002717,0.011549,0.026495,0.010870,0.008152,0.010190,0.028533,0.019701,0.009511,0.000679,0.006114,0.010870,0.002038,0.009511,0.019022,0.028533,0.007473,0.027174,0.031250,0.025136,0.029891,0.015625,0.027174,0.019701,0.009511,0.007473,0.017663,0.044837,0.013587,0.008832,0.021739,0.031250,0.008152,0.016984,0.033967,0.027853,0.034647,0.023777,0.030571,0.029212,0.013587,0.000679,0.012908,0.027174,0.003397,0.000000,0.008152,0.000000,0.001359,0.008832,0.021739,0.009511,0.010190,0.020380,0.027174,0.029212,0.010870,0.000679,0.013587,0.005435
10,28,1,873,0.086769,0.195876,0.657839,0.000000,0.000000,0.000000,0.000000,0.000000,1.0,-0.007148,0.828752,1.097018,0.061963,0.002809,0.0,2.043350,0.025258,0.019518,0.021814,0.024110,0.025258,0.018370,0.003444,0.012629,0.035591,0.009185,0.016073,0.006889,0.016073,0.017222,0.010333,0.033295,0.019518,0.011481,0.020666,0.022962,0.017222,0.008037,0.002296,0.021814,0.003444,0.001148,0.004592,0.002296,0.008037,0.019518,0.022962,0.019518,0.033295,0.013777,0.019518,0.011481,0.014925,0.006889,0.000000,0.012629,0.018370,0.011481,0.017222,0.018370,0.005741,0.008037,0.012629,0.012629,0.012629,0.014925,0.006889,0.017222,0.017222,0.016073,0.005741,0.022962,0.020666,0.012629,0.027555,0.011481,0.021814,0.017222,0.026406
100,36,1,1092,0.479295,0.611722,0.851369,0.354628,0.618954,0.754579,0.030220,0.086996,0.0,0.040463,1.249600,1.354306,6.081620,0.028404,0.0,0.868383,0.018727,0.012172,0.023408,0.003745,0.017790,0.024345,0.007491,0.014981,0.024345,0.020599,0.025281,0.011236,0.003745,0.013109,0.019663,0.004682,0.017790,0.016854,0.029963,0.017790,0.034644,0.022472,0.010300,0.028090,0.005618,0.010300,0.014045,0.003745,0.015918,0.015918,0.033708,0.011236,0.014981,0.022472,0.026217,0.009363,0.015918,0.031835,0.007491,0.025281,0.028090,0.029026,0.021536,0.013109,0.008427,0.010300,0.016854,0.003745,0.006554,0.012172,0.005618,0.008427,0.014981,0.016854,0.009363,0.008427,0.014981,0.019663,0.029026,0.010300,0.004682,0.010300,0.004682
1000,126,1,2800,0.171524,0.280357,0.554023,0.052420,0.278492,0.270357,0.021429,0.151429,0.0,-0.022495,0.921420,1.382249,2.254471,0.014520,0.0,1.143060,0.022054,0.014823,0.022415,0.024946,0.022054,0.014100,0.006146,0.015546,0.024946,0.016992,0.012292,0.015907,0.013377,0.021330,0.026392,0.017715,0.026392,0.011931,0.027477,0.017354,0.023861,0.016992,0.006508,0.019161,0.005785,0.003977,0.007954,0.003977,0.006146,0.010846,0.025307,0.015907,0.022415,0.022777,0.016269,0.018800,0.015184,0.016992,0.004700,0.014461,0.017354,0.010484,0.010123,0.011931,0.009400,0.007231,0.020607,0.011931,0.013738,0.008315,0.006146,0.016631,0.022054,0.018077,0.004700,0.009038,0.031092,0.019523,0.019523,0.016992,0.016269,0.014100,0.015907
10000,55,1,1484,0.143843,0.030997,0.400789,0.106455,0.457949,0.708221,0.030997,0.659704,0.0,-0.000387,0.960747,1.196871,1.080241,0.009545,0.0,4.217000,0.039835,0.015797,0.030220,0.025412,0.024038,0.012363,0.002747,0.019918,0.048077,0.006868,0.015797,0.009615,0.020604,0.009615,0.032280,0.023352,0.019918,0.012363,0.021978,0.015797,0.015110,0.003434,0.004121,0.013049,0.005495,0.001374,0.002060,0.002060,0.013736,0.014423,0.014423,0.013736,0.034341,0.017857,0.024725,0.024725,0.016484,0.006868,0.002747,0.006181,0.022665,0.013049,0.010302,0.008242,0.009615,0.004808,0.013736,0.011676,0.018544,0.012363,0.008242,0.019231,0.015110,0.012363,0.002060,0.015797,0.024038,0.010989,0.026099,0.018544,0.014423,0.015797,0.019231
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
999,208,1,2649,0.313496,0.427709,0.721323,0.380132,0.560000,0.371461,0.147603,0.375613,0.0,0.051321,1.156640,1.677763,12.624956,0.019776,0.0,0.549588,0.016813,0.015285,0.020634,0.015667,0.029041,0.018342,0.012228,0.012992,0.022163,0.018724,0.019488,0.009171,0.007642,0.014903,0.015667,0.015667,0.020634,0.025984,0.026366,0.012992,0.026748,0.023691,0.009935,0.021016,0.009171,0.007642,0.008789,0.007642,0.009935,0.017577,0.030187,0.012610,0.024838,0.019870,0.022545,0.016431,0.010699,0.021781,0.005732,0.018724,0.022927,0.014520,0.011846,0.012992,0.004968,0.011846,0.017577,0.007642,0.006878,0.011846,0.003057,0.007642,0.019488,0.017195,0.005732,0.017577,0.028659,0.015667,0.021781,0.011464,0.006496,0.015667,0.014138
9990,88,1,4035,0.159518,0.305328,0.618466,1.000000,0.379258,0.538290,1.000000,0.578686,0.0,0.032907,0.952004,1.596068,4.338614,0.013269,0.0,2.271970,0.019613,0.013327,0.021624,0.012824,0.021624,0.018607,0.004275,0.016093,0.018607,0.016847,0.015841,0.013830,0.007292,0.016595,0.025396,0.016847,0.019361,0.019864,0.020619,0.023133,0.024642,0.014835,0.009052,0.018104,0.009303,0.004275,0.007040,0.003017,0.012069,0.015590,0.021121,0.022630,0.020116,0.014332,0.019361,0.016595,0.016595,0.015590,0.004275,0.018356,0.019613,0.015841,0.016093,0.014835,0.008046,0.011064,0.019361,0.010561,0.008801,0.011567,0.006035,0.012321,0.019864,0.018104,0.006035,0.018858,0.022630,0.017098,0.027910,0.016093,0.011315,0.019613,0.018858
9991,37,2,2043,0.164623,0.025453,0.748995,0.710461,0.872609,0.785120,0.786099,1.000000,0.0,0.045040,0.865913,1.245576,7.591840,0.014049,0.0,2.458350,0.024463,0.007988,0.020469,0.024463,0.015477,0.016975,0.002996,0.011982,0.017474,0.017474,0.008987,0.014978,0.012481,0.019970,0.024463,0.019471,0.017474,0.015976,0.024963,0.018472,0.021468,0.012481,0.003994,0.025462,0.004493,0.004493,0.003994,0.003495,0.014978,0.020969,0.024963,0.031952,0.021468,0.008987,0.010484,0.014978,0.013979,0.013979,0.001498,0.021468,0.014978,0.011982,0.009486,0.007489,0.004993,0.010484,0.012981,0.012981,0.013979,0.014478,0.005991,0.016475,0.025961,0.019970,0.006490,0.033450,0.018972,0.015976,0.021468,0.014978,0.018472,0.034448,0.011483
9992,14,12,372,0.166620,0.572581,0.857123,0.861899,1.000000,1.000000,1.000000,1.000000,0.0,0.017871,1.277585,1.767925,0.136402,0.020090,0.0,2.001840,0.018919,0.013514,0.027027,0.024324,0.027027,0.013514,0.008108,0.018919,0.032432,0.013514,0.013514,0.008108,0.002703,0.027027,0.027027,0.018919,0.029730,0.018919,0.013514,0.024324,0.035135,0.016216,0.002703,0.016216,0.005405,0.005405,0.005405,0.005405,0.013514,0.010811,0.024324,0.016216,0.035135,0.016216,0.021622,0.010811,0.008108,0.016216,0.005405,0.010811,0.018919,0.016216,0.005405,0.010811,0.010811,0.010811,0.013514,0.008108,0.000000,0.018919,0.005405,0.013514,0.016216,0.024324,0.005405,0.018919,0.029730,0.005405,0.027027,0.018919,0.010811,0.016216,0.016216


*** 
## Nested CV on Gradient Boosted Trees

### The Model and its Parameter Space

In [14]:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import AdaBoostClassifier
from sklearn.compose import ColumnTransformer
from sklearn.decomposition import PCA

# Confounder PCA
confpca = ColumnTransformer([
    ("ots+of", "passthrough", [0, 1]), 
    ("conf", PCA(), slice(2, X.shape[1]))
])

# Parameter Grid
main_params = {
    "gb__learning_rate": [0.5, 1, 1.5],
    "gb__n_estimators": [50, 100, 200],
}

pca_on = {'pca': [confpca], 'pca__conf__n_components': [None, 0.01, 0.95]}
pca_off = {'pca': ['passthrough'],}

param_grid = [{**main_params, **pca_on}, {**main_params, **pca_off}]

# Define the model to be tuned
adab_clf = Pipeline([
    ("scaler", StandardScaler()),
    ("pca", confpca),
    ("gb", AdaBoostClassifier()),
])

adab_clf

### Nested CV

In [15]:
from sklearn.exceptions import ConvergenceWarning, FitFailedWarning
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
from sklearn.model_selection import GridSearchCV, KFold, cross_val_score, RepeatedKFold
from sklearn.utils._testing import ignore_warnings
from sklearn.model_selection import train_test_split


# configure the cross-validation procedure
np.random.seed(3)
model = adab_clf
k_outer = 10
k_inner = 3
cv_outer = KFold(n_splits=k_outer, shuffle=True, random_state=1)
cv_inner = KFold(n_splits=k_inner, shuffle=True, random_state=3)

# To store results
roc_results = list()
found_params = list()

print(f"Performing nested-cv with {k_outer} outer-folds and {k_inner} inner-folds.\n")
print("OUTER CV | BEST OF INNER CV | CHOSEN PARAMS")

for train_ix, test_ix in cv_outer.split(X):

    # split data
    X_tr, X_te = X[train_ix, :], X[test_ix, :]
    y_tr, y_te = y[train_ix], y[test_ix]

    # If some parameter combinations are incompatible:
    # with ignore_warnings(category=[ConvergenceWarning, FitFailedWarning]):
    
    # define search
    search = GridSearchCV(estimator=model, param_grid=param_grid, scoring="roc_auc", cv=cv_inner, n_jobs=4)
    
    # execute search
    result = search.fit(X_tr, y_tr)
        
    # get the best performing model fit on the whole training set
    best_model = result.best_estimator_

    # evaluate model on the hold out dataset
    # yhat = best_model.predict(X_te)
    yhat = best_model.predict_proba(X_te)[:,1]

    # evaluate the model
    roc_auc = roc_auc_score(y_te, yhat)
    
    # store the result
    roc_results.append(roc_auc)
    found_params.append(result.best_params_)

    # report progress
    print(">roc-auc=%.3f, est=%.3f, params=%s" % (roc_auc, result.best_score_, result.best_params_))

# summarize the estimated performance of the model
print("ROC-AUC: %.3f (std = %.3f)" % (np.mean(roc_results), np.std(roc_results)))

Performing nested-cv with 10 outer-folds and 5 inner-folds.

OUTER CV | BEST OF INNER CV | CHOSEN PARAMS


KeyboardInterrupt: 

In [8]:
ncv_df = pd.DataFrame(roc_results, columns=['roc_auc'])
ncv_df = pd.concat([ncv_df, pd.DataFrame(found_params)], axis=1)
ncv_df

Unnamed: 0,roc_auc,gb__learning_rate,pca__apply_PCA,pca__n_components
0,0.676146,0.1,False,0.95
1,0.703986,0.05,False,0.95
2,0.6836,0.05,False,0.95
3,0.693924,0.05,False,0.95
4,0.688165,0.05,False,0.95
5,0.691082,0.05,False,0.95
6,0.685935,0.05,False,0.95
7,0.656577,0.05,False,0.95
8,0.711407,0.05,False,0.95
9,0.667781,0.05,False,0.95


In [9]:
ncv_df["roc_auc"].mean()

0.6858603438495539

In [10]:
ncv_df.to_csv("./data/gb_ncv.csv")