In [None]:
import os
import math
import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, auc
from sklearn.metrics import ConfusionMatrixDisplay, RocCurveDisplay
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import KFold, cross_val_score, cross_validate

In [None]:
def class_label(df):
    label = df.iloc[:,0]
    classlb = []
    for i in label:
        if 'H' in i:
            classlb.append(0)
        if 'R' in i:
            classlb.append(1)
    
    classlb = np.array(classlb)
    
    return classlb

In [None]:
def ns_assign(df):
    ns1 = df.iloc[:,1].to_numpy().reshape(-1, 1)
    ns2 = df.iloc[:,2].to_numpy().reshape(-1, 1)
    ns3 = df.iloc[:,3].to_numpy().reshape(-1, 1)
    ns4 = df.iloc[:,4].to_numpy().reshape(-1, 1)
    ns5 = df.iloc[:,5].to_numpy().reshape(-1, 1)
    ns_all = df.iloc[:,1:6].to_numpy()
    
    return ns1, ns2, ns3, ns4, ns5, ns_all

In [None]:
def clf_fit(df1, df2, clf):
    
    trainset = ns_assign(df1)
    y_train = class_label(df1)
    testset = ns_assign(df2)
    y_test = class_label(df2)
    
    clf.fit(trainset[0], y_train)
    y_pred1 = clf.predict_proba(testset[0])[:, 1]
    fpr1, tpr1, thresholds = roc_curve(y_test, y_pred1)

    clf.fit(trainset[1], y_train)
    y_pred2 = clf.predict_proba(testset[1])[:, 1]
    fpr2, tpr2, thresholds = roc_curve(y_test, y_pred2)

    clf.fit(trainset[2], y_train)
    y_pred3 = clf.predict_proba(testset[2])[:, 1]
    fpr3, tpr3, thresholds = roc_curve(y_test, y_pred3)

    clf.fit(trainset[3], y_train)
    y_pred4 = clf.predict_proba(testset[3])[:, 1]
    fpr4, tpr4, thresholds = roc_curve(y_test, y_pred4)

    clf.fit(trainset[4], y_train)
    y_pred5 = clf.predict_proba(testset[4])[:, 1]
    fpr5, tpr5, thresholds = roc_curve(y_test, y_pred5)

    clf.fit(trainset[5], y_train)
    y_pred_all = clf.predict_proba(testset[5])[:, 1]
    fpr_all, tpr_all, thresholds = roc_curve(y_test, y_pred_all)

    return fpr1, tpr1, fpr2, tpr2, fpr3, tpr3, fpr4, tpr4, fpr5, tpr5, fpr_all, tpr_all

In [None]:
def plot_roc_allns(df1, df2, clf):
    
    pr = clf_fit(df1, df2, clf)
        
    roc_auc1 = auc(pr[0], pr[1])
    roc_auc2 = auc(pr[2], pr[3])
    roc_auc3 = auc(pr[4], pr[5])
    roc_auc4 = auc(pr[6], pr[7])
    roc_auc5 = auc(pr[8], pr[9])
    roc_auc_all = auc(pr[10], pr[11])
    
    #vABN1: BV01 (S70)-HFA1 (col1), vABN2: S8-HFA3 (col3), vABN3: S108-d5eth (col5), 
    #vABN4: PP03 (S72)-d7isop (col4), #vABN5: OCW32 (S26) furin-d3but (col2)
    plt.figure(figsize=(3, 3),dpi = 160)
    plt.rc('font', family='Arial')
    plt.plot(pr[0], pr[1], color='red', lw=1.5, label='S70-HFA1 (AUC: %0.2f)' % roc_auc1)
    plt.plot(pr[8], pr[9], color='gold', lw=1.5, label='S72-d7isop (AUC: %0.2f)' % roc_auc4)
    plt.plot(pr[2], pr[3], color='green', lw=1.5, label='S26-d3but (AUC: %0.2f)' % roc_auc2)
    plt.plot(pr[4], pr[5], color='saddlebrown', lw=1.5, label='S8-HFA3 (AUC: %0.2f)' % roc_auc3)
    plt.plot(pr[6], pr[7], color='blue', lw=1.5, label='S108-d5eth (AUC: %0.2f)' % roc_auc5)

    plt.plot(pr[10], pr[11], color='black', lw=1.5, label='Multiplex (AUC = %0.2f)' % roc_auc_all)

    plt.plot([0, 1], [0, 1], 'k--', alpha=0.2, lw=1.5, label='Random classifier')

    plt.xticks(fontsize=15)
    plt.yticks(fontsize=15)
    plt.xlabel('1 - Specificity', fontsize=15)
    plt.ylabel('Sensitivity', fontsize=15)
    plt.legend(bbox_to_anchor=(1.05, 0.9), loc='upper left', edgecolor="None", fontsize=13)

    plt.show()

In [None]:
def plot_roc_multiplex(df1, df2, clf):
    
    pr = clf_fit(df1, df2, clf)
    roc_auc = auc(pr[10], pr[11])
    
    plt.figure(figsize=(3, 3),dpi = 160)
    plt.rc('font', family='Arial')
    
    plt.plot(pr[10], pr[11], color='black', lw=1.5, label='PR8 (AUC: %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.3, label='Random classifier')

    plt.xticks(fontsize=11)
    plt.yticks(fontsize=11)
    plt.xlabel('1 - Specificity', fontsize=13, labelpad=7)
    plt.ylabel('Sensitivity', fontsize=13, labelpad=7)

    plt.legend(loc='best', facecolor="None", edgecolor="None", fontsize=10)

    plt.show()

In [None]:
def plot_cm(df2): 
    
    y_test = class_label(df2)
    y_pred = clf.predict(ns_assign(df2)[5])
    print("ACCURACY OF THE MODEL: ", metrics.accuracy_score(y_test, y_pred))

    fig, ax = plt.subplots(figsize=(3, 3),dpi = 160)
    plt.rc('font', family='Arial')
    disp = ConfusionMatrixDisplay.from_predictions(y_test, y_pred, display_labels=['Healthy', 'PR8'],cmap=plt.cm.Greys, ax=ax)
    plt.rcParams.update({'font.size': 11})
    
    label_font = {'size':'12'}  # Adjust to fit
    ax.set_xlabel('Predicted labels', fontdict=label_font, labelpad=5.0);
    ax.set_ylabel('True labels', fontdict=label_font, labelpad=5.0);

In [None]:
#Cross validation by k-fold
def cross_val(classifier, X, y, n_splits):
    #X is the nanosnsor signal from the training set
    #y is the class label from the training set
    
    cv = KFold(n_splits=n_splits, random_state=42, shuffle=True)

    tprs = []
    aucs = []
    mean_fpr = np.linspace(0, 1, 100)

    plt.rc('font', family='Arial')

    fig, ax = plt.subplots(figsize = (3, 3), dpi=160)
    
    for fold, (train, test) in enumerate(cv.split(X, y)):        
        classifier.fit(X[train], y[train])
        
        if any(i != 1 for i in y[test]) and any(i != 0 for i in y[test]): 
            viz = RocCurveDisplay.from_estimator(
                classifier,
                X[test],
                y[test],
                name = f"Fold {fold}",
                alpha = 0.4,
                lw = 1,
                ax = ax)

            if math.isnan(viz.roc_auc) == False:
                interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr)
                interp_tpr[0] = 0.0
                tprs.append(interp_tpr)
                aucs.append(viz.roc_auc)

    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = auc(mean_fpr, mean_tpr)
    std_auc = np.std(aucs)
    
    ax.plot(
        mean_fpr,
        mean_tpr,
        color="black",
        label=r"Mean (AUC = %0.2f $\pm$ %0.2f)" % (mean_auc, std_auc),
        lw=1.5,
        alpha=1.0)

    ax.set_xlabel('1 - Specificity', fontsize=12)
    ax.set_ylabel('Sensitivity', fontsize=12)

    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.legend(loc='lower right', fontsize=5)
    plt.show()
    
#references: 
#https://scikit-learn.org/stable/modules/cross_validation.html
#https://stackoverflow.com/questions/46598301/how-to-compute-precision-recall-and-f1-score-of-an-imbalanced-dataset-for-k-fold
#https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html