In [3]:
import ast
import math
import pickle
import random
import json
import warnings
from collections import Counter, deque

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, TensorDataset
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from scipy.stats import median_abs_deviation
from sklearn.metrics import auc

from sklearn.exceptions import UndefinedMetricWarning
from sklearn.metrics import (
    accuracy_score, average_precision_score, f1_score,
    precision_score, recall_score, roc_auc_score, classification_report,
    precision_recall_curve
)
from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler

from iterstrat.ml_stratifiers import (
    MultilabelStratifiedKFold, MultilabelStratifiedShuffleSplit
)




In [30]:
ia_df = pd.read_csv(r'IA_all.tsv', sep='\t', header=None)
ia_df.columns = ['GO', 'IA']

ic_dict = dict(zip(ia_df['GO'], ia_df['IA']))

In [5]:
import os 
os.chdir('..')

In [6]:
def process_GO_data(file_path):
    """
    Processes GO data by reading a TSV file, parsing the 'Propagated GO terms filtered' column,
    removing entries with empty GO terms, and updating the embeddings.

    Parameters:
    - file_path: str, path to the TSV file containing GO data.

    Returns:
    - GO_df: pd.DataFrame, the processed GO DataFrame.
    - GO_list: list, list of 'Propagated GO terms filtered' for each entry.
    - GO_annotated: list, list of 'Raw GO terms' for each entry.
    """
    GO_df = pd.read_csv(file_path, sep='\t')
    GO_df['Raw propagated GO terms'] = GO_df['Raw propagated GO terms'].apply(ast.literal_eval)
    GO_df = GO_df.reset_index(drop=True)
    indices_to_remove = GO_df[GO_df['Raw propagated GO terms'].apply(lambda x: len(x) == 0)].index
    GO_df = GO_df[GO_df['Raw propagated GO terms'].apply(lambda x: len(x) != 0)]
    GO_df = GO_df.reset_index(drop=True)  # Reset index after removal
    GO_list = GO_df['Raw propagated GO terms'].tolist()
    GO_annotated = GO_df['Raw GO terms'].tolist()
    
    return GO_df, GO_list, GO_annotated

In [9]:
plasmodium_MF_df, plasmodium_MF_GO_list, plasmodium_MF_GO_annotated = process_GO_data( r'processed_data_90_30/Plasmodium_MA_function.tsv')
plasmodium_CC_df, plasmodium_CC_GO_list, plasmodium_CC_GO_annotated = process_GO_data( r'processed_data_90_30/Plasmodium_MA_component.tsv')
plasmodium_BP_df, plasmodium_BP_GO_list, plasmodium_BP_GO_annotated = process_GO_data( r'processed_data_90_30/Plasmodium_MA_process.tsv')

test_MF_df, test_MF_GO_list, test_MF_GO_annotated = process_GO_data( r'processed_data_90_30/function_test.tsv')
test_CC_df, test_CC_GO_list, test_CC_GO_annotated = process_GO_data( r'processed_data_90_30/component_test.tsv')
test_BP_df, test_BP_GO_list, test_BP_GO_annotated = process_GO_data( r'processed_data_90_30/process_test.tsv')

In [34]:
test_MF_df_dict = {test_MF_df['Entry'][i]: test_MF_GO_list[i] for i in range(len(test_MF_df))}
test_CC_df_dict = {test_CC_df['Entry'][i]: test_CC_GO_list[i] for i in range(len(test_CC_df))}
test_BP_df_dict = {test_BP_df['Entry'][i]: test_BP_GO_list[i] for i in range(len(test_BP_df))}

plasmodium_MF_df_dict = {plasmodium_MF_df['Entry'][i]: plasmodium_MF_GO_list[i] for i in range(len(plasmodium_MF_df))}
plasmodium_CC_df_dict = {plasmodium_CC_df['Entry'][i]: plasmodium_CC_GO_list[i] for i in range(len(plasmodium_CC_df))}
plasmodium_BP_df_dict = {plasmodium_BP_df['Entry'][i]: plasmodium_BP_GO_list[i] for i in range(len(plasmodium_BP_df))}

#convert to sets 
plasmodium_MF_df_dict = {k: set(v) for k, v in plasmodium_MF_df_dict.items()}
plasmodium_CC_df_dict = {k: set(v) for k, v in plasmodium_CC_df_dict.items()}
plasmodium_BP_df_dict = {k: set(v) for k, v in plasmodium_BP_df_dict.items()}

test_MF_df_dict = {k: set(v) for k, v in test_MF_df_dict.items()}
test_CC_df_dict = {k: set(v) for k, v in test_CC_df_dict.items()}
test_BP_df_dict = {k: set(v) for k, v in test_BP_df_dict.items()}

In [20]:
blast_function_on_train = pickle.load(open('blast_results/blast_function_test_on_train_pred_annots_dict.pkl', 'rb'))
blast_component_on_train = pickle.load(open('blast_results/blast_component_test_on_train_pred_annots_dict.pkl', 'rb'))
blast_process_on_train = pickle.load(open('blast_results/blast_process_test_on_train_pred_annots_dict.pkl', 'rb'))

blast_function_on_swissprot = pickle.load(open('blast_results/blast_function_test_on_swissprot_pred_annots_dict.pkl', 'rb'))
blast_component_on_swissprot = pickle.load(open('blast_results/blast_component_test_on_swissprot_pred_annots_dict.pkl', 'rb'))
blast_process_on_swissprot = pickle.load(open('blast_results/blast_process_test_on_swissprot_pred_annots_dict.pkl', 'rb'))

blast_plasmodium_function_on_swissprot = pickle.load(open('blast_results/blast_function_pf_on_swissprot_pred_annots_dict.pkl', 'rb'))
blast_plasmodium_component_on_swissprot = pickle.load(open('blast_results/blast_component_pf_on_swissprot_pred_annots_dict.pkl', 'rb'))
blast_plasmodium_process_on_swissprot = pickle.load(open('blast_results/blast_process_pf_on_swissprot_pred_annots_dict.pkl', 'rb'))

In [25]:
foldseek_function_on_train = pickle.load(open('foldseek_results/foldseek_function_on_train_pred_annots_dict.pkl', 'rb'))
foldseek_component_on_train = pickle.load(open('foldseek_results/foldseek_component_on_train_pred_annots_dict.pkl', 'rb'))
foldseek_process_on_train = pickle.load(open('foldseek_results/foldseek_process_on_train_pred_annots_dict.pkl', 'rb'))

foldseek_function_on_swissprot = pickle.load(open('foldseek_results/foldseek_function_on_swissprot_pred_annots_dict.pkl', 'rb'))
foldseek_component_on_swissprot = pickle.load(open('foldseek_results/foldseek_component_on_swissprot_pred_annots_dict.pkl', 'rb'))
foldseek_process_on_swissprot = pickle.load(open('foldseek_results/foldseek_process_on_swissprot_pred_annots_dict.pkl', 'rb'))

foldseek_plasmodium_function_on_swissprot = pickle.load(open('foldseek_results/foldseek_function_pf_on_swissprot_pred_annots_dict.pkl', 'rb'))
foldseek_plasmodium_component_on_swissprot = pickle.load(open('foldseek_results/foldseek_component_pf_on_swissprot_pred_annots_dict.pkl', 'rb'))
foldseek_plasmodium_process_on_swissprot = pickle.load(open('foldseek_results/foldseek_process_pf_on_swissprot_pred_annots_dict.pkl', 'rb'))

In [26]:
def evaluate_annotations(ic_dict, real_annots_dict, pred_annots_dict):
    """
    Evaluates precision, recall, F1-score, remaining uncertainty (ru),
    and misinformation (mi) using sets of GO terms.
    """
    total = 0
    p_sum = 0.0
    r_sum = 0.0
    p_total = 0
    ru = 0.0
    mi = 0.0
    fps = []
    fns = []
    tp_global, fp_global, fn_global = 0, 0, 0

    common_entries = set(real_annots_dict.keys()).intersection(pred_annots_dict.keys())

    for entry in common_entries:
        real_annots = real_annots_dict[entry]
        pred_annots = pred_annots_dict[entry]  # Now correctly treated as a set

        tp = real_annots.intersection(pred_annots)
        fp = pred_annots - tp
        fn = real_annots - tp

        tp_global += len(tp)
        fp_global += len(fp)
        fn_global += len(fn)

        for go_id in fp:
            if go_id in ic_dict:
                mi += ic_dict[go_id]  # No need to multiply by score since it's a set

        for go_id in fn:
            if go_id in ic_dict:
                ru += ic_dict[go_id]

        fps.append(fp)
        fns.append(fn)
        total += 1

        recall = len(tp) / (len(tp) + len(fn)) if (len(tp) + len(fn)) > 0 else 0
        precision = len(tp) / (len(tp) + len(fp)) if (len(tp) + len(fp)) > 0 else 0

        r_sum += recall
        if len(pred_annots) > 0:
            p_total += 1
            p_sum += precision

    r = r_sum / total if total > 0 else 0
    p = p_sum / p_total if p_total > 0 else 0

    p_micro = tp_global / (tp_global + fp_global) if (tp_global + fp_global) > 0 else 0
    r_micro = tp_global / (tp_global + fn_global) if (tp_global + fn_global) > 0 else 0

    f = 2 * p * r / (p + r) if (p + r) > 0 else 0
    f_micro = 2 * p_micro * r_micro / (p_micro + r_micro) if (p_micro + r_micro) > 0 else 0

    ru /= total
    mi /= total

    s = math.sqrt(ru * ru + mi * mi)

    return f, p, r, s, ru, mi, f_micro, p_micro, r_micro, tp_global, fp_global, fn_global

def _calculate_metrics_at_threshold(ic_dict, real_annots_dict, pred_annots_dict_with_scores, threshold):
    """
    Helper function to calculate metrics at a specific threshold.
    Only predictions with a score >= threshold are kept.
    """
    # Keep only GO terms with scores >= threshold
    filtered_pred_annots_dict = {
        entry: {go_id for go_id, score in go_scores.items() if score >= threshold}
        for entry, go_scores in pred_annots_dict_with_scores.items()
    }

    # Evaluate annotations using the filtered predictions
    f, p, r, s, ru, mi, f_micro, p_micro, r_micro, tp_global, fp_global, fn_global = \
        evaluate_annotations(ic_dict, real_annots_dict, filtered_pred_annots_dict)

    # Calculate coverage
    cov = len([1 for preds in filtered_pred_annots_dict.values() if len(preds) > 0]) / len(real_annots_dict)

    return {
        'n': threshold,  # Store the current threshold
        'tp': tp_global,
        'fp': fp_global,
        'fn': fn_global,
        'pr': p,
        'rc': r,
        'cov': cov,
        'mi': mi,
        'ru': ru,
        'f': f,
        's': s,
        'pr_micro': p_micro,
        'rc_micro': r_micro,
        'f_micro': f_micro,
        'cov_max': cov
    }


def threshold_performance_metrics(ic_dict, real_annots_dict, pred_annots_dict_with_scores, threshold_range=None, set_threshold=None):
    """
    Calculates S-min and F-max over a range of thresholds or at a set threshold.
    """
    if threshold_range is None and set_threshold is None:
        raise ValueError("Either threshold_range or set_threshold must be provided.")

    smin = float('inf')
    fmax = 0
    best_threshold_s = None
    best_threshold_f = None
    s_at_fmax = None
    results = []

    if set_threshold is not None:
        # Calculate metrics for the set threshold only
        result = _calculate_metrics_at_threshold(ic_dict, real_annots_dict, pred_annots_dict_with_scores, set_threshold)
        results.append(result)
    else:
        # Iterate over the threshold range and calculate metrics
        for threshold in tqdm(threshold_range, desc='Calculating Smin & Fmax'):
            result = _calculate_metrics_at_threshold(ic_dict, real_annots_dict, pred_annots_dict_with_scores, threshold)
            results.append(result)

            if result['s'] < smin:
                smin = result['s']
                best_threshold_s = threshold
            if result['f'] > fmax:
                fmax = result['f']
                best_threshold_f = threshold
                s_at_fmax = result['s']

    results_df = pd.DataFrame(results)

    print(f"F-max @ Best Threshold ({best_threshold_f}): {fmax}")
    print(f"S-min @ Best Threshold ({best_threshold_s}): {smin}")
    print(f"S-min @ F-max Threshold ({best_threshold_f}): {s_at_fmax}")

    return smin, fmax, best_threshold_s, best_threshold_f, s_at_fmax, results_df


def calculate_aupr_micro(real_annots_dict, pred_annots_dict_with_scores):
    """
    Calculate AUPR Micro for the entire dataset.
    """
    # Flatten the ground truth and predicted scores into lists
    y_true_flat = []
    y_scores_flat = []

    for entry, go_scores in pred_annots_dict_with_scores.items():
        real_annots = real_annots_dict.get(entry, set())
        
        for go_id, score in go_scores.items():
            y_true_flat.append(1 if go_id in real_annots else 0)
            y_scores_flat.append(score)

    # Calculate precision-recall curve and AUPR micro
    precision, recall, _ = precision_recall_curve(y_true_flat, y_scores_flat)
    return auc(recall, precision)



In [27]:
threshold_range = np.arange(0.01, 1.00, 0.01).astype(np.float32)


#BLAST

In [35]:
blast_func_test_on_train_smin, blast_func_test_on_train_fmax, blast_func_test_on_train_best_threshold_s, blast_func_test_on_train_best_threshold_f, blast_func_test_on_train_s_at_fmax, blast_func_test_on_train_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=test_MF_df_dict,
    pred_annots_dict_with_scores=blast_function_on_train,
    threshold_range=threshold_range
)
print(f"AUPR Micro: {calculate_aupr_micro(test_MF_df_dict, blast_function_on_train)}")

Calculating Smin & Fmax: 100%|██████████| 99/99 [04:48<00:00,  2.91s/it]


F-max @ Best Threshold (0.9900000095367432): 0.6858464271163891
S-min @ Best Threshold (0.9900000095367432): 6.280941047277003
S-min @ F-max Threshold (0.9900000095367432): 6.280941047277003
AUPR Micro: 0.7866795704822008


In [38]:
#save blast_func_test_on_train_results_df
blast_func_test_on_train_results_df.to_csv('baseline models/blast_results/blast_func_test_on_train_fmax_smin_df.csv', index=False)

In [45]:
blast_process_test_on_train_smin, blast_process_test_on_train_fmax, blast_process_test_on_train_best_threshold_s, blast_process_test_on_train_best_threshold_f, blast_process_test_on_train_s_at_fmax, blast_process_test_on_train_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=test_BP_df_dict,
    pred_annots_dict_with_scores=blast_process_on_train,
    threshold_range=threshold_range
)
print(f"AUPR Micro: {calculate_aupr_micro(test_BP_df_dict, blast_process_on_train)}")
blast_component_test_on_train_smin, blast_component_test_on_train_fmax, blast_component_test_on_train_best_threshold_s, blast_component_test_on_train_best_threshold_f, blast_component_test_on_train_s_at_fmax, blast_component_test_on_train_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=test_CC_df_dict,
    pred_annots_dict_with_scores=blast_component_on_train,
    threshold_range=threshold_range
)
print(f"AUPR Micro: {calculate_aupr_micro(test_CC_df_dict, blast_component_on_train)}")
blast_func_test_on_swissprot_smin, blast_func_test_on_swissprot_fmax, blast_func_test_on_swissprot_best_threshold_s, blast_func_test_on_swissprot_best_threshold_f, blast_func_test_on_swissprot_s_at_fmax, blast_func_test_on_swissprot_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=test_MF_df_dict,
    pred_annots_dict_with_scores=blast_function_on_swissprot,
    threshold_range=threshold_range
)
print(f"AUPR Micro: {calculate_aupr_micro(test_MF_df_dict, blast_function_on_swissprot)}")

blast_process_test_on_swissprot_smin, blast_process_test_on_swissprot_fmax, blast_process_test_on_swissprot_best_threshold_s, blast_process_test_on_swissprot_best_threshold_f, blast_process_test_on_swissprot_s_at_fmax, blast_process_test_on_swissprot_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=test_BP_df_dict,
    pred_annots_dict_with_scores=blast_process_on_swissprot,
    threshold_range=threshold_range
)
print(f"AUPR Micro: {calculate_aupr_micro(test_BP_df_dict, blast_process_on_swissprot)}")
blast_component_test_on_swissprot_smin, blast_component_test_on_swissprot_fmax, blast_component_test_on_swissprot_best_threshold_s, blast_component_test_on_swissprot_best_threshold_f, blast_component_test_on_swissprot_s_at_fmax, blast_component_test_on_swissprot_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=test_CC_df_dict,
    pred_annots_dict_with_scores=blast_component_on_swissprot,
    threshold_range=threshold_range
)
print(f"AUPR Micro: {calculate_aupr_micro(test_CC_df_dict, blast_component_on_swissprot)}")

Calculating Smin & Fmax: 100%|██████████| 99/99 [05:24<00:00,  3.28s/it]


F-max @ Best Threshold (0.9900000095367432): 0.6961835777412881
S-min @ Best Threshold (0.9900000095367432): 6.9090238953397
S-min @ F-max Threshold (0.9900000095367432): 6.9090238953397
AUPR Micro: 0.753624132171876


Calculating Smin & Fmax: 100%|██████████| 99/99 [01:30<00:00,  1.09it/s]


F-max @ Best Threshold (0.9900000095367432): 0.7218816715715645
S-min @ Best Threshold (0.9900000095367432): 2.1367866810710554
S-min @ F-max Threshold (0.9900000095367432): 2.1367866810710554
AUPR Micro: 0.7900266491781599


Calculating Smin & Fmax: 100%|██████████| 99/99 [04:19<00:00,  2.62s/it]


F-max @ Best Threshold (0.9900000095367432): 0.598541236591428
S-min @ Best Threshold (0.9900000095367432): 11.452622432349974
S-min @ F-max Threshold (0.9900000095367432): 11.452622432349974
AUPR Micro: 0.6847284791780713


Calculating Smin & Fmax: 100%|██████████| 99/99 [10:23<00:00,  6.30s/it]


F-max @ Best Threshold (0.9900000095367432): 0.5186763590505143
S-min @ Best Threshold (0.9900000095367432): 31.50684534579625
S-min @ F-max Threshold (0.9900000095367432): 31.50684534579625
AUPR Micro: 0.5228277628542191


Calculating Smin & Fmax: 100%|██████████| 99/99 [02:14<00:00,  1.36s/it]


F-max @ Best Threshold (0.9900000095367432): 0.5902444074626783
S-min @ Best Threshold (0.9900000095367432): 6.974960289964934
S-min @ F-max Threshold (0.9900000095367432): 6.974960289964934
AUPR Micro: 0.601654432870621


In [46]:
#save dfs 
blast_func_test_on_swissprot_results_df.to_csv('baseline models/blast_results/blast_func_test_on_swissprot_fmax_smin_df.csv', index=False)
blast_process_test_on_swissprot_results_df.to_csv('baseline models/blast_results/blast_process_test_on_swissprot_fmax_smin_df.csv', index=False)
blast_component_test_on_swissprot_results_df.to_csv('baseline models/blast_results/blast_component_test_on_swissprot_fmax_smin_df.csv', index=False)
blast_process_test_on_train_results_df.to_csv('baseline models/blast_results/blast_process_test_on_train_fmax_smin_df.csv', index=False)
blast_component_test_on_train_results_df.to_csv('baseline models/blast_results/blast_component_test_on_train_fmax_smin_df.csv', index=False)

plasmodium

In [41]:
blast_plasmodium_func_on_swissprot_smin, blast_plasmodium_func_on_swissprot_fmax, blast_plasmodium_func_on_swissprot_best_threshold_s, blast_plasmodium_func_on_swissprot_best_threshold_f, blast_plasmodium_func_on_swissprot_s_at_fmax, blast_plasmodium_func_on_swissprot_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=plasmodium_MF_df_dict,
    pred_annots_dict_with_scores=blast_plasmodium_function_on_swissprot,
    threshold_range=threshold_range
)
print(f"AUPR Micro: {calculate_aupr_micro(plasmodium_MF_df_dict, blast_plasmodium_function_on_swissprot)}")

blast_plasmodium_process_on_swissprot_smin, blast_plasmodium_process_on_swissprot_fmax, blast_plasmodium_process_on_swissprot_best_threshold_s, blast_plasmodium_process_on_swissprot_best_threshold_f, blast_plasmodium_process_on_swissprot_s_at_fmax, blast_plasmodium_process_on_swissprot_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=plasmodium_BP_df_dict,
    pred_annots_dict_with_scores=blast_plasmodium_process_on_swissprot,
    threshold_range=threshold_range
)
print(f"AUPR Micro: {calculate_aupr_micro(plasmodium_BP_df_dict, blast_plasmodium_process_on_swissprot)}")

blast_plasmodium_component_on_swissprot_smin, blast_plasmodium_component_on_swissprot_fmax, blast_plasmodium_component_on_swissprot_best_threshold_s, blast_plasmodium_component_on_swissprot_best_threshold_f, blast_plasmodium_component_on_swissprot_s_at_fmax, blast_plasmodium_component_on_swissprot_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=plasmodium_CC_df_dict,
    pred_annots_dict_with_scores=blast_plasmodium_component_on_swissprot,
    threshold_range=threshold_range
)

Calculating Smin & Fmax: 100%|██████████| 99/99 [00:11<00:00,  8.79it/s]


F-max @ Best Threshold (0.9900000095367432): 0.5948729642193064
S-min @ Best Threshold (0.9900000095367432): 10.712260957990571
S-min @ F-max Threshold (0.9900000095367432): 10.712260957990571
AUPR Micro: 0.6961778204890277


Calculating Smin & Fmax: 100%|██████████| 99/99 [00:32<00:00,  3.09it/s]


F-max @ Best Threshold (0.9900000095367432): 0.5134676228319336
S-min @ Best Threshold (0.9900000095367432): 24.594680429941278
S-min @ F-max Threshold (0.9900000095367432): 24.594680429941278
AUPR Micro: 0.5563032589912397


Calculating Smin & Fmax: 100%|██████████| 99/99 [00:06<00:00, 15.40it/s]

F-max @ Best Threshold (0.9300000071525574): 0.5559618222408098
S-min @ Best Threshold (0.9900000095367432): 6.493964649252618
S-min @ F-max Threshold (0.9300000071525574): 13.073187048030315





In [42]:
#save dfs
blast_plasmodium_func_on_swissprot_results_df.to_csv('baseline models/blast_results/blast_plasmodium_func_on_swissprot_fmax_smin_df.csv', index=False)
blast_plasmodium_process_on_swissprot_results_df.to_csv('baseline models/blast_results/blast_plasmodium_process_on_swissprot_fmax_smin_df.csv', index=False)
blast_plasmodium_component_on_swissprot_results_df.to_csv('baseline models/blast_results/blast_plasmodium_component_on_swissprot_fmax_smin_df.csv', index=False)

# Foldseek

plasmodium

In [43]:
foldseek_plasmodium_func_on_swissprot_smin, foldseek_plasmodium_func_on_swissprot_fmax, foldseek_plasmodium_func_on_swissprot_best_threshold_s, foldseek_plasmodium_func_on_swissprot_best_threshold_f, foldseek_plasmodium_func_on_swissprot_s_at_fmax, foldseek_plasmodium_func_on_swissprot_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=plasmodium_MF_df_dict,
    pred_annots_dict_with_scores=foldseek_plasmodium_function_on_swissprot,
    threshold_range=threshold_range
)
print(f"AUPR Micro: {calculate_aupr_micro(plasmodium_MF_df_dict, foldseek_plasmodium_function_on_swissprot)}")

foldseek_plasmodium_process_on_swissprot_smin, foldseek_plasmodium_process_on_swissprot_fmax, foldseek_plasmodium_process_on_swissprot_best_threshold_s, foldseek_plasmodium_process_on_swissprot_best_threshold_f, foldseek_plasmodium_process_on_swissprot_s_at_fmax, foldseek_plasmodium_process_on_swissprot_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=plasmodium_BP_df_dict,
    pred_annots_dict_with_scores=foldseek_plasmodium_process_on_swissprot,
    threshold_range=threshold_range
)
print(f"AUPR Micro: {calculate_aupr_micro(plasmodium_BP_df_dict, foldseek_plasmodium_process_on_swissprot)}")

foldseek_plasmodium_component_on_swissprot_smin, foldseek_plasmodium_component_on_swissprot_fmax, foldseek_plasmodium_component_on_swissprot_best_threshold_s, foldseek_plasmodium_component_on_swissprot_best_threshold_f, foldseek_plasmodium_component_on_swissprot_s_at_fmax, foldseek_plasmodium_component_on_swissprot_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=plasmodium_CC_df_dict,
    pred_annots_dict_with_scores=foldseek_plasmodium_component_on_swissprot,
    threshold_range=threshold_range
)

Calculating Smin & Fmax: 100%|██████████| 99/99 [00:31<00:00,  3.18it/s]


F-max @ Best Threshold (0.9900000095367432): 0.4866229268602174
S-min @ Best Threshold (0.9900000095367432): 15.568842784653118
S-min @ F-max Threshold (0.9900000095367432): 15.568842784653118
AUPR Micro: 0.5476459602352697


Calculating Smin & Fmax: 100%|██████████| 99/99 [02:07<00:00,  1.29s/it]


F-max @ Best Threshold (0.9800000190734863): 0.3976062910496501
S-min @ Best Threshold (0.9900000095367432): 40.230700432573585
S-min @ F-max Threshold (0.9800000190734863): 51.914073272113704
AUPR Micro: 0.3475218193278338


Calculating Smin & Fmax: 100%|██████████| 99/99 [00:17<00:00,  5.50it/s]

F-max @ Best Threshold (0.9700000286102295): 0.5346706180359273
S-min @ Best Threshold (0.9900000095367432): 8.480191253889489
S-min @ F-max Threshold (0.9700000286102295): 12.671456936545338





In [44]:
#save dfs
foldseek_plasmodium_func_on_swissprot_results_df.to_csv('baseline models/foldseek_results/foldseek_plasmodium_func_on_swissprot_fmax_smin_df.csv', index=False)
foldseek_plasmodium_process_on_swissprot_results_df.to_csv('baseline models/foldseek_results/foldseek_plasmodium_process_on_swissprot_fmax_smin_df.csv', index=False)
foldseek_plasmodium_component_on_swissprot_results_df.to_csv('baseline models/foldseek_results/foldseek_plasmodium_component_on_swissprot_fmax_smin_df.csv', index=False)

In [47]:
foldseek_func_test_on_train_smin, foldseek_func_test_on_train_fmax, foldseek_func_test_on_train_best_threshold_s, foldseek_func_test_on_train_best_threshold_f, foldseek_func_test_on_train_s_at_fmax, foldseek_func_test_on_train_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=test_MF_df_dict,
    pred_annots_dict_with_scores=foldseek_function_on_train,
    threshold_range=threshold_range
)
print(f"AUPR Micro: {calculate_aupr_micro(test_MF_df_dict, foldseek_function_on_train)}")

foldseek_process_test_on_train_smin, foldseek_process_test_on_train_fmax, foldseek_process_test_on_train_best_threshold_s, foldseek_process_test_on_train_best_threshold_f, foldseek_process_test_on_train_s_at_fmax, foldseek_process_test_on_train_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=test_BP_df_dict,
    pred_annots_dict_with_scores=foldseek_process_on_train,
    threshold_range=threshold_range
)
print(f"AUPR Micro: {calculate_aupr_micro(test_BP_df_dict, foldseek_process_on_train)}")

foldseek_component_test_on_train_smin, foldseek_component_test_on_train_fmax, foldseek_component_test_on_train_best_threshold_s, foldseek_component_test_on_train_best_threshold_f, foldseek_component_test_on_train_s_at_fmax, foldseek_component_test_on_train_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=test_CC_df_dict,
    pred_annots_dict_with_scores=foldseek_component_on_train,
    threshold_range=threshold_range
)
print(f"AUPR Micro: {calculate_aupr_micro(test_CC_df_dict, foldseek_component_on_train)}")

foldseek_func_test_on_swissprot_smin, foldseek_func_test_on_swissprot_fmax, foldseek_func_test_on_swissprot_best_threshold_s, foldseek_func_test_on_swissprot_best_threshold_f, foldseek_func_test_on_swissprot_s_at_fmax, foldseek_func_test_on_swissprot_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=test_MF_df_dict,
    pred_annots_dict_with_scores=foldseek_function_on_swissprot,
    threshold_range=threshold_range
)
print(f"AUPR Micro: {calculate_aupr_micro(test_MF_df_dict, foldseek_function_on_swissprot)}")

foldseek_process_test_on_swissprot_smin, foldseek_process_test_on_swissprot_fmax, foldseek_process_test_on_swissprot_best_threshold_s, foldseek_process_test_on_swissprot_best_threshold_f, foldseek_process_test_on_swissprot_s_at_fmax, foldseek_process_test_on_swissprot_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=test_BP_df_dict,
    pred_annots_dict_with_scores=foldseek_process_on_swissprot,
    threshold_range=threshold_range
)

print(f"AUPR Micro: {calculate_aupr_micro(test_BP_df_dict, foldseek_process_on_swissprot)}")

foldseek_component_test_on_swissprot_smin, foldseek_component_test_on_swissprot_fmax, foldseek_component_test_on_swissprot_best_threshold_s, foldseek_component_test_on_swissprot_best_threshold_f, foldseek_component_test_on_swissprot_s_at_fmax, foldseek_component_test_on_swissprot_results_df = threshold_performance_metrics(
    ic_dict=ic_dict,
    real_annots_dict=test_CC_df_dict,
    pred_annots_dict_with_scores=foldseek_component_on_swissprot,
    threshold_range=threshold_range
)

print(f"AUPR Micro: {calculate_aupr_micro(test_CC_df_dict, foldseek_component_on_swissprot)}")



Calculating Smin & Fmax: 100%|██████████| 99/99 [01:42<00:00,  1.04s/it]


F-max @ Best Threshold (0.009999999776482582): 0.4258024535315746
S-min @ Best Threshold (0.1599999964237213): 7.828731444270688
S-min @ F-max Threshold (0.009999999776482582): 7.8800929093743655
AUPR Micro: 0.7699852969233438


Calculating Smin & Fmax: 100%|██████████| 99/99 [01:23<00:00,  1.19it/s]


F-max @ Best Threshold (0.07000000029802322): 0.29871472160136453
S-min @ Best Threshold (0.6000000238418579): 10.030916308423194
S-min @ F-max Threshold (0.07000000029802322): 10.624523458300976
AUPR Micro: 0.51053537707953


Calculating Smin & Fmax: 100%|██████████| 99/99 [01:13<00:00,  1.34it/s]


F-max @ Best Threshold (0.009999999776482582): 0.3519402126671549
S-min @ Best Threshold (0.5): 2.6961853069370645
S-min @ F-max Threshold (0.009999999776482582): 3.0487624655127337
AUPR Micro: 0.5937977952694989


Calculating Smin & Fmax: 100%|██████████| 99/99 [05:23<00:00,  3.26s/it]


F-max @ Best Threshold (0.9900000095367432): 0.8085844137941027
S-min @ Best Threshold (0.9900000095367432): 8.107724794554588
S-min @ F-max Threshold (0.9900000095367432): 8.107724794554588
AUPR Micro: 0.8449195348088556


Calculating Smin & Fmax: 100%|██████████| 99/99 [24:15<00:00, 14.70s/it]


F-max @ Best Threshold (0.9900000095367432): 0.7128679020096702
S-min @ Best Threshold (0.9900000095367432): 29.388975073308888
S-min @ F-max Threshold (0.9900000095367432): 29.388975073308888
AUPR Micro: 0.6832811042734481


Calculating Smin & Fmax: 100%|██████████| 99/99 [04:03<00:00,  2.46s/it]


F-max @ Best Threshold (0.9900000095367432): 0.6900694797157101
S-min @ Best Threshold (0.9900000095367432): 6.957947572977
S-min @ F-max Threshold (0.9900000095367432): 6.957947572977
AUPR Micro: 0.6899680977012932


In [48]:
#save dfs
foldseek_func_test_on_swissprot_results_df.to_csv('baseline models/foldseek_results/foldseek_func_test_on_swissprot_fmax_smin_df.csv', index=False)
foldseek_process_test_on_swissprot_results_df.to_csv('baseline models/foldseek_results/foldseek_process_test_on_swissprot_fmax_smin_df.csv', index=False)
foldseek_component_test_on_swissprot_results_df.to_csv('baseline models/foldseek_results/foldseek_component_test_on_swissprot_fmax_smin_df.csv', index=False)
foldseek_func_test_on_train_results_df.to_csv('baseline models/foldseek_results/foldseek_func_test_on_train_fmax_smin_df.csv', index=False)
foldseek_process_test_on_train_results_df.to_csv('baseline models/foldseek_results/foldseek_process_test_on_train_fmax_smin_df.csv', index=False)
foldseek_component_test_on_train_results_df.to_csv('baseline models/foldseek_results/foldseek_component_test_on_train_fmax_smin_df.csv', index=False)
