SET PARAMETERS

In [None]:
# initialize params
DEBUG = True
saveReport = False
toPrint = True
reportName = 'notebook'
txt_label = "Classification of integrated c1 and c2 CNVs samples"


# train test split params
split_train_size = 40
split_random_state = 0

# classification params
classification_args = {
    "n_splits": 10,
    "random_state": 0
}

# plotting params
function_dict = None
with_swarm = False
highRes = False
if highRes:
    img_ext = '.pdf'
else:
    img_ext = '.png'
cmap_custom = None
vmin, vmax = (-2, +2)
plot_kwargs = {
    "mincol":"red",
    "midcol":"white",
    "maxcol":"blue",
}

In [None]:
# data file
data_fpath = "output/headneck/integrate_cohorts/c1c2/CNV_mapped_filt/integrated_data.csv"

# sample_info file
sample_info_fpath = "output/headneck/integrate_cohorts/c1c2/integrated_sample_info.csv"
sample_class_column = "Relapsed"
class_labels = ["relapsed","NOTrelapsed"]
class_values = [1,0]

# genes_info file
genes_info_fpath = "output/headneck/setup_c1_oncoscan_byNexus/genes_info.csv"
chr_col = 'chr_int'
gene_id_col = 'gene'

# output dir
output_directory = "output/headneck/classification/"+reportName

In [None]:
# arguments to load the sample_info file
sample_info_read_csv_kwargs = {
    "sep": "\t",
    "header": 0,
    "col_as_index":"patientID"
}

In [None]:
# select features
genepanel_path = "output/headneck/setup_c1_genepanel/process_select_primary/data_processed.csv"
_dirs = ['c1_prmr_OncFltNxEx', 'c2_ExcvFltNxEx', 'c1_prmr_mapped_c2_CnvNxEx', 'c1_prmr_mapped_c2_Cnv', 'c1_prmr_mapped_c2_CnvMixedNxEx']
_key_names = ['c1_OncFltNxEx', 'c2_ExcvFltNxEx', 'c3_CnvNxEx', 'c3_Cnv', 'c3_CnvMixedNxEx']

SET ENVIRONMENT

In [None]:
# custom imports
from omics_processing.io import (
    set_directory, load_clinical
)
from omics_processing.remove_duplicates import (
    remove_andSave_duplicates
)
from gene_signatures.core import (
    custom_div_cmap,
    get_chr_ticks,
    choose_samples,
    parse_arg_type,
    boxplot,
    set_heatmap_size,
    set_cbar_ticks,
    edit_names_with_duplicates,
    plot_confusion_matrix
)

# basic imports
import os, sys
import numpy as np
import pandas as pd
import json
from scipy.spatial.distance import pdist, squareform
from natsort import natsorted, index_natsorted
import math
import logging
from sklearn import linear_model
from sklearn import svm
from distutils.util import strtobool
from scipy.stats import binom_test
from sklearn.externals import joblib
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import roc_curve, auc
from scipy import interp

# plotting imports
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('white')
sns.set_context('talk')

script_path = os.getcwd()
logger = logging.getLogger(__name__)

Functions

In [None]:
def check_path_integrity(f,rootDir=None, name="", force=False):
    if not os.path.exists(f):
        f = os.path.join(*f.rsplit('/'))
        f = os.path.join(rootDir, f)
        if force:
            f = set_directory(f)
        logger.debug("set "+name+" fpath:\n"+f)
    return f

def save_image(saveReport=False, output_directory="", img_name="figure", img_ext=".png", plt_obj=None):
    if plt_obj is None:
        plt_obj = plt.gcf()
    if saveReport:
        logger.info('Save distplot')
        plt_obj.savefig(os.path.join(
            output_directory, 'Fig_'+img_name+img_ext),
            transparent=True, bbox_inches='tight',
            pad_inches=0.1, frameon=False)
        plt_obj.close("all")
    else:
        plt_obj.show()

def extract_gene_set(df):
    gene_set = set()
    if 'dupl_genes' in df.columns:
        dupl_col = df['dupl_genes']
        dupl_set = set([
            item for sublist in dupl_col 
            if isinstance(sublist, str) 
            for item in eval(sublist)
        ])
        gene_set = gene_set.union(set(dupl_set))

        if 'cleanName' in df.columns:
            gene_set = gene_set.union(set(df['cleanName'].values))
    else:
        gene_set = gene_set.union(set(df.index.values))
        
    return gene_set

def plot_data_heatmap(
    data, ground_truth, xlabel, xpos, 
    vmin, vmax, cmap_custom, custom_div_cmap_arg,
    function_dict
):
    ground_truth_sorted = ground_truth.sort_values()
    data_sorted = data.loc[ground_truth_sorted.index,:].copy()
    try:
        pat_labels_txt = ground_truth_sorted.astype(int).reset_index().values
    except:
        pat_labels_txt = ground_truth_sorted.reset_index().values

    _figure_x_size, _figure_y_size, _show_gene_names, _ = \
        set_heatmap_size(data_sorted)
    plt.figure(figsize=(_figure_x_size, _figure_y_size))
    ax = sns.heatmap(data_sorted, vmin=vmin, vmax=vmax,
                     yticklabels=pat_labels_txt, xticklabels=False,
                     cmap=cmap_custom, cbar=False)
    if (_show_gene_names and (
            (xpos is None) or
            (xlabel is None))):
        plt.xticks(rotation=90)
    elif (
            (xpos is not None) and
            (xlabel is not None)):
        plt.xticks(xpos, xlabel, rotation=0)
    plt.xlabel('chromosomes (the number is aligned at the end ' +
               'of the chr region)')
    plt.ylabel('samples')
    cbar = ax.figure.colorbar(ax.collections[0])
    set_cbar_ticks(cbar, function_dict, custom_div_cmap_arg)
    
def _run_classification(
        dat, dat_target, random_state=None, n_splits=10):

    min_class_count = np.unique(dat_target, return_counts=True)[1].min()
    if n_splits is not None:
        if (n_splits > dat.shape[0]) or (n_splits > min_class_count):
            n_splits = min_class_count
    if random_state is not None:
        random_state = parse_arg_type(random_state, int)
    else:
        random_state = 0
    logger.info(
        "model: svm.LinearSVC with l2 penalty, squared_hinge loss " +
        "and random_state: "+str(random_state)
    )
    model = svm.LinearSVC(
        penalty='l2', C=1, random_state=random_state,
        loss='squared_hinge', dual=False
    )

    logger.info("Running classification...")
    dat = dat.copy()
    dat_target = dat_target.copy()

    X = dat
    y = dat_target
    k_fold = StratifiedKFold(n_splits=n_splits)
    cross_val_scores = []
    all_coefs = np.zeros((n_splits, dat.shape[1]))
    y_train_predictions = pd.Series(index=y.index)
    y_train_predictions.name = "train_predictions"
    
    fprs = []
    tprs = []
    interps = []
    aucs = []
    mean_fpr = np.linspace(0, 1, 100)

    split_i = 0
    for train_indices, test_indices in k_fold.split(X, y):
        X_train = dat.iloc[train_indices]
        y_train = dat_target.iloc[train_indices]
        
        X_crossval = dat.iloc[test_indices]
        y_crossval = dat_target.iloc[test_indices]

        model.fit(X_train, y_train)
        all_coefs[split_i:split_i+1, :] = model.coef_[0]
        cross_val_scores.append(model.score(X_crossval, y_crossval))
        y_train_predictions.iloc[test_indices] = model.predict(X_crossval)
        
        
        clf = CalibratedClassifierCV(base_estimator=model, cv='prefit')
        clf.fit(X_crossval, y_crossval)
        y_proba = clf.predict_proba(X_crossval)
        # Compute ROC curve and area the curve
        fpr, tpr, thresholds = roc_curve(y_crossval, y_proba[:, 1])
        fprs.append(fpr)
        tprs.append(tpr)
        interps.append(interp(mean_fpr, fpr, tpr))
        interps[-1][0] = 0.0
        roc_auc = auc(fpr, tpr)
        aucs.append(roc_auc)

        split_i += 1

    X = dat
    y = dat_target
    model.fit(X, y)

    all_coefs = pd.DataFrame(all_coefs, columns=dat.columns.values)

    return model, all_coefs, y_train_predictions, cross_val_scores, fprs, tprs, interps, aucs

# plot count of correct/wrong predictions per class
def plot_prediction_counts_per_class(y_real, y_pred, class_labels=None, class_values=None):
    compare_predictions = pd.concat(
        [y_real, np.abs(y_pred-y_real)], axis=1)
    compare_predictions.columns = columns=['real', 'pred_diffs']
    y_maxlim = max([
            np.histogram(compare_predictions.iloc[:,i], bins=2)[0].max() 
            for i in range(2)
        ])

    axes = compare_predictions.hist(
        by='real', column='pred_diffs',
        bins=2, rwidth=0.4, figsize=(10, 6))
    for ax in axes:
        ax.set_ylim(0, y_maxlim+1)
        ax.set_xlim(0, 1)
        ax.set_xticks([0.25, 0.75])
        ax.set_xticklabels(['correct', 'wrong'], rotation=0, fontsize=14)
        if class_labels is not None and class_values is not None:
            ax_title = class_labels[class_values == float(ax.get_title())][0]+':'+str(ax.get_title())
            ax.set_title(ax_title, fontsize=14)
        plt.suptitle('real predictions', fontsize=16)

# plot confusion matrix
def compute_and_plot_confusion_matrices(y_real, y_pred, class_labels=None, class_values=None):
    # Compute confusion matrix
    cnf_matrix = confusion_matrix(
        y_real,
        y_pred.loc[y_real.index])
    np.set_printoptions(precision=2)
    if class_labels is not None and class_values is not None:
        _classes = [
            class_labels[class_values == 0][0],
            class_labels[class_values == 1][0]]
    else:
        _classes = ['class_0', 'class_1']

    # Plot non-normalized confusion matrix
    plt.figure()
    plot_confusion_matrix(
        cnf_matrix, classes=_classes,
        title='Confusion matrix, without normalization')
    plt1 = plt.gcf()

    # Plot normalized confusion matrix
    plt.figure()
    plot_confusion_matrix(
        cnf_matrix, classes=_classes, normalize=True,
        title='Normalized confusion matrix')
    plt2 = plt.gcf()
    
    return plt1, plt2

def plot_roc_for_many_models(model_names, fprs, tprs, aucs, figsize=(10,10)):
    n_models = len(model_names)
    plt.figure(figsize=figsize)
    for i in range(n_models):
        plt.plot(
            fprs[i], tprs[i], #lw=1, alpha=0.3,
            label='%s (AUC=%0.2f)' % (model_names[i], aucs[i]))
    plt.plot(
        [0, 1], [0, 1], linestyle='--', lw=2, color='r',
        label='Luck', alpha=.8)
    
    plt.xlim([-0.05, 1.05])
    plt.ylim([-0.05, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC curves')
    plt.legend(loc="best")
    
def plot_roc_with_std_for_one_model(n_splits, fprs, tprs, interps, aucs, figsize=(10,10), model_name=None):
    mean_fpr = np.linspace(0, 1, 100)
    
    plt.figure(figsize=figsize)
    for i in range(n_splits):
        plt.plot(
            fprs[i], tprs[i], lw=1, alpha=0.3,
            label='ROC fold %d (AUC = %0.2f)' % (i, aucs[i]))
    plt.plot(
        [0, 1], [0, 1], linestyle='--', lw=2, color='r',
        label='Luck', alpha=.8)

    mean_tpr = np.mean(interps, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = auc(mean_fpr, mean_tpr)
    std_auc = np.std(aucs)
    plt.plot(mean_fpr, mean_tpr, color='b',
             label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
             lw=2, alpha=.8)

    std_tpr = np.std(interps, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
    plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
                     label=r'$\pm$ 1 std. dev.')

    plt.xlim([-0.05, 1.05])
    plt.ylim([-0.05, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    if model_name is None:
        plt.title('Cross-validation training ROC curves\nwith std for one model')
    else:
        plt.title('Cross-validation training ROC curves\nwith std for model: '+model_name)
    plt.legend(loc="best")
    
def plot_scatter_scores(y_train_scores, y_test_score=None):
    # plot accuracy scores of the train and test data
    plt.figure(figsize=(10, 6))
    plt.scatter(
        np.arange(len(y_train_scores))+1, sorted(y_train_scores), color='black')
    if y_test_score is not None:
        plt.scatter(0, y_test_score, color='red')
    plt.axhline(0, color='k')
    plt.xlim(-1, len(y_train_scores)+1)
    plt.ylim(-0.5, 1.5)
    plt.xlabel("test and train kfolds")
    plt.ylabel("accuracy scores")


START ANALYSIS

In [None]:
if DEBUG:
    logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)

In [None]:
class_labels = np.array(class_labels)
class_values = np.array(class_values)

In [None]:
# properly set file paths
try:
    os.path.exists(MainDataDir)
except:
    MainDataDir = os.path.join(script_path, '..','..', 'data')
    logger.debug("set MainDataDir:\n"+MainDataDir)

# data input
data_fpath = check_path_integrity(data_fpath, rootDir=MainDataDir, name="data")

# sample info input
sample_info_fpath = check_path_integrity(sample_info_fpath, rootDir=MainDataDir, name="sample_info")

# gene info input
genes_info_fpath = check_path_integrity(genes_info_fpath, rootDir=MainDataDir, name="gene_info")

# data output
output_directory = check_path_integrity(output_directory, rootDir=MainDataDir, name="output", force=True)

In [None]:
# fpaths_dict
fpaths_dict = {}
fpaths_dict['genepanel'] = check_path_integrity(genepanel_path, rootDir=MainDataDir, name="genepanel features")

for _f, _k in zip(_dirs, _key_names):
    fpath = "output/headneck/feature_selection/"+_f+"/featsel_results.csv"
    fpaths_dict[_k] = check_path_integrity(fpath, rootDir=MainDataDir, name=_k+" features")
    

In [None]:
if (cmap_custom is None) and (vmin is not None) and (vmax is not None):
    custom_div_cmap_arg = abs(vmin)+abs(vmax)
    if (vmin <= 0) and (vmax >= 0):
        custom_div_cmap_arg = custom_div_cmap_arg + 1
    mincol = plot_kwargs.get('mincol', None)
    midcol = plot_kwargs.get('midcol', None)
    maxcol = plot_kwargs.get('maxcol', None)
    if (
            (mincol is not None) and
            (midcol is not None) and
            (maxcol is not None)
            ):
        cmap_custom = custom_div_cmap(
            numcolors=custom_div_cmap_arg,
            mincol=mincol, midcol=midcol, maxcol=maxcol)
    else:
        cmap_custom = custom_div_cmap(numcolors=custom_div_cmap_arg)

In [None]:
# load data
data = pd.read_csv(data_fpath, sep='\t', header=0, index_col=0)
logger.info('loaded data file with shape: '+str(data.shape))

In [None]:
data.head()

In [None]:
# load info table of samples
sample_info = load_clinical(
    sample_info_fpath, **sample_info_read_csv_kwargs)
logger.info('loaded sample_info file with shape: '+str(sample_info.shape))

In [None]:
sample_info = sample_info.loc[data.index,:]
logger.info('keeping part of sample_infowith shape: '+str(sample_info.shape))

In [None]:
sample_info.head()

In [None]:
# load info table of genes
genes_info = pd.read_csv(genes_info_fpath, sep='\t', header=0, index_col=0)
logger.info('loaded gene_info file with shape: '+str(genes_info.shape))

In [None]:
genes_info.head()

In [None]:
# set the ground truth
ground_truth = sample_info.loc[data.index, sample_class_column]

In [None]:
ground_truth.head()

In [None]:
#  Plot Heatmap of data w/ duplicates
xlabels, xpos = get_chr_ticks(
    genes_info, data, id_col=gene_id_col, chr_col=chr_col)

plot_data_heatmap(
    data, ground_truth, xlabels, xpos, 
    vmin, vmax, cmap_custom, custom_div_cmap_arg,
    function_dict
)
plt.title('data: '+str(data.shape[1])+' gene profiles')
save_image(saveReport=saveReport, output_directory=output_directory, img_name="heatmap_data", img_ext=img_ext)

In [None]:
# remove all zero columns!
orphancols = np.where(abs(data).sum(axis=0) == 0)[0]
if len(orphancols) > 0:
    logger.warning('removing '+str(len(orphancols))+' genes from train data with zero columns!')
    cols2drop = data.columns.values[orphancols]
    data = data.drop(cols2drop, axis=1).copy()

# REMOVE DUPLICATES!!!!
data_uniq, dupldict, wo_dupl_set, all_dupl_set = remove_andSave_duplicates(
    data, to_compute_euclidean_distances=True,
    to_save_euclidean_distances=True, to_save_output=True,
    output_filename='data_wo_duplicates',
    output_directory=output_directory
)
single_dupl_set = set(dupldict.keys())

_countA = len(set.union(single_dupl_set, wo_dupl_set))
_countB = data_uniq.shape[1]
if not _countA == _countB:
    print(
        'ERROR: inconsistencies in the final uniq gene count!\n'+
        str(_countA)+' genes that should be in the uniq dataset VS. '+
        str(_countB)+' genes that are'
    )

In [None]:
#  Plot Heatmap of data w/o duplicates
xlabels_uniq, xpos_uniq = get_chr_ticks(
    genes_info, data_uniq, id_col=gene_id_col, chr_col=chr_col)

plot_data_heatmap(
    data_uniq, ground_truth, xlabels_uniq, xpos_uniq, 
    vmin, vmax, cmap_custom, custom_div_cmap_arg,
    function_dict
)
plt.title('data w/o duplicates: '+str(data_uniq.shape[1])+' gene profiles')
save_image(saveReport=saveReport, output_directory=output_directory, img_name="heatmap_data_uniq", img_ext=img_ext)

In [None]:
# split data in train-test ONCE!
stratify_by = pd.concat([ground_truth, sample_info['dataset']], axis=1, sort=False)
stratify_by = stratify_by.loc[ground_truth.index]

data_train, data_test, y_train, y_test = train_test_split(
    data_uniq, ground_truth,
    train_size=split_train_size,
    test_size=None,
    random_state=split_random_state,
    stratify=stratify_by
)

stratify_by.hist()
plt.suptitle('all '+str(ground_truth.shape[0])+' samples', fontsize=16)
save_image(saveReport=saveReport, output_directory=output_directory, img_name="stratify_by_all", img_ext=img_ext)

stratify_by.loc[data_train.index].hist()
plt.suptitle(str(y_train.shape[0])+' train samples', fontsize=16)
save_image(saveReport=saveReport, output_directory=output_directory, img_name="stratify_by_train", img_ext=img_ext)

stratify_by.loc[data_test.index].hist()
plt.suptitle(str(y_test.shape[0])+' test samples', fontsize=16)
save_image(saveReport=saveReport, output_directory=output_directory, img_name="stratify_by_test", img_ext=img_ext)


In [None]:
xlabels_train, xpos_train = get_chr_ticks(
    genes_info, data_train, id_col=gene_id_col, chr_col=chr_col)

xlabels_test, xpos_test = get_chr_ticks(
    genes_info, data_test, id_col=gene_id_col, chr_col=chr_col)

#  Plot Heatmap of train data (w/o duplicates)
plot_data_heatmap(
    data_train, y_train, xlabels_train, xpos_train, 
    vmin, vmax, cmap_custom, custom_div_cmap_arg,
    function_dict
)
plt.title('train data: '+str(data_train.shape[1])+' gene profiles')
save_image(saveReport=saveReport, output_directory=output_directory, img_name="heatmap_data_train", img_ext=img_ext)

#  Plot Heatmap of test data (w/o duplicates)
plot_data_heatmap(
    data_test, y_test, xlabels_test, xpos_test, 
    vmin, vmax, cmap_custom, custom_div_cmap_arg,
    function_dict
)
plt.title('test data: '+str(data_test.shape[1])+' gene profiles')
save_image(saveReport=saveReport, output_directory=output_directory, img_name="heatmap_data_test", img_ext=img_ext)

In [None]:
all_data_genes = set(data.columns.values)
all_data_genes_uniq = set(data_uniq.columns.values)

In [None]:
features_dict = {}
features_sets = {}
for key in fpaths_dict:
    
    df = pd.read_csv(fpaths_dict[key], sep='\t', header=0, index_col=0)
    
    if 'genepanel' in key:
        features_sets[key] = set(df.columns.values)
        n_total = len(features_sets[key])
        n_unique = n_total
    else:
        features_dict[key] = df
        features_sets[key] = extract_gene_set(df)
        n_total = len(features_sets[key])
        n_unique = df.shape[0]
        
    logger.info(str(n_unique)+' unique out of '+str(n_total)+' total features from '+key)

Venn diagrams to explain the functionality of the cell below:<br>
U --> data uniq genes (genes w/o dupl + single copy duplicates genes) <br>
D --> the rest of the duplicates genes copies<br>
fs --> a single features set<br>
IU --> the features that exist in the U set<br>
ID --> the features that exist in the D set<br>
NI --> the features that do NOT exist in neither set<br>
IDa --> the features that exist in the ID set and are not represented in the IU set<br>
IDb --> the features that exist in the ID set and are already represented in the IU set<br>
_ID = IDa + IDb_<br>
U_IDa --> the features that exist in the U set (but not in the IU set) and represent the IDa features<br>
**fs in data_uniq = IU + U_IDa**<br>
<img src="./files/venn_legend.jpg?1" alt="drawing" style="float:left" width="300px"/>

In [None]:
# U: all_data_genes_uniq
# D: all_dupl_set.difference(all_data_genes_uniq)
# fs: features_sets[key]
new_features_sets = {}

U_set = all_data_genes_uniq
D_set = all_dupl_set.difference(all_data_genes_uniq)
for key in features_sets:
    fs = features_sets[key]
    _fs_original_size = len(fs)
    print(key+' feature set :')
    print('--- originally ---')
    print(' original total size: '+str(_fs_original_size))
    IU_set = fs.intersection(U_set)
    _IU_size = len(IU_set)
    print(' IU_set: '+str(_IU_size))

    ID_set = (fs.difference(IU_set)).intersection(D_set)
    _ID_size = len(ID_set)
    print(' ID_set: '+str(_ID_size))

    NI_set = (fs.difference(IU_set)).difference(ID_set)
    _NI_size =len(NI_set)
    print(' NI_set: '+str(_NI_size))

    U_IDa_set = set()
    IDa_set = set()
    IDb_set = set()
    IDc_set = set()
    done = False
    temp_set = ID_set.copy()
    while temp_set and not done:
        for ud, dl in dupldict.items():
            _IDx = set(dl).intersection(temp_set)
            if _IDx:
                temp_set = temp_set.difference(_IDx)
                if ud not in IU_set:
                    U_IDa_set.add(ud)
                    IDa_set.update(_IDx)
                else:
                    IDb_set.update(_IDx) 
        done = True

    fs_in_data_uniq = set.union(IU_set, U_IDa_set)
    _fs_in_data_uniq_size = len(fs_in_data_uniq)

    ########################################
    new_features_sets[key] = fs_in_data_uniq
    ########################################
    
    print('--- finally ---')
    print(' features in data uniq : '+str(_fs_in_data_uniq_size))
    print(' IU_set: '+str(_IU_size))
    print('   U_IDa_set: '+str(len(U_IDa_set)))
    print(' ID_set: '+str(len(ID_set)))
    print('   IDa_set: '+str(len(IDa_set)))
    print('   IDb_set: '+str(len(IDb_set)))
    print(' NI_set: '+str(_NI_size))

    if ID_set != set.union(IDa_set, IDb_set):
        print(
            'ERROR: something went wrong, ID_set != IDa_set + IDb_set, for feature set: '+key
        )
        break

    print('\n')

In [None]:
fs_fprs = []
fs_tprs = []
fs_aucs = []
for key in new_features_sets:
    fset = list(new_features_sets[key])
    X_train = data_train.loc[:,fset].copy()
    X_test = data_test.loc[:,fset].copy()

    # train model
    model, all_coefs, y_train_predictions, y_train_scores, fprs, tprs, interps, aucs = \
        _run_classification(
            X_train, y_train, **classification_args)

    # plot_prediction_counts_per_class
    plot_prediction_counts_per_class(
        y_train, y_train_predictions, class_labels=class_labels, class_values=class_values)
    save_image(
        saveReport=saveReport, output_directory=output_directory, 
        img_name="count_predictions_per_class", img_ext=img_ext)

    # compute_and_plot_confusion_matrices
    plt1, plt2 = compute_and_plot_confusion_matrices(
        y_train, y_train_predictions, class_labels=class_labels, class_values=class_values)
    save_image(
        saveReport=saveReport, output_directory=output_directory, 
        img_name="confusion_matrix", img_ext=img_ext, plt_obj=plt1)
    save_image(
        saveReport=saveReport, output_directory=output_directory, 
        img_name="confusion_matrix_normalized", img_ext=img_ext, plt_obj=plt2)

    # plot_roc_with_std_for_one_model
    n_splits = classification_args["n_splits"]
    plot_roc_with_std_for_one_model(n_splits, fprs, tprs, interps, aucs, figsize=(10,10), model_name=key)
    save_image(
        saveReport=saveReport, output_directory=output_directory, 
        img_name="train_crossval_roc_curves_"+key, img_ext=img_ext)

    # Test the model
    y_test_score = model.score(X_test, y_test)
    y_test_predictions = model.predict(X_test)
    y_test_predictions = pd.Series(y_test_predictions, index=X_test.index)
    y_test_predictions.name = 'test_predictions'

    plot_scatter_scores(y_train_scores, y_test_score)
    save_image(
        saveReport=saveReport, output_directory=output_directory, 
        img_name="scatter_scores", img_ext=img_ext)
    
    #  prepare for the ROC curves on each feature set
    clf = CalibratedClassifierCV(base_estimator=model, cv='prefit')
    clf.fit(X_test, y_test)
    y_proba = clf.predict_proba(X_test)
    # Compute ROC curve and area the curve
    fpr, tpr, thresholds = roc_curve(y_test, y_proba[:, 1])
    fs_fprs.append(fpr)
    fs_tprs.append(tpr)
    roc_auc = auc(fpr, tpr)
    fs_aucs.append(roc_auc)
    

In [None]:
plot_roc_for_many_models(list(new_features_sets.keys()), fs_fprs, fs_tprs, fs_aucs, figsize=(10,10))
save_image(
    saveReport=saveReport, output_directory=output_directory, 
    img_name="test_all_models_roc_curves", img_ext=img_ext)
