In [10]:
import pandas as pd

scop_metadata = pd.read_csv("/scratch/gpfs/jr8867/main/db/family-split-train-test/test/test_metadata.csv", index_col=0)
blast_results = pd.read_csv("/scratch/gpfs/jr8867/main/db/fasta/results.tsv", sep="\t", header=None)
blast_results.columns = ["qseqid", "sseqid", "pident", "length", "evalue", "bitscore"]

In [17]:
scop_metadata

Unnamed: 0_level_0,uid,fa,sf,fold,seq
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,P09147,4000088,3000038,2000148,MRVLVTGGSGYIGSHTCVQLLQNGHDVIILDNLCNSKRSVLPVIER...
5,P00327,4000010,3000040,2000005,MSTAGKVIKCKAAVLWEEKKPFSIEEVEVAPPKAHEVRIKMVATGI...
10,P0A9B2,4000077,3000043,2000005,MTIKVGINGFGRIGRIVFRAAQKRSDIEIVAINDLLDADYMAYMLK...
13,P31116,4000093,3000043,2000005,MSTKVVNVAVIGAGVVGSAFLDQLLAMKSTITYNLVLLAEAERSLI...
21,P0A9T0,4000051,3000006,2000014,MAKVSLEKDKIKFLLVEGVHQKALESLRAAGYTNIEFHKGALDDEQ...
...,...,...,...,...,...
35972,P20585,4004015,3000587,2001251,MSRRKPASGGLAASSSAPARQAVLSRFFQSTGSLKSTSSSTGAADQ...
35973,P20585,4004015,3002020,2001251,MSRRKPASGGLAASSSAPARQAVLSRFFQSTGSLKSTSSSTGAADQ...
35974,P52701,4004015,3001688,2001251,MSRQSTLYSFFPKSPALSDANKASARASREGGRAAAAPGASPSPGG...
35975,P52701,4004015,3000587,2001251,MSRQSTLYSFFPKSPALSDANKASARASREGGRAAAAPGASPSPGG...


In [7]:
blast_results = blast_results[blast_results["qseqid"] != blast_results["sseqid"]]
blast_results

Unnamed: 0,qseqid,sseqid,pident,length,evalue,bitscore
1,1,18437,52.493,341,5.030000e-131,376.0
2,1,23621,53.892,334,5.630000e-127,365.0
3,1,23622,49.575,353,5.730000e-118,343.0
4,1,18156,49.271,343,5.080000e-111,336.0
5,1,18436,38.860,386,1.970000e-71,225.0
...,...,...,...,...,...,...
294221,35976,35005,30.645,62,5.100000e+00,29.6
294222,35976,2142,30.645,62,5.100000e+00,29.6
294223,35976,12083,31.373,51,5.900000e+00,27.7
294224,35976,4466,25.000,104,6.700000e+00,28.9


In [27]:
# Convert Blast to binary classification
def classify_blast_hit(row):
    qseqid = row["qseqid"]
    sseqid = row["sseqid"]
    
    try:
        q_sf = scop_metadata.loc[qseqid]["sf"]
        s_sf = scop_metadata.loc[sseqid]["sf"]
        
        q_fa = scop_metadata.loc[qseqid]["fa"]
        s_fa = scop_metadata.loc[sseqid]["fa"]
        
        if q_sf == s_sf or q_fa == s_fa:
            return 1
        else:
            return 0
    except KeyError:
        print("error")
        return 0  # Handle cases where the ID is not in scop_metadata

blast_results["binary_classification"] = blast_results.apply(classify_blast_hit, axis=1)
blast_results["bitscore"] = blast_results["bitscore"].astype(float)

evaluation_df = blast_results[["binary_classification", "bitscore"]]
evaluation_df


Unnamed: 0,binary_classification,bitscore
0,1,704.0
1,1,376.0
2,1,365.0
3,1,343.0
4,1,336.0
...,...,...
294221,0,29.6
294222,0,29.6
294223,0,27.7
294224,0,28.9


In [29]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve
import numpy as np

# ----------------------
# Generate ROC curve
# ----------------------

# Compute ROC curve and ROC area
labels = evaluation_df["binary_classification"].to_numpy()
scores = np.log(evaluation_df["bitscore"].to_numpy())

fpr, tpr, _ = roc_curve(labels, scores)
roc_auc = auc(fpr, tpr)

# Plot ROC curve
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.savefig('roc_curve_blast.png')
plt.close()

print("Generated ROC curve.")

# ----------------------
# Generate PR curve
# ----------------------

# Compute PR curve and PR area
precision, recall, _ = precision_recall_curve(labels, scores)
pr_auc = auc(recall, precision)

# Plot PR curve
plt.figure(figsize=(8, 6))
plt.plot(recall, precision, color='blue', lw=2, label=f'PR curve (AUC = {pr_auc:.2f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('PR Curve')
plt.legend(loc="upper right")
plt.savefig('pr_curve_blast.png')
plt.close()

print("Generated PR curve.")

# ------------------------------------------------------------------------------------------------
# Generate Histogram of scores by label
# ------------------------------------------------------------------------------------------------  

# Generate histogram of scores by label
plt.figure(figsize=(10, 6))
h1 = plt.hist(scores[labels == 1], bins=100, color='blue', alpha=0.5, label='Positive Scores', density=True)
h2 = plt.hist(scores[labels == 0], bins=100, color='orange', alpha=0.5, label='Negative Scores', density=True)
plt.xlabel('Scores')
plt.ylabel('Density')
plt.title('Distribution of Scores by Label')
plt.legend()

# Trace lines through the top of the histogram
max_height = max(max(h1[0]), max(h2[0]))
plt.plot(h1[1][:-1], h1[0], color='blue', linewidth=2)
plt.plot(h2[1][:-1], h2[0], color='orange', linewidth=2)

plt.savefig('score_histogram_blast.png')
plt.close()

print("Generated score histogram.")


Generated ROC curve.
Generated PR curve.
Generated score histogram.


In [31]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

def analyze_threshold_train(labels, scores, num_thresholds=100):
    """
    Analyzes different thresholds to find the optimal one based on F1-score.

    Args:
        labels (np.array): Ground truth labels.
        scores (np.array): Similarity scores.
        num_thresholds (int): Number of thresholds to test.

    Returns:
        tuple: Optimal threshold and corresponding F1-score.
    """
    thresholds = np.linspace(np.min(scores), np.max(scores), num_thresholds)
    best_threshold = None
    best_f1 = 0

    for threshold in thresholds:
        y_pred = (scores >= threshold).astype(int)
        cm = confusion_matrix(labels, y_pred)
        
        if cm.shape != (2, 2):
            continue
        
        tn, fp, fn, tp = cm.ravel()
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold

    return best_threshold

def plot_confusion_matrix(labels, scores, threshold, filename):
    """Plots the confusion matrix as a heatmap."""
    y_pred = (scores >= threshold).astype(int)
    cm = confusion_matrix(labels, y_pred)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Negative', 'Positive'], 
                yticklabels=['Negative', 'Positive'])
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title(f'BLAST Optimal Confusion Matrix')
    plt.savefig(filename)
    plt.close()

# Analyze Blast Results
print("Analyzing BLAST Results...")
best_threshold = analyze_threshold_train(labels, scores)
print(f"Best Threshold: {best_threshold}")
plot_confusion_matrix(labels, scores, best_threshold, 'confusion_matrix_blast.png')
print("Generated confusion matrix.")


Analyzing BLAST Results...
Best Threshold: 3.408007923909017
Generated confusion matrix.
