In [4]:
#!pip install tensorflow
#!pip install sdv

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import sdv
import joblib
from joblib import dump, load

from sdv.tabular import GaussianCopula, CTGAN, CopulaGAN, TVAE
from sdv.sampling import Condition
from sdv.evaluation import evaluate

from numpy import random
from matplotlib.pyplot import figure
from sklearn.preprocessing import MinMaxScaler, RobustScaler
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV, KFold
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn import metrics
from tensorflow import keras
from tensorflow.keras import layers, losses
from tensorflow.keras.models import Model, Sequential, save_model, load_model
from tensorflow.keras.layers import Dense, Input, Conv1D, Activation, Reshape, Flatten, Dropout, MaxPooling1D
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score, roc_curve

def which_taxo(file):

    if (file.find("a_") != -1):
        taxo = pd.read_csv("test/taxoS_test.csv")
        nom = "CRC S"
    elif (file.find("b_") != -1):
        taxo = pd.read_csv("test/taxoS1_test.csv")
        nom = "CRC S1"
    elif (file.find("k_") != -1):
        taxo = pd.read_csv("test/taxoS_test.csv")
        nom = "CRC + Regió S"
    elif (file.find("l_") != -1):
        taxo = pd.read_csv("test/taxoS1_test.csv")
        nom = "CRC + Regió S1"
    elif (file.find("m_") != -1):
        taxo = pd.read_csv("test/taxoS_test.csv")
        nom = "CRC + Seq_tool S"
    elif (file.find("n_") != -1):
        taxo = pd.read_csv("test/taxoS1_test.csv")
        nom = "CRC + Seq_tool S1"
    elif (file.find("o_") != -1):
        taxo = pd.read_csv("test/taxoS_test.csv")
        nom = "CRC + Regió + Seq_tool S"
    elif (file.find("p_") != -1):
        taxo = pd.read_csv("test/taxoS1_test.csv")
        nom = "CRC + Regió + Seq_tool S1"
    return taxo, nom

def which_meta_extra(file):
    if (file.find("m_") != -1) or (file.find("n_") != -1):
        meta_extra = ["seq_tool"]
    
    elif (file.find("o_") != -1) or (file.find("p_") != -1):
        meta_extra = ["region","seq_tool"]
        
    elif (file.find("k_") != -1) or (file.find("l_") != -1):
        meta_extra = ["region"]
        
    else:
        meta_extra = []
    return meta_extra

def filt_y_rows(taxo,meta):
    condition = []
    for i in range(len(taxo)):
        condition.append(pd.DataFrame(meta["condition"][meta["sample"] == 
                                                        taxo["sample"][i]]).iloc[0,0])
    return condition

def add_metas(taxo,meta,meta_extra=[]):
    for i in meta_extra:
        meta_extra_i = []
        for j in range(len(taxo)):
            meta_extra_i.append(pd.DataFrame(meta[i][meta["sample"] ==
                                          taxo["sample"][j]]).iloc[0,0])
        taxo[i] = meta_extra_i
    return taxo

def drop_nas(taxo,meta):
    taxo["condition"] = meta
    taxo = taxo.dropna()
    meta = taxo["condition"]
    taxo = taxo.iloc[:,1:-1]
    return taxo, meta

def get_roc_curves(y_test, y_prob, pos_label):
    fper1, tper1, tresholds1 = roc_curve(y_test, y_prob, pos_label=pos_label)
    fig,base = plt.subplots()
    base.set_title("Corba ROC" + " " + nomi)
    base.set_xlabel("Especificitat")
    base.set_ylabel("Sensibilitat")
    base.plot(fper1, tper1)
    base.plot([0,1],[0,1],'-')
    fig.savefig("metrics/" + file + '_roc_curve.png')
    plt.close()

def get_test_metrics(file, model, x_test, y_test, labels):
    y_pred = model.predict(x_test)
    if (file.find(".h5") == -1):
        y_prob = model.predict_proba(x_test)[:,1]
    else:
        y_prob = y_pred
    y_pred = pd.DataFrame(np.around(y_pred,0).astype(int))
    y_pred.replace([0,1], labels, inplace=True)
    tn, fp, fn, tp = confusion_matrix(y_test, y_pred, labels = labels).ravel()
    acc = accuracy_score(y_test, y_pred)
    sens = tp / (tp+fn)
    espe = tn / (tn+fp)
    AUC = 1 - roc_auc_score(y_test, y_prob)
    get_roc_curves(y_test, y_prob, labels[1])
    return acc, sens, espe, AUC
    
def encode_data(x_test, meta_extra, autoencoder):
        if len(meta_extra) == 0:
            x_test_no_meta = x_test
        else:
            x_test_no_meta = x_test.iloc[:,0:(-len(meta_extra))]
        x_test_no_meta = x_test_no_meta.to_numpy()
        x_test_no_meta = np.expand_dims(x_test_no_meta, axis=2)
        x_test_no_meta = x_test_no_meta.astype('float32')
        encoded_data = autoencoder.encoder(x_test_no_meta).numpy()
        encoded_data = pd.DataFrame(encoded_data)
        
        if len(meta_extra) == 2:
            encoded_data["region"] = x_test.reset_index(drop=True).iloc[:,-1]
            encoded_data["seq_tool"] = x_test.reset_index(drop=True).iloc[:,-2]
        elif len(meta_extra) == 1:
            if meta_extra == ["region"]:
                encoded_data["region"] = x_test.reset_index(drop=True).iloc[:,-1]
            else:
                encoded_data["seq_tool"] = x_test.reset_index(drop=True).iloc[:,-1]

        scaler = MinMaxScaler()
        scaler.fit(encoded_data)
        encoded_data_scaled = pd.DataFrame(scaler.transform(encoded_data))
        return encoded_data_scaled

In [5]:
taxo = pd.read_csv("test/taxoS1_test.csv")
len(taxo.columns)

715

In [23]:
DirExists = os.path.exists("metrics")
if not DirExists:
    os.makedirs("metrics")
nom = []
nom_model = []
N = []
acc = []
sens = []
espe = []
AUC = []
for file in os.listdir("models/altres_models/"):
    if (file.find("autoencoders") == -1 & file.find("synthetizers") == -1 & file.find("params") == -1 & file.find("checkpoints")==-1):
        print(file)
        x_test, nomi = which_taxo(file)
        y = pd.read_csv("metadades_full.csv")
        meta_extra = which_meta_extra(file)
        y_test = filt_y_rows(x_test, y)
        x_test = add_metas(x_test, y, meta_extra)
        x_test, y_test = drop_nas(x_test, y_test)
        if (file.find("rawdata") == -1):
            autoencoder_filename = "models/altres_models/autoencoders/" + file[0] + "_autoencoder"
            autoencoder = tf.saved_model.load(autoencoder_filename)
            x_test = encode_data(x_test, meta_extra, autoencoder)
            if (file.find("encoded") == -1):
                if (file.find("synthetic") != -1):
                    nomi2 = "(Synthetic)"
                elif (file.find("synreal") != -1):
                    nomi2 = "(Syn+Real)"
            else:
                nomi2 = "(Encoded)"
        else:
            nomi2 = "(Rawdata)"
        nomi = nomi + " " + nomi2
        if (file.find(".joblib") == -1):
            model = load_model("models/altres_models/" + file)
        else:
            model = joblib.load("models/altres_models/" + file)
            
        N.append(len(x_test))
        nom.append(nomi)
        nom_model.append(str(model)[0:10])
        acci, sensi, espei, AUCi =  get_test_metrics(file, model, x_test, y_test, ["Control","CRC"])
        acc.append(acci)
        sens.append(sensi)
        espe.append(espei)
        AUC.append(AUCi)

results = pd.DataFrame()
results["Nom"] = nom
results["n (test)"] = N
results["Model"] = nom_model
results["Exactitud"] = acc
results["Sensibilitat"] = sens
results["Especificitat"] = espe
results["AUC"] = AUC

results.to_csv("metrics/metriques_models_alternatius.csv", index=False)

print("metrics done")

a_encoded.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


a_rawdata.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


a_synreal.joblib
a_synthetic.h5
b_encoded.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


b_rawdata.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


b_synreal.h5
b_synthetic.h5
k_encoded.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


k_rawdata.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


k_synreal.joblib




k_synthetic.h5




l_encoded.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


l_rawdata.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


l_synreal.joblib




l_synthetic.h5




m_encoded.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


m_rawdata.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


m_synreal.joblib




m_synthetic.h5




n_encoded.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


n_rawdata.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


n_synreal.h5




n_synthetic.h5




o_encoded.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


o_rawdata.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


o_synreal.joblib




o_synthetic.h5




p_encoded.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


p_rawdata.joblib


https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


p_synreal.h5




p_synthetic.h5




metrics done


In [24]:
results[results["Nom"].str.contains('Rawdata')]

Unnamed: 0,Nom,n (test),Model,Exactitud,Sensibilitat,Especificitat,AUC
1,CRC S (Rawdata),45,KNeighbors,0.622222,0.782609,0.454545,0.615613
5,CRC S1 (Rawdata),45,RandomFore,0.533333,0.347826,0.727273,0.567194
9,CRC + Regió S (Rawdata),45,KNeighbors,0.622222,0.782609,0.454545,0.625494
13,CRC + Regió S1 (Rawdata),45,KNeighbors,0.711111,0.565217,0.863636,0.712451
17,CRC + Seq_tool S (Rawdata),33,KNeighbors,0.636364,0.789474,0.428571,0.723684
21,CRC + Seq_tool S1 (Rawdata),33,KNeighbors,0.454545,0.315789,0.642857,0.503759
25,CRC + Regió + Seq_tool S (Rawdata),33,KNeighbors,0.545455,0.684211,0.357143,0.697368
29,CRC + Regió + Seq_tool S1 (Rawdata),33,KNeighbors,0.484848,0.421053,0.571429,0.530075


In [25]:
results[results["Nom"].str.contains('Encoded')]

Unnamed: 0,Nom,n (test),Model,Exactitud,Sensibilitat,Especificitat,AUC
0,CRC S (Encoded),45,KNeighbors,0.777778,0.956522,0.590909,0.741107
4,CRC S1 (Encoded),45,KNeighbors,0.577778,0.565217,0.590909,0.613636
8,CRC + Regió S (Encoded),45,KNeighbors,0.711111,0.913043,0.5,0.731225
12,CRC + Regió S1 (Encoded),45,KNeighbors,0.533333,0.608696,0.454545,0.582016
16,CRC + Seq_tool S (Encoded),33,"SVC(C=1, g",0.545455,0.684211,0.357143,0.571429
20,CRC + Seq_tool S1 (Encoded),33,KNeighbors,0.515152,0.578947,0.428571,0.541353
24,CRC + Regió + Seq_tool S (Encoded),33,"SVC(C=1, g",0.575758,0.684211,0.428571,0.509398
28,CRC + Regió + Seq_tool S1 (Encoded),33,KNeighbors,0.545455,0.947368,0.0,0.50188


In [26]:
results[results["Nom"].str.contains('Synthetic')]

Unnamed: 0,Nom,n (test),Model,Exactitud,Sensibilitat,Especificitat,AUC
3,CRC S (Synthetic),45,<keras.eng,0.622222,0.608696,0.636364,0.705534
7,CRC S1 (Synthetic),45,<keras.eng,0.533333,0.391304,0.681818,0.628458
11,CRC + Regió S (Synthetic),45,<keras.eng,0.622222,0.652174,0.590909,0.715415
15,CRC + Regió S1 (Synthetic),45,<keras.eng,0.488889,0.0,1.0,0.715415
19,CRC + Seq_tool S (Synthetic),33,<keras.eng,0.515152,0.473684,0.571429,0.503759
23,CRC + Seq_tool S1 (Synthetic),33,<keras.eng,0.636364,0.842105,0.357143,0.507519
27,CRC + Regió + Seq_tool S (Synthetic),33,<keras.eng,0.545455,0.578947,0.5,0.526316
31,CRC + Regió + Seq_tool S1 (Synthetic),33,<keras.eng,0.636364,0.684211,0.571429,0.710526


In [27]:
results[results["Nom"].str.contains('Real')]

Unnamed: 0,Nom,n (test),Model,Exactitud,Sensibilitat,Especificitat,AUC
2,CRC S (Syn+Real),45,KNeighbors,0.666667,0.826087,0.5,0.747036
6,CRC S1 (Syn+Real),45,<keras.eng,0.6,0.652174,0.545455,0.63834
10,CRC + Regió S (Syn+Real),45,KNeighbors,0.688889,0.782609,0.590909,0.697628
14,CRC + Regió S1 (Syn+Real),45,KNeighbors,0.6,0.608696,0.590909,0.628458
18,CRC + Seq_tool S (Syn+Real),33,KNeighbors,0.454545,0.631579,0.214286,0.43985
22,CRC + Seq_tool S1 (Syn+Real),33,<keras.eng,0.575758,1.0,0.0,0.466165
26,CRC + Regió + Seq_tool S (Syn+Real),33,KNeighbors,0.606061,0.789474,0.357143,0.556391
30,CRC + Regió + Seq_tool S1 (Syn+Real),33,<keras.eng,0.575758,0.894737,0.142857,0.571429
