## Run_entire_model

Wrapper for other functions to run the whole process

In [8]:
import pandas as pd
import os
import csv
import math
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy import stats

#Load other functions
%run feature_reductions.ipynb
%run join_and_normalize.ipynb
%run modeling.ipynb

In [7]:
# Label name is 'Group' for the HIV dataset

def run_entire_model(osu_files, 
                     tax_files, 
                     meta_file, 
                     norm_type,
                     ncomp,
                     levelup, 
                     groups,
                     feat_reduction, 
                     test_frac, 
                     model,
                     plot_pca,
                     plot_lc,
                     plot_ab_comp,
                     cutoff,
                     scorer):
    '''
    Wrapper script to run the HIV Classifier!
    Inputs: `osu_files` - list of locations of osu_abundances.txt files
            `tax_files` - list of locations of osu_taxonomy.txt files
            `meta_file` - str of location of meta data file
            `norm_type` - type of normalization: 'clr', 'css', 'tss'
            `ncomp` - integer number of components
            `levelup` - taxonomy level to group data: None, 'genus', 'order','class', 'family', 'phylum'
            `groups` - List of two groups: e.g. ['NEG','RHI']
            `feat_reduction` - Feat reduction method: None, 'svd', 'zscore', 'corr', 'diff'
            `test_frac` - Test fraction (e.g. 0.25)
            `model` - type of model: 'lg', 'rf', 'xgb', 'all'
            `plot_pca` - Bool of whether to plot PCA
            `plot_lc` - Bool to plot learning curve 
            `plot_ab_comp` - Bool to plot abundance comparison
            `cutoff` - cutoff level (e.g. 0.001)
            `scorer` - type of score: precision_score, recall_score, f1_score, fbeta_score, roc_auc_score
    Returns: Report of performance
    '''
    
    #Make dataframe of OSU data and normalize according to <norm_type>
    osu_df = join_osus(osu_files,norm_type)
    
    #Get label meta data
    meta_df = get_labels(meta_file)
    
    #Join OSU data with meta data and load taxonomy data
    osu_df = join_osu_with_labels(osu_df,meta_df)
    tax_df = join_taxonomy(tax_files)
    
    #Select only the comparison <groups>
    pair_df = select_groups(osu_df,groups)

    #If selected, level up the taxonomy information
    if levelup != None:
        pair_df = make_df_up_level(pair_df,tax_df,levelup,norm_type)
    
    #Use feature reduction strategy
    feats = None
    if feat_reduction =='svd':
        X,Y,labels = SVD_truncate(pair_df,ncomp,cutoff)
    elif feat_reduction =='zscore':
        plot_cutoff = 0.4
        X,Y,labels,feats = make_dataset_zscore(pair_df, ncomp,cutoff,plot_cutoff,norm_type)
    elif feat_reduction =='corr':
        X,Y,labels,feats = feature_from_correlation(pair_df,ncomp,cutoff,norm_type)
    elif feat_reduction =='diff':
        X,Y,labels,feats = make_dataset_osu_diff(pair_df, ncomp,cutoff)
    elif feat_reduction == None:
        X,Y,labels,feats = make_dataset(pair_df,cutoff)
    
    print("Comparing the following groups:",labels)

    #Print the top features if feature reduction is chosen
    if feats !=None:
        if len(feats)<50:
            print("Top features:",feats)
        
    #Plot PCA data 
    if plot_pca == True:
        if feat_reduction == 'svd':
            #PCA does not need to be done if SVD was already done
            X1 = X[:, 0]
            X2 = X[:, 1]
        else: 
            X_PCA = PCA(n_components=2, random_state=42).fit_transform(np.array(X))
            X1 = X_PCA[:, 0]
            X2 = X_PCA[:, 1]
        c=np.array(Y)
        colors = np.where(c == 0, 'r', 'k')
        plt.figure(figsize=(6, 6))
        plt.scatter(X1, X2, c=colors)
        plt.show()
        

    #Split into training and test
    seed = 30
    X_train, X_test, y_train, y_test = train_test_split(X, Y,
                                                    stratify=Y,
                                                    test_size=test_frac,
                                                    random_state=seed)
    
    
    #Optimize classifier(s)
    if model == 'lg':
        print('-'*50)
        print('Logistic Regression')
        result_table = opt_log_reg(X_train,y_train,X_test,y_test,labels,scorer)
    elif model == 'rf':
        print('-'*50)
        print('Random Forest')
        result_table = opt_random_forest(X_train,y_train,X_test,y_test,labels,scorer)
    elif model == 'xg':
        print('-'*50)
        print('XG Boost')
        opt_xgboost(X_train,y_train,X_test,y_test,labels,scorer)
    elif model == 'all':
        print('-'*50)
        print('Logistic Regression')
        opt_log_reg(X_train,y_train,X_test,y_test,labels,scorer)
        print('-'*50)
        print('Random Forest')
        opt_random_forest(X_train,y_train,X_test,y_test,labels,scorer)
        print('-'*50)
        print('XG Boost')
        opt_xgboost(X_train,y_train,X_test,y_test,labels,scorer)
       
    #Plot comparison of features
    if plot_ab_comp == True:
        if feats == None:
            print("Plotting abundance only applicable when zscore, diff, corr feature reduction used.")
        elif len(feats) > 20:
            print("Too many features to plot effectively.")
        else:
            IDs = feats
            abundance_comparison(pair_df,IDs,norm_type)
    
    #Plot learning curve
    if plot_lc == True:

        title = "Learning Curve for Logistic Regression"
        # Cross validation with 100 iterations to get smoother mean test and train
        # score curves, each time with 20% data randomly selected as a validation set.
        cv = ShuffleSplit(n_splits=100, test_size=test_frac, random_state=0)
        
        #get classweights
        cw=get_class_weight(Y)
        
        #Use Logistic Regression 
        estimator = LogisticRegression(random_state = 42, solver ='liblinear',class_weight=cw);
        plot_learning_curve(estimator, title, X, Y, ylim=(0.0, 1.01), cv=cv, n_jobs=10)

        plt.show()
