In [1]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import classification_report, confusion_matrix

import os
from dotenv import load_dotenv
from data_manipulation.reading_util import load_ml_data_emb
from data_manipulation.reading_util import load_non_enz_esm2

load_dotenv()



2023-09-30 14:00:29.142617: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-09-30 14:00:29.385510: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-09-30 14:00:29.388783: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


True

In [2]:
def plot_report(report, y, predictions):
    """
    Plots results of model
    :param report: Report of model
    """

    class_0_metrics = report.split('\n')[2].split()[1:]
    class_1_metrics = report.split('\n')[3].split()[1:]
    class_2_metrics = report.split('\n')[4].split()[1:]
    class_3_metrics = report.split('\n')[5].split()[1:]
    class_4_metrics = report.split('\n')[6].split()[1:]
    class_5_metrics = report.split('\n')[7].split()[1:]
    class_6_metrics = report.split('\n')[8].split()[1:]
    class_7_metrics = report.split('\n')[9].split()[1:]

    metrics = [class_0_metrics,
               class_1_metrics,
               class_2_metrics,
               class_3_metrics,
               class_4_metrics,
               class_5_metrics,
               class_6_metrics,
               class_7_metrics
               ]

    precs = []
    recs = []
    f1_s = []

    for class_m in metrics:
        precision = float(class_m[0])
        recall = float(class_m[1])
        f1_score = float(class_m[2])
        precs.append(precision)
        recs.append(recall)
        f1_s.append(f1_score)

    # class_names = [1, 2, 3, 4, 5, 6, 7, 0]
    class_names = [
        "Class 1",
        "Class 2",
        "Class 3",
        "Class 4",
        "Class 5",
        "Class 6",
        "Class 7",
        "Non Enzyme"
    ]

    # Create subplots for accuracy, precision, and F1
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

    # Plot accuracy for each class
    ax1.bar(class_names, recs)

    ax1.xticks(rotation=45)
    ax1.set_title("Recall")
    ax1.set_xlabel("Main Class")


    # Plot precision for each class
    ax2.bar(class_names, precs)
    ax2.xticks(rotation=45)
    ax2.set_title("Precision")
    ax2.set_xlabel("Main Class")

    # Plot F1 score for each class
    ax3.bar(class_names, f1_s)
    ax3.xticks(rotation=45)
    ax3.set_title("F1 Score")
    ax3.set_xlabel("Main Class")

    # Adjust layout
    plt.tight_layout()

    # Show the plots
    plt.show()


    conf_matrix = confusion_matrix(y, predictions)

    # Create a confusion matrix heatmap
    plt.figure(figsize=(10, 7))
    sns.set(font_scale=1.2)  # Adjust font size as needed
    sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix")
    plt.show()


In [3]:
def validate_on_test_data(path_to_test_csv: str, path_to_test_esm2: str, 
                          path_to_non_ez_fasta: str, path_to_non_ez_esm2: str, 
                          path_to_cnn_multiclass_model, path_to_cnn_binaryclass_model):
    """
    Easy way to test a model on test dataset
    :param path_to_test_csv: Path to test.csv
    :param path_to_test_esm2: Path to test_esm2.h5
    :param path_to_non_ez_fasta: Path to non_ez fasta
    :param path_to_non_ez_esm2: Path to non_ez esm2
    :param path_to_cnn_multiclass_model: Path to cnn classifying 1st ec class
    :param path_to_cnn_binaryclass_model: Path to binary classification cnn
    """

    # Load model and test data
    multiclass__model = tf.keras.models.load_model(path_to_cnn_multiclass_model)
    binaryclass_model = tf.keras.models.load_model(path_to_cnn_binaryclass_model)


    X_enzymes, y_enzymes = load_ml_data_emb(path_to_esm2=path_to_test_esm2, path_to_enzyme_csv=path_to_test_csv)
    # X_enzymes = X_enzymes[0:100]
    # y_enzymes = y_enzymes[0:100]


    # Load non_enzymes
    X_non_enzymes, y_non_enzymes = load_non_enz_esm2(non_enzymes_fasta_path=path_to_non_ez_fasta,
                                                     non_enzymes_esm2_path=path_to_non_ez_esm2)
    # X_non_enzymes = X_non_enzymes[0:100]
    # y_non_enzymes = y_non_enzymes[0:100]
    
    # Combine data
    X = np.vstack((X_enzymes, X_non_enzymes))
    y = np.hstack((y_enzymes, y_non_enzymes))


    print("\n\n============================================= PREDICTING TEST DATA =============================================\n\n")

    complete_predictions = []

    # in this for loop we use the binaryclassifier in order to decide if we have an enzyme or not. If yes, we pass the embedding into
    # the multiclassification cnn in order to predict its main class
    # Labels: 0-6 → main ec classes; 7 → non_enzymes
    for index, prediction in enumerate((binaryclass_model.predict(X) > 0.5).astype(int)):
        if prediction == np.array(1):
            print("Multiclass")
            print(index, prediction)
            ec_class_pred = np.argmax(multiclass__model.predict(X[index:index+1]), axis=-1)
            print(ec_class_pred, "will be appended to predictions", type(ec_class_pred))
            complete_predictions.append(ec_class_pred)
        else:
            print("Non Enzyme")
            print(index, prediction)
            print("Will append 7 to complete predictions")
            complete_predictions.append(np.array([7]))

    complete_predictions = np.array(complete_predictions)
    print(complete_predictions)
            
            
    
        
    # np.argmax(model.predict(X), axis=-1)) # for multi class classification (using softmax)

    print("LOG: Making predictions: DONE")

    report = classification_report(y, complete_predictions)

    print(report)
    
    plot_report(report, y, complete_predictions)

    # Evaluation

In [4]:
bin_class_path = "/home/malte/01_Documents/projects/pbl_binary_classifier/tf_cnn_esm2/enzyme_non_enzyme_models/cnn_binary_v2_opt_S70.keras"
mul_class_path = "/home/malte/01_Documents/projects/pbl_binary_classifier/tf_cnn_esm2/only_enzyme_models/cnn_v5_1_split100.keras"

emb_path = os.getenv("ESM2_ENZYMES_SPLIT_10")
csv_path = os.getenv("CSV10_ENZYMES")

fasta_path = os.getenv("FASTA_NON_ENZYMES")
emb_non_ez_path = os.getenv("ESM2_NON_ENZYMES")

In [None]:
validate_on_test_data(path_to_test_csv=csv_path,
                      path_to_test_esm2=emb_path,
                      path_to_non_ez_esm2=emb_non_ez_path,
                      path_to_non_ez_fasta=fasta_path,
                      path_to_cnn_binaryclass_model=bin_class_path,
                      path_to_cnn_multiclass_model=mul_class_path)

LOG: 3 Sequences with aa O in /home/malte/Desktop/Dataset/data/enzymes/csv/split10.csv
LOG: 12 Sequences with aa U in /home/malte/Desktop/Dataset/data/enzymes/csv/split10.csv
LOG: 166 multifunctional enzymes with diff ec main classes in /home/malte/Desktop/Dataset/data/enzymes/csv/split10.csv
LOG: 181 entries will be ignored
LOG: Data loaded in: 0.797 min
LOG: ESM2 of enzymes: 7212
LOG: Labels of enzymes: 7212
LOG: 0 Sequences with aa O in /home/malte/Desktop/Dataset/data/non_enzyme/fasta/no_enzyme_train.fasta
LOG: 17 Sequences with aa U in /home/malte/Desktop/Dataset/data/non_enzyme/fasta/no_enzyme_train.fasta
LOG: 2138 non enzymes are longer than 1022 cutoff
LOG: 2155 entries will be ignored
LOG: Non Enzymes data loaded in: 4.469 min
LOG: ESM2 of non enzymes: 39502
LOG: Labels of non enzymes: 39502



Multiclass
0 [1]
[2] will be appended to predictions <class 'numpy.ndarray'>
Multiclass
1 [1]
[2] will be appended to predictions <class 'numpy.ndarray'>
Multiclass
2 [1]
[1] will be ap