# Import libraries

In [1]:
import pandas as pd
import numpy as np
import os
import config
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_predict
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score, accuracy_score
from utils import confusion
from sklearn.metrics import confusion_matrix
from tqdm import tqdm

# Load data

In [2]:
X = pd.read_csv(os.path.join(config.CLEAN_DIR, "taxonomic_features.csv"), index_col=[0, 1])
y = pd.read_csv(os.path.join(config.CLEAN_DIR, "metadata.csv"), index_col=[0, 1]).iloc[:, [0]]
display(X.head())
display(y.head())

Unnamed: 0_level_0,Unnamed: 1_level_0,k__Archaea,k__Archaea|p__Euryarchaeota,k__Archaea|p__Euryarchaeota|c__Methanobacteria,k__Archaea|p__Euryarchaeota|c__Methanobacteria|o__Methanobacteriales,k__Archaea|p__Euryarchaeota|c__Methanobacteria|o__Methanobacteriales|f__Methanobacteriaceae,k__Archaea|p__Euryarchaeota|c__Methanobacteria|o__Methanobacteriales|f__Methanobacteriaceae|g__Methanobrevibacter,k__Archaea|p__Euryarchaeota|c__Methanobacteria|o__Methanobacteriales|f__Methanobacteriaceae|g__Methanobrevibacter|s__Methanobrevibacter_smithii,k__Archaea|p__Euryarchaeota|c__Methanobacteria|o__Methanobacteriales|f__Methanobacteriaceae|g__Methanosphaera,k__Archaea|p__Euryarchaeota|c__Methanobacteria|o__Methanobacteriales|f__Methanobacteriaceae|g__Methanosphaera|s__Methanosphaera_stadtmanae,k__Archaea|p__Euryarchaeota|c__Thermoplasmata,...,k__Viruses|p__Viruses_unclassified|c__Viruses_unclassified|o__Viruses_unclassified|f__Virgaviridae|g__Hordeivirus|s__Barley_stripe_mosaic_virus,k__Viruses|p__Viruses_unclassified|c__Viruses_unclassified|o__Viruses_unclassified|f__Virgaviridae|g__Tobamovirus,k__Viruses|p__Viruses_unclassified|c__Viruses_unclassified|o__Viruses_unclassified|f__Virgaviridae|g__Tobamovirus|s__Cactus_mild_mottle_virus,k__Viruses|p__Viruses_unclassified|c__Viruses_unclassified|o__Viruses_unclassified|f__Virgaviridae|g__Tobamovirus|s__Cucumber_green_mottle_mosaic_virus,k__Viruses|p__Viruses_unclassified|c__Viruses_unclassified|o__Viruses_unclassified|f__Virgaviridae|g__Tobamovirus|s__Paprika_mild_mottle_virus,k__Viruses|p__Viruses_unclassified|c__Viruses_unclassified|o__Viruses_unclassified|f__Virgaviridae|g__Tobamovirus|s__Pepper_mild_mottle_virus,k__Viruses|p__Viruses_unclassified|c__Viruses_unclassified|o__Viruses_unclassified|f__Virgaviridae|g__Tobamovirus|s__Tobacco_mild_green_mosaic_virus,k__Viruses|p__Viruses_unclassified|c__Viruses_unclassified|o__Viruses_unclassified|f__Viruses_unclassified|g__Viruses_unclassified|s__Deep_sea_thermophilic_phage_D6E,k__Viruses|p__Viruses_unclassified|c__Viruses_unclassified|o__Viruses_unclassified|f__Viruses_unclassified|g__Viruses_unclassified|s__Loktanella_phage_pCB2051_A,k__Viruses|p__Viruses_unclassified|c__Viruses_unclassified|o__Viruses_unclassified|f__Viruses_unclassified|g__Viruses_unclassified|s__Tetraselmis_viridis_virus_S1
Study_ID,Sample Accession,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
GMHI-10,SAMN03283239,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
GMHI-10,SAMN03283266,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
GMHI-10,SAMN03283281,0.009764,0.009764,0.009764,0.009764,0.009764,0.009764,0.009764,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
GMHI-10,SAMN03283294,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
GMHI-10,SAMN03283288,0.011865,0.011865,0.011865,0.011865,0.011865,0.011865,0.011865,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Unnamed: 0_level_0,Unnamed: 1_level_0,PHENOTYPE:Healthy_Nonhealthy
Study_ID,Sample Accession,Unnamed: 2_level_1
GMHI-10,SAMN03283239,True
GMHI-10,SAMN03283266,True
GMHI-10,SAMN03283281,True
GMHI-10,SAMN03283294,True
GMHI-10,SAMN03283288,True


# Generate predictions in 10-fold cross validation! (repeated 10 times)

In [3]:
np.random.seed(42)

In [4]:
num_repeats = 10
n_splits = 10

logit_prediction_list = []

for i in tqdm(range(num_repeats)):
    
    kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=np.random.randint(0, high=1000000))
    
    # instantiate the logistic regression classifier (gmhi2)
    gmhi2 = LogisticRegression(random_state=42, penalty="l1", solver="liblinear", C=config.REGULARIZATION, class_weight="balanced")

    # generate logit predictions in 10 fold cross validation
    GMHI2_scores_cv = cross_val_predict(gmhi2, X > config.PRESENCE_CUTOFF, y.values.flatten(), method="decision_function",
        cv=kfold, verbose=0, n_jobs=-1
    )

    GMHI2_scores_cv = pd.DataFrame(GMHI2_scores_cv, index=y.index, columns=["GMHI2_cv"])
    logit_prediction_list.append(GMHI2_scores_cv)

100%|███████████████████████████████████████████████████████████████| 10/10 [00:17<00:00,  1.75s/it]


In [5]:
for cutoff in [0, 0.1, 0.5, 1]:
    print("cutoff:", cutoff)
    confusions = sum([confusion(log, y, cutoff) for log in logit_prediction_list]) / 10
    display(confusions)

cutoff: 0


Unnamed: 0,Predicted Healthy,Predicted Nonhealthy,Accuracy
Actual Nonhealthy,525.2,1996.8,0.791753
Actual healthy,4356.5,1190.5,0.785379


cutoff: 0.1


Unnamed: 0,Predicted Healthy,Predicted Nonhealthy,Accuracy
Actual Nonhealthy,468.3,1936.9,0.805298
Actual healthy,4230.1,1071.6,0.797876


cutoff: 0.5


Unnamed: 0,Predicted Healthy,Predicted Nonhealthy,Accuracy
Actual Nonhealthy,289.7,1636.5,0.849606
Actual healthy,3729.1,676.5,0.846445


cutoff: 1


Unnamed: 0,Predicted Healthy,Predicted Nonhealthy,Accuracy
Actual Nonhealthy,129.6,1217.2,0.903774
Actual healthy,2964.8,334.7,0.898561


In [6]:
def evaluate_performance(cutoff):
    balanced_accuracies = [balanced_accuracy_score(y[(abs(lo) >= cutoff).values], lo[(abs(lo) >= cutoff).values] > 0) for lo in logit_prediction_list]
    mean_acc = np.mean(balanced_accuracies)
    print("Mean acc:", mean_acc)
    
    mat_sum = np.zeros([2 , 2])

    for lo in logit_prediction_list:
        idx = (abs(lo) >= cutoff).values
        y_curr = y[idx]
        lo_curr = lo[idx]
        mat = confusion_matrix(y_curr, lo_curr > 0)
        mat_sum += mat
    tn, fp, fn, tp = mat_sum.ravel()
    df = pd.DataFrame(mat_sum, columns=["Predicted Nonhealthy", "Predicted Healthy"], index=["Actual Nonhealthy", "Actual healthy"])
    display(df)
    print("Total samples evaluated:", mat_sum.sum())
    print("Percentage of possible:", mat_sum.sum() / X.shape[0] * 10)

In [7]:
evaluate_performance(0)

Mean acc: 0.7885660308627864


Unnamed: 0,Predicted Nonhealthy,Predicted Healthy
Actual Nonhealthy,19968.0,5252.0
Actual healthy,11905.0,43565.0


Total samples evaluated: 80690.0
Percentage of possible: 100.0


In [8]:
evaluate_performance(0.1)

Mean acc: 0.8015873009718006


Unnamed: 0,Predicted Nonhealthy,Predicted Healthy
Actual Nonhealthy,19369.0,4683.0
Actual healthy,10716.0,42301.0


Total samples evaluated: 77069.0
Percentage of possible: 95.51245507497832


In [9]:
evaluate_performance(0.5)

Mean acc: 0.848025561304277


Unnamed: 0,Predicted Nonhealthy,Predicted Healthy
Actual Nonhealthy,16365.0,2897.0
Actual healthy,6765.0,37291.0


Total samples evaluated: 63318.0
Percentage of possible: 78.47069029619531


In [10]:
evaluate_performance(1)

Mean acc: 0.9011677832786118


Unnamed: 0,Predicted Nonhealthy,Predicted Healthy
Actual Nonhealthy,12172.0,1296.0
Actual healthy,3347.0,29648.0


Total samples evaluated: 46463.0
Percentage of possible: 57.58210434998141


In [13]:
logit_prediction_list[0].to_csv(os.path.join(config.LOG_DIR, "GMHI2_scores_10-fold.csv"))