In [None]:
import os
from dotenv import load_dotenv
import pandas as pd
from utils import enzyme_split30_preprocessing, read_h5, apply_embedding, read_fasta

from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, matthews_corrcoef

In [None]:
load_dotenv()
path_to_non_enzymes = os.getenv("FASTA_NON_ENZYMES")
path_to_csv_split30 = os.getenv("CSV30_ENZYMES", "not found")

path_to_esm2_ne = os.getenv("ESM2_NON_ENZYMES", "not found")
path_to_esm2 = os.getenv("ESM2_ENZYMES_SPLIT_30", "not found")
path_to_csv_split30_esm2 = os.getenv("CSV30_ENZYMES_ESM25_APPLIED")
path_to_non_enzymes_esm2 = os.getenv("NON_ENZYMES_ESM2_APPLIED")

In [None]:
enzymes = enzyme_split30_preprocessing(pd.read_csv(path_to_csv_split30, delimiter=","))
enzymes.head()

In [None]:
enzymes = apply_embedding(read_h5(path_to_esm2, False), enzymes)
enzymes.head()

In [None]:
non_enzymes = read_fasta(path_to_non_enzymes)
non_enzymes.head()

In [None]:
non_enzymes = apply_embedding(read_h5(path_to_esm2_ne, False), non_enzymes)
non_enzymes.head()

In [None]:
import numpy as np
# Split data
enzymes["Label"] = 1
non_enzymes["Label"] = 0

bin = pd.concat([enzymes[["Label", "Embedding"]], non_enzymes[["Label", "Embedding"]]], ignore_index=True)

bin = bin.sample(frac=1, random_state=42).reset_index(drop=True)

X = [value for value in bin["Embedding"]]
y = bin["Label"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
y_test = np.array(y_test)

In [None]:
import math
def bootstrap_statistic(y_true, y_pred, statistic_func, B=10_000, alpha=0.05):
    bootstrap_scores = []
    for _ in range(B):
        indices = np.random.choice(len(y_pred), len(y_pred), replace=True)
        try:
            resampled_pred = y_pred[indices]
            resampled_true = y_true[indices]
            score = statistic_func(resampled_true, resampled_pred)
            bootstrap_scores.append(score)
        except:
            #print("Key error for " + str(indices))
            continue

    print(bootstrap_scores)
    mean_score = np.mean(bootstrap_scores)
    standard_error = np.std(bootstrap_scores, ddof=1)

    # Calculate the 95% confidence interval
    lower_bound = np.percentile(bootstrap_scores, (alpha / 2) * 100)
    upper_bound = np.percentile(bootstrap_scores, (1 - alpha / 2) * 100)

    return mean_score, standard_error, (lower_bound, upper_bound)

def calculate_f1(y_true, y_pred):
    return f1_score(y_true, y_pred, average='micro')

def round_to_significance(x, significance):
    if significance == 0.0:
        sig_position = 0
    else:
        sig_position = int(math.floor(math.log10(abs(significance))))
    return round(x, -sig_position), round(significance, -sig_position + 1)

In [None]:
k = 7
knn_classifier = KNeighborsClassifier(n_neighbors=k)

# Fit the classifier to the training data
knn_classifier.fit(X_train, y_train)

# Make predictions on the test data
y_pred = knn_classifier.predict(X_test) # y_pred are predicted labels for embeddings at index


In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
cm = confusion_matrix(y_test, y_pred)
# Define custom colors (e.g., green and purple)
colors = ['yellow', 'purple']

# Create a colormap using custom colors
cmap = mcolors.ListedColormap(colors)

# Plot the confusion matrix heatmap
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.colorbar()
plt.xticks([0, 1], ["Predicted 0", "Predicted 1"])
plt.yticks([0, 1], ["Actual 0", "Actual 1"])

plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')

for i in range(2):
    for j in range(2):
        plt.text(j, i, str(cm[i, j]), ha='center', va='center', color='white', fontsize=16)

plt.show()

In [None]:
initial_f1 = calculate_f1(np.array(y_test), y_pred)
mean_f1, se_f1, ci_95 = bootstrap_statistic(y_test, y_pred, calculate_f1)
rounded_mean_f1, rounded_se_f1 = round_to_significance(mean_f1, se_f1)

In [None]:
print(f"ESM2 KNN:")
print(f"  - Accuracy: {accuracy_score(y_test, y_pred)}")
print(f"  - Initial F1 Score: {initial_f1:.2f}")
print(f"  - MCC: {matthews_corrcoef(y_test, y_pred)}")
print(f"  - Mean F1 ± SE F1: {rounded_mean_f1} ± {rounded_se_f1}")
print(f"  - 95% CI: [{ci_95[0]:.2f}, {ci_95[1]:.2f}]")

In [None]:
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
# roc curve
fpr, tpr, thresholds = roc_curve(y_test, y_pred)

# Calculate the AUC (Area Under the ROC Curve)
roc_auc = auc(fpr, tpr)
# Plot the ROC curve
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve for KNN (k=7)')
plt.legend(loc='lower right')
plt.show()