# Get mF1 scores for each of the identity intervals  

In [ ]:
# define the directory where the data is
data_path = "/home/joao/Desktop/required_data_ec_number_paper/"

In [ ]:
import pandas as pd

not_in_predictions = pd.read_csv(f"{data_path}/not_in_predictions.csv")
predictions = pd.read_csv(f"{data_path}/predictions_no_duplicates_ident_evalue.csv")
test_dataset_pandas = pd.read_csv(f"{data_path}/data/test.csv")

In [ ]:
blast_results = pd.read_csv(f"{data_path}/test_blast_predictions_right_format.csv")
blast_results_true_values = pd.read_csv(f"{data_path}/test_right_format.csv")

## Function to get the predictions for the enzymes below a specific identity threshold

In [ ]:
from plants_sm.data_structures.dataset.single_input_dataset import SingleInputDataset
from plants_sm.models.pytorch_model import PyTorchModel
from sklearn.metrics import f1_score

def get_less_studied_predictions(model, model_path, evalue_threshold=1e-5, identity_threshold1=None, identity_threshold2=None, return_blast_results=True):
    
    if identity_threshold1 is None:
        high_evalue_predictions = predictions[predictions["evalue"]>evalue_threshold].qseqid.tolist()
        high_evalue_predictions.extend(not_in_predictions.qseqid.tolist())
        print(len(high_evalue_predictions))
    elif identity_threshold2 is None:
        high_evalue_predictions = predictions[predictions["pident"]<identity_threshold1].qseqid.tolist()
        high_evalue_predictions.extend(not_in_predictions.qseqid.tolist())
        print(len(high_evalue_predictions))
    else:
        high_evalue_predictions = predictions[(predictions["pident"]<=identity_threshold1) & (predictions["pident"]>identity_threshold2)].qseqid.tolist()
        high_evalue_predictions.extend(not_in_predictions.qseqid.tolist())
        print(len(high_evalue_predictions))
    
    test_dataset_pandas_filtered = test_dataset_pandas[test_dataset_pandas["accession"].isin(high_evalue_predictions)]
    test_dataset_no_similarity = SingleInputDataset(test_dataset_pandas_filtered,
                                            instances_ids_field="accession", representation_field="sequence",
                                            labels_field=slice(8, 2779))
    test_dataset_no_similarity.load_features(f'/scratch/jribeiro/results/{model}/test/')

    blast_results_no_similarity = blast_results[blast_results["qseqid"].isin(high_evalue_predictions)]
    blast_results_true_values_no_similarity = blast_results_true_values[blast_results_true_values["accession"].isin(high_evalue_predictions)]

    predictions_blast = blast_results_no_similarity.iloc[:, 6:]
    true_values = blast_results_true_values_no_similarity.iloc[:, 8:]

    def get_labels_to_remove(dataset):
        labels_with_no_positive_sample = []
        for i in range(dataset.shape[1]):
            if dataset[:, i].sum() == 0:
                labels_with_no_positive_sample.append(i)
        return labels_with_no_positive_sample

    labels_to_remove = get_labels_to_remove(np.array(true_values))

    predictions_blast = np.delete(np.array(predictions_blast), labels_to_remove, axis=1)
    true_values = np.delete(np.array(true_values), labels_to_remove, axis=1)

    results = {}
    
    if return_blast_results:
        print("BLAST")
        blast_wf1 = f1_score(true_values, predictions_blast, average="weighted")
        blast_mf1 = f1_score(true_values, predictions_blast, average="macro")
        print(blast_mf1)
        print(blast_wf1)

    else:
        blast_mf1 = None
        blast_wf1 = None

    model = PyTorchModel.load(model_path)
    predictions_esm2_3b_no_similarity = model.predict(test_dataset_no_similarity)
    # drop columns with only zeros
    
    labels_to_remove = get_labels_to_remove(test_dataset_no_similarity.y)

    predictions_esm2_3b_no_similarity_ = np.delete(predictions_esm2_3b_no_similarity, labels_to_remove, axis=1)
    y_true = np.delete(test_dataset_no_similarity.y, labels_to_remove, axis=1)
    print("Model")
    mf1 = f1_score(y_true, predictions_esm2_3b_no_similarity_, average="macro")
    print(mf1)
    wf1 = f1_score(y_true, predictions_esm2_3b_no_similarity_, average="weighted")
    print(wf1)
    return mf1, wf1, len(high_evalue_predictions), blast_mf1, blast_wf1
    

## Get the results for each of the identity intervals for ESM and ProtBERT models

In [ ]:
model = "esm2_t36_3B_UR50D"
model_path = f"{data_path}/models/DNN_esm2_t36_3B_UR50D_trial_2_merged"


models = [("esm2_t36_3B_UR50D", f"{data_path}/models/DNN_esm2_t36_3B_UR50D_trial_2_merged"), 
          ("prot_bert_vectors", f"{data_path}/models/DNN_prot_bert_vectors_trial_2_merged"), 
          ("esm1b_t33_650M_UR50S", f"{data_path}/models/DNN_esm1b_t33_650M_UR50S_trial_4_merged"), 
          ("esm2_t33_650M_UR50D", f"{data_path}/models/DNN_esm2_t33_650M_UR50D_trial_4_merged")]

results = {}
thresholds = [0, 15, 25, 35, 45, 55, 65, 75, 85, 90, 100]
#thresholds = [90, 100]


for model, model_path in models:
    for i in range(1, len(thresholds)):

        identity_threshold1 = thresholds[i]
        identity_threshold2 = thresholds[i-1]
        print(f"identity {identity_threshold1}")
        mf1, wf1, samples_num, _, _ = get_less_studied_predictions(model, model_path,
                                    identity_threshold1=identity_threshold1,
                                    identity_threshold2=identity_threshold2, return_blast_results=False)
        
        if "method" not in results:
            results["method"] = [model]
        else:
            results["method"].append(model)
        
        if "identity_threshold" not in results:
            results["identity_threshold"] = [identity_threshold1]
        else:
            results["identity_threshold"].append(identity_threshold1)

        if "macro_f1" not in results:
            results["macro_f1"] = [mf1]
        else:
            results["macro_f1"].append(mf1)

        if "weighted_f1" not in results:
            results["weighted_f1"] = [wf1]

        else:
            results["weighted_f1"].append(wf1)

        if "samples_num" not in results:
            results["samples_num"] = [samples_num]
        else:
            results["samples_num"].append(samples_num)

pd.DataFrame(results).to_csv("less_studied_results.csv", index=False)

## Get the results for each of the identity intervals for DeepEC and DSPACE models

In [ ]:
model = "esm2_t36_3B_UR50D"
model_path = f"{data_path}/models/DNN_esm2_t36_3B_UR50D_trial_2_merged"


models = [("DeepEC", "one_hot_encoding", f"{data_path}/models/DeepEC_merged_merged"), 
          ("DSPACE", "one_hot_encoding", f"{data_path}/models/DSPACE_merged_merged") 
         ]

thresholds = [0, 15, 25, 35, 45, 55, 65, 75, 85, 90, 100]
#thresholds = [90, 100]

def f1_score_macro():
    pass

results = {}
for model_name,model, model_path in models:
    for i in range(1, len(thresholds)):

        identity_threshold1 = thresholds[i]
        identity_threshold2 = thresholds[i-1]
        print(f"identity {identity_threshold1}")
        mf1, wf1, samples_num, _, _ = get_less_studied_predictions(model, model_path,
                                    identity_threshold1=identity_threshold1,
                                    identity_threshold2=identity_threshold2, return_blast_results=False)
        
        if "method" not in results:
            results["method"] = [model_name]
        else:
            results["method"].append(model_name)
        
        if "identity_threshold" not in results:
            results["identity_threshold"] = [identity_threshold1]
        else:
            results["identity_threshold"].append(identity_threshold1)

        if "macro_f1" not in results:
            results["macro_f1"] = [mf1]
        else:
            results["macro_f1"].append(mf1)

        if "weighted_f1" not in results:
            results["weighted_f1"] = [wf1]

        else:
            results["weighted_f1"].append(wf1)

        if "samples_num" not in results:
            results["samples_num"] = [samples_num]
        else:
            results["samples_num"].append(samples_num)
            
pd.read_csv("less_studied_results.csv").append(pd.DataFrame(results)).to_csv("less_studied_results.csv", index=False)

## Get the results for each of the identity intervals for BLASTp

In [ ]:
model = "esm2_t36_3B_UR50D"
model_path = f"{data_path}/models/DNN_esm2_t36_3B_UR50D_trial_2_merged"

less_studied_enzymes = pd.read_csv("less_studied_results.csv")

results = {}
thresholds = [0, 15, 25, 35, 45, 55, 65, 75, 85, 90, 100]
for i in range(1, len(thresholds)):
    identity_threshold1 = thresholds[i]
    identity_threshold2 = thresholds[i-1]
    mf1, wf1, samples_num, blast_mf1_score, blast_wf1_score = get_less_studied_predictions(model, model_path,
                                        identity_threshold1=identity_threshold1,
                                        identity_threshold2=identity_threshold2, return_blast_results=True)
    if "method" not in results:
        results["method"] = ["BLASTp"]
    else:
        results["method"].append("BLASTp")

    if "identity_threshold" not in results:
        results["identity_threshold"] = [identity_threshold1]
    else:
        results["identity_threshold"].append(identity_threshold1)

    if "macro_f1" not in results:
        results["macro_f1"] = [blast_mf1_score]
    else:
        results["macro_f1"].append(blast_mf1_score)
    
    if "weighted_f1" not in results:
        results["weighted_f1"] = [blast_wf1_score]
    else:
        results["weighted_f1"].append(blast_wf1_score)
    
    if "samples_num" not in results:
        results["samples_num"] = [samples_num]
    else:
        results["samples_num"].append(samples_num)


## Get the results for each of the identity intervals for the ensemble of models and BLASTp

### Note that the predictions were already saved in a pickle file

In [ ]:
import pickle

with open(f"{data_path}/predictions/predictions_models_voting.pkl", "rb") as f:
    voting_predictions = pickle.load(f)

# read predictions from pickle file
import pickle

with open(f"{data_path}/predictions/predictions_models_voting_blast.pkl", "rb") as f:
    voting_predictions_blast = pickle.load(f)

In [ ]:
results = {}
thresholds = [0, 15, 25, 35, 45, 55, 65, 75, 85, 90, 100]
#thresholds = [90, 100]

for method in ["Models + BLASTp ensemble", "Models ensemble"]:
    for i in range(1, len(thresholds)):

        if method == "Models ensemble":
            voting_predictions_ = voting_predictions
        else:
            voting_predictions_ = voting_predictions_blast

        identity_threshold1 = thresholds[i]
        identity_threshold2 = thresholds[i-1]
        print(f"identity {identity_threshold1}")
        mf1, samples_num = get_less_studied_predictions(voting_predictions_,
                                    identity_threshold1=identity_threshold1,
                                    identity_threshold2=identity_threshold2)

        if "method" not in results:
            results["method"] = [method]
        else:
            results["method"].append(method)
        
        if "identity_threshold" not in results:
            results["identity_threshold"] = [identity_threshold1]
        else:
            results["identity_threshold"].append(identity_threshold1)

        if "macro_f1" not in results:
            results["macro_f1"] = [mf1]
        else:
            results["macro_f1"].append(mf1)

        if "samples_num" not in results:
            results["samples_num"] = [samples_num]
        else:
            results["samples_num"].append(samples_num)
            
pd.read_csv("less_studied_results.csv").append(pd.DataFrame(results)).to_csv("less_studied_results.csv", index=False)

## Save them to the csv file

In [ ]:
less_studied_enzymes.append(pd.DataFrame(results)).to_csv("less_studied_results.csv", index=False)

## Plot the results

Note that the name of the models were changed in the dataset directly to make the plot more readable.

In [ ]:
color_map = {'DNN ESM2 35M': (0.00392156862745098,
  0.45098039215686275,
  0.6980392156862745),
 'DNN ESM2 150M': (0.8705882352941177, 0.5607843137254902, 0.0196078431372549),
 'DNN ESM2 8M': (0.00784313725490196, 0.6196078431372549, 0.45098039215686275),
 'DNN ProtBERT': (0.8352941176470589, 0.3686274509803922, 0.0),
 'DNN ESM2 3B': (0.8, 0.47058823529411764, 0.7372549019607844),
 'DNN ESM2 650M': (0.792156862745098, 0.5686274509803921, 0.3803921568627451),
 'DNN ESM1b': (0.984313725490196, 0.6862745098039216, 0.8941176470588236),
 'DeepEC CNN3': (0.5803921568627451, 0.5803921568627451, 0.5803921568627451),
 'DSPACE EC': (0.9254901960784314, 0.8823529411764706, 0.2),
 'Models + BLASTp ensemble': (0.00392156862745098,
  0.45098039215686275,
  0.6980392156862745),
 'Models ensemble': (0.8705882352941177, 0.5607843137254902, 0.0196078431372549),
    'BLASTp': (0.00784313725490196, 0.6196078431372549, 0.45098039215686275),}

In [ ]:
import pandas as pd
import numpy as np

less_studied_enzymes = pd.read_csv("less_studied_results.csv")

# generate line plot for f1 score
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams['savefig.transparent'] = True

# Set Seaborn style
sns.set(style="whitegrid")
sns.set(rc={'figure.figsize':(11.7,8.27)})

# Plotting with Seaborn
plt.figure(figsize=(11.7, 8.27))
sns.lineplot(x='identity_threshold', y='macro_f1', hue='method', marker='o',  markers=True, markersize=5, data=less_studied_enzymes, palette=color_map)

# plt.title('mF1 per Identity Intervals')
plt.xlabel('Identity Intervals from BLASTp alignment', labelpad=20, fontsize=23)
plt.ylabel('mF1', labelpad=20, fontsize=23)
plt.legend(title='Method')
plt.ylim(0, 1, 0.1)
thresholds = [0, 15, 25, 35, 45, 55, 65, 75, 85, 90, 100]
labels = []
for i in range(1, len(thresholds)):
    labels.append(f"]{thresholds[i-1]}, {thresholds[i]}]")

# Add grid lines
plt.grid(True, linestyle='--', alpha=0.7)

# Adjust tick parameters
plt.tick_params(axis='both', which='major', labelsize=20)
plt.xticks([15, 25, 35, 45, 55, 65, 75, 85, 90, 100], labels=labels, rotation=90)
plt.rcParams['figure.figsize'] = [15, 12]
plt.rcParams['figure.dpi'] = 200
plt.title("mF1 per Identity Intervals from BLASTp alignment", pad=40, fontsize=23)
# plt.show()
plt.savefig("mF1_per_identity_intervals.png", bbox_inches='tight', dpi=400)