In [1]:
import pandas as pd 
import numpy as np
import matplotlib.pyplot as plt
import torch

from Bio.Seq import Seq
from Bio import SeqIO
from Bio import Align
from Bio import AlignIO
from Bio.Align import substitution_matrices
from Bio.Data import IUPACData
from Bio.Blast import NCBIWWW, NCBIXML
from Bio.SeqRecord import SeqRecord
from Bio import pairwise2
from Bio.pairwise2 import format_alignment

'''import cafaeval
from cafaeval.evaluation import cafa_eval
from cafaeval.parser import obo_parser, gt_parser'''

from pathlib import Path
import os

import h5py
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import ast

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer




In [2]:
# Path to train data
training_data_path = Path('../data/train')

# Path to test data
test_data_path = Path('../data/test')

# Path to baseline data
baseline_data_path = Path('../data/baseline')


In [3]:
# Extracting test_ids.txt
with open(test_data_path / 'test_ids.txt', 'r') as file:
    test_ids = file.read().splitlines()

# Display the first few IDs to verify
print(test_ids[:5])

['O43747', 'Q969H0', 'Q9JMA2', 'P18065', 'A0A8I6AN32']


In [4]:
test_fasta_list = list(SeqIO.parse(test_data_path / 'test.fasta', 'fasta'))

# Print the first sequence to verify
print(test_fasta_list[0])

ID: O43747
Name: O43747
Description: O43747
Number of features: 0
Seq('MPAPIRLRELIRTIRTARTQAEEREMIQKECAAIRSSFREEDNTYRCRNVAKLL...SWQ')


In [5]:
# Extract relevant information from SeqRecord
test_fasta_dict = [{
    'ID': record.id,
    'name': record.name,
    'description': record.description,
    'num_features': len(record.features),
    'sequence': record.seq,
} for record in test_fasta_list]

# Create a DataFrame from the extracted data
test_fasta = pd.DataFrame(test_fasta_dict)

# Display the DataFrame
test_fasta.head()

Unnamed: 0,ID,name,description,num_features,sequence
0,O43747,O43747,O43747,0,"(M, P, A, P, I, R, L, R, E, L, I, R, T, I, R, ..."
1,Q969H0,Q969H0,Q969H0,0,"(M, N, Q, E, L, L, S, V, G, S, K, R, R, R, T, ..."
2,Q9JMA2,Q9JMA2,Q9JMA2,0,"(M, A, A, V, G, S, P, G, S, L, E, S, A, P, R, ..."
3,P18065,P18065,P18065,0,"(M, L, P, R, V, G, C, P, A, L, P, L, P, P, P, ..."
4,A0A8I6AN32,A0A8I6AN32,A0A8I6AN32,0,"(M, A, S, N, D, Y, T, Q, Q, A, T, Q, S, Y, G, ..."


In [6]:
# Checking for differences between the ID and name columns
diff_id_name = sum(test_fasta['ID'] != test_fasta['name'])

# Checking for differences between the ID and description columns
diff_id_description = sum(test_fasta['ID'] != test_fasta['description'])

print(f"We have a total of {diff_id_name} differences between the ID and name columns.\nWe have a total of {diff_id_description} differences between the ID and description columns.")

We have a total of 0 differences between the ID and name columns.
We have a total of 0 differences between the ID and description columns.


In [7]:
num_features_values = sum(test_fasta['num_features'] != 0)

print(f"We have a total of {num_features_values} sequences with features.")

We have a total of 0 sequences with features.


In [8]:
test_fasta.drop(columns=['name', 'description', 'num_features'], inplace=True)


test_fasta.head()

Unnamed: 0,ID,sequence
0,O43747,"(M, P, A, P, I, R, L, R, E, L, I, R, T, I, R, ..."
1,Q969H0,"(M, N, Q, E, L, L, S, V, G, S, K, R, R, R, T, ..."
2,Q9JMA2,"(M, A, A, V, G, S, P, G, S, L, E, S, A, P, R, ..."
3,P18065,"(M, L, P, R, V, G, C, P, A, L, P, L, P, P, P, ..."
4,A0A8I6AN32,"(M, A, S, N, D, Y, T, Q, Q, A, T, Q, S, Y, G, ..."


In [9]:
len_ID = len(test_fasta['ID'].unique()) # assigned because gave problem on else statement print

if len(test_ids) == len_ID:
    print(f"The number of IDs in train_ids.txt is equal to the number of unique IDs in the train set ({len(test_ids)}).\n"
          "Proceeding with the analysis.")
else:
    print(f'The numbers are not the same: test_ids are {len(test_ids)}, while the length of the fasta file is {len_ID})')

The number of IDs in train_ids.txt is equal to the number of unique IDs in the train set (1000).
Proceeding with the analysis.


In [10]:
data_list = []

with h5py.File(test_data_path / "test_embeddings.h5", "r") as f:
    for dataset_name in f.keys():
        dataset = f[dataset_name][:]
        data_list.append([dataset_name, dataset])

test_embeddings = pd.DataFrame(data_list, columns=["ID", "embeddings"])

test_embeddings.head()


Unnamed: 0,ID,embeddings
0,A0A0B4JCV4,"[0.00979, -0.03973, 0.03653, -0.006447, -0.040..."
1,A0A0B4KHT0,"[0.02786, -0.01154, 0.008865, -0.01765, 0.0073..."
2,A0A0B4P506,"[0.01643, 0.01802, 0.03702, -0.0591, 0.0356, 0..."
3,A0A0G2K1A2,"[0.00882, 0.0835, -0.001374, -0.0003645, -0.06..."
4,A0A0G2K1V4,"[0.0659, 0.0929, -0.001803, 0.0226, 0.0383, 0...."


In [11]:
test_protein2ipr = pd.read_csv(test_data_path / 'test_protein2ipr.dat', sep='\t')

# Rename Protein_ID and aspect columns
test_protein2ipr.columns = ['ID', 'ipr', 'domain', 'familyID', 'start', 'end']

# Remove 'domain' that is useless
test_protein2ipr.drop('domain', axis=1)

test_protein2ipr.head()

Unnamed: 0,ID,ipr,domain,familyID,start,end
0,A0A0B4JCV4,IPR039915,TACC family,PTHR13924,38,1206
1,A0A0B4KHT0,IPR000315,B-box-type zinc finger,PF00643,177,219
2,A0A0B4KHT0,IPR000315,B-box-type zinc finger,PF00643,236,274
3,A0A0B4KHT0,IPR000315,B-box-type zinc finger,PS50119,173,220
4,A0A0B4KHT0,IPR000315,B-box-type zinc finger,PS50119,235,282


In [12]:
# Group by 'ID' and aggregate other columns into lists
test_protein2ipr_grouped = test_protein2ipr.groupby('ID').agg(lambda x: tuple(x)).reset_index()

print(f"Test protein2ipr ({test_protein2ipr.shape}):")
test_protein2ipr_grouped.head()

Test protein2ipr ((11263, 6)):


Unnamed: 0,ID,ipr,domain,familyID,start,end
0,A0A0B4JCV4,"(IPR039915,)","(TACC family,)","(PTHR13924,)","(38,)","(1206,)"
1,A0A0B4KHT0,"(IPR000315, IPR000315, IPR000315, IPR000315, I...","(B-box-type zinc finger, B-box-type zinc finge...","(PF00643, PF00643, PS50119, PS50119, SM00336, ...","(177, 236, 173, 235, 173, 235, 976, 826, 988, ...","(219, 274, 220, 282, 220, 276, 1048, 839, 1004..."
2,A0A0B4P506,"(IPR003417, IPR003417, IPR036552, IPR036552)","(Core-binding factor, beta subunit, Core-bindi...","(PF02312, PTHR10276, G3DSA:2.40.250.10, SSF50723)","(1, 1, 1, 4)","(164, 168, 142, 140)"
3,A0A0G2K1A2,"(IPR010255, IPR019791, IPR019791, IPR019791, I...","(Haem peroxidase superfamily, Haem peroxidase,...","(SSF48113, PF03098, PR00457, PR00457, PR00457,...","(142, 148, 172, 226, 374, 392, 417, 470, 598, ...","(718, 692, 183, 241, 392, 412, 443, 480, 618, ..."
4,A0A0G2K1V4,"(IPR000048, IPR000048, IPR001609, IPR001609, I...","(IQ motif, EF-hand binding site, IQ motif, EF-...","(PS50096, SM00015, PF00063, PR00193, PR00193, ...","(789, 788, 89, 116, 172, 228, 459, 513, 86, 80...","(818, 810, 774, 135, 197, 255, 487, 541, 786, ..."


In [13]:
combined_test = pd.merge(test_embeddings, test_fasta, on='ID')
combined_test = pd.merge(combined_test, test_protein2ipr_grouped, on='ID', how='left')

missing_rows = combined_test[combined_test['ipr'].isna()].shape[0]
print(f"Number of rows missing from train_protein2ipr_grouped: {missing_rows}")

print(f"Combined DataFrame shape: {combined_test.shape}")
combined_test.head()

Number of rows missing from train_protein2ipr_grouped: 19
Combined DataFrame shape: (1000, 8)


Unnamed: 0,ID,embeddings,sequence,ipr,domain,familyID,start,end
0,A0A0B4JCV4,"[0.00979, -0.03973, 0.03653, -0.006447, -0.040...","(M, E, F, D, D, A, E, N, G, L, G, M, G, F, G, ...","(IPR039915,)","(TACC family,)","(PTHR13924,)","(38,)","(1206,)"
1,A0A0B4KHT0,"[0.02786, -0.01154, 0.008865, -0.01765, 0.0073...","(M, D, M, D, L, E, Q, L, K, N, D, F, L, P, L, ...","(IPR000315, IPR000315, IPR000315, IPR000315, I...","(B-box-type zinc finger, B-box-type zinc finge...","(PF00643, PF00643, PS50119, PS50119, SM00336, ...","(177, 236, 173, 235, 173, 235, 976, 826, 988, ...","(219, 274, 220, 282, 220, 276, 1048, 839, 1004..."
2,A0A0B4P506,"[0.01643, 0.01802, 0.03702, -0.0591, 0.0356, 0...","(M, P, R, V, V, P, D, Q, R, S, K, F, E, N, E, ...","(IPR003417, IPR003417, IPR036552, IPR036552)","(Core-binding factor, beta subunit, Core-bindi...","(PF02312, PTHR10276, G3DSA:2.40.250.10, SSF50723)","(1, 1, 1, 4)","(164, 168, 142, 140)"
3,A0A0G2K1A2,"[0.00882, 0.0835, -0.001374, -0.0003645, -0.06...","(M, K, L, F, L, A, L, A, G, L, L, A, P, L, A, ...","(IPR010255, IPR019791, IPR019791, IPR019791, I...","(Haem peroxidase superfamily, Haem peroxidase,...","(SSF48113, PF03098, PR00457, PR00457, PR00457,...","(142, 148, 172, 226, 374, 392, 417, 470, 598, ...","(718, 692, 183, 241, 392, 412, 443, 480, 618, ..."
4,A0A0G2K1V4,"[0.0659, 0.0929, -0.001803, 0.0226, 0.0383, 0....","(M, S, S, D, A, E, M, A, V, F, G, E, A, A, P, ...","(IPR000048, IPR000048, IPR001609, IPR001609, I...","(IQ motif, EF-hand binding site, IQ motif, EF-...","(PS50096, SM00015, PF00063, PR00193, PR00193, ...","(789, 788, 89, 116, 172, 228, 459, 513, 86, 80...","(818, 810, 774, 135, 197, 255, 487, 541, 786, ..."


In [14]:
X_test = combined_test[['ID','embeddings']]

In [15]:
# Step 1: Expand vectors while keeping IDs as index
df_expanded = X_test.set_index('ID')['embeddings'].apply(lambda x: x.flatten()).apply(pd.Series)

# Step 2: Convert expanded DataFrame to a PyTorch tensor
tensor = torch.tensor(df_expanded.values, dtype=torch.float32)

# Step 3: Keep IDs as a list for reference
protein_ids = df_expanded.index.tolist()

# Verifying the outputs
print("Tensor shape:", tensor.shape)  # Expected: torch.Size([1000, 1024])
print("First few IDs:", protein_ids[:5])  # Checking the first few IDs

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

X_test_tensor = tensor.to(device)


Tensor shape: torch.Size([1000, 1024])
First few IDs: ['A0A0B4JCV4', 'A0A0B4KHT0', 'A0A0B4P506', 'A0A0G2K1A2', 'A0A0G2K1V4']


In [16]:
print(X_test.shape)

(1000, 2)


In [17]:
# protein_ids = combined_test["ID"].tolist()
# protein_ids = np.array(protein_ids)  # Keep it as an array to ensure alignment

In [18]:
# # Convert to tensor
# X_test_tensor = torch.tensor(X_test, dtype=torch.float32)

# # Move to GPU if available
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# X_test_tensor = X_test_tensor.to(device)

In [19]:
'''import numpy as np

def format_predictions(protein_ids, go_terms, y_pred, threshold=0.2, max_labels_per_protein=1500, min_threshold=0.05):
    """
    Formats the model predictions according to the competition rules.

    Args:
    - protein_ids (list): List of protein target IDs.
    - go_terms (list): List of GO term labels corresponding to output nodes.
    - y_pred (numpy.ndarray): Predicted probabilities (shape: [num_proteins, num_classes]).
    - threshold (float): Minimum probability to include a GO term.
    - max_labels_per_protein (int): Maximum allowed labels per protein.
    - min_threshold (float): Lower probability bound if additional labels are needed.

    Returns:
    - list of formatted predictions (protein_id, go_term, probability)
    """
    formatted_results = []

    for i, protein_id in enumerate(protein_ids):
        # Get predictions and sort them in descending order
        pred_probs = y_pred[i]
        sorted_indices = np.argsort(-pred_probs)  # Descending order

        # Select labels above threshold
        selected_indices = [idx for idx in sorted_indices if pred_probs[idx] > threshold]

        # If too many labels, keep only top-K
        if len(selected_indices) > max_labels_per_protein:
            selected_indices = selected_indices[:max_labels_per_protein]

        # If too few labels, add lower-probability terms down to min_threshold
        elif len(selected_indices) < max_labels_per_protein:
            additional_indices = [idx for idx in sorted_indices if min_threshold <= pred_probs[idx] <= threshold]
            additional_needed = max_labels_per_protein - len(selected_indices)
            selected_indices += additional_indices[:additional_needed]

        # Format predictions
        for idx in selected_indices:
            go_term = go_terms[idx]
            probability = round(float(pred_probs[idx]), 3)  # Keep 3 decimal places
            probability = max(probability, 0.001)  # Ensure probability > 0
            formatted_results.append(f"{protein_id} {go_term} {probability}")

    return formatted_results

# Example usage:
# protein_ids = ['P9WHI7', 'P04637', ...]  # List of protein IDs
# go_terms = ['GO:0009274', 'GO:0071944', 'GO:0005575', ...]  # GO terms corresponding to model outputs
# y_pred = model1_cc(X_test_tensor).cpu().numpy()  # Get model predictions
# result_lines = format_predictions(protein_ids, go_terms, y_pred, threshold=0.2, max_labels_per_protein=1500, min_threshold=0.05)

# Print or save results
# for line in result_lines:
#     print(line)
'''


'import numpy as np\n\ndef format_predictions(protein_ids, go_terms, y_pred, threshold=0.2, max_labels_per_protein=1500, min_threshold=0.05):\n    """\n    Formats the model predictions according to the competition rules.\n\n    Args:\n    - protein_ids (list): List of protein target IDs.\n    - go_terms (list): List of GO term labels corresponding to output nodes.\n    - y_pred (numpy.ndarray): Predicted probabilities (shape: [num_proteins, num_classes]).\n    - threshold (float): Minimum probability to include a GO term.\n    - max_labels_per_protein (int): Maximum allowed labels per protein.\n    - min_threshold (float): Lower probability bound if additional labels are needed.\n\n    Returns:\n    - list of formatted predictions (protein_id, go_term, probability)\n    """\n    formatted_results = []\n\n    for i, protein_id in enumerate(protein_ids):\n        # Get predictions and sort them in descending order\n        pred_probs = y_pred[i]\n        sorted_indices = np.argsort(-pre

In [20]:
'''
import torch
import pickle
import numpy as np

# Load saved models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_cc = torch.load("model_cc.pth", map_location=device)
model_bp = torch.load("model_bp.pth", map_location=device)
model_mf = torch.load("model_mf.pth", map_location=device)

model_cc.eval()
model_bp.eval()
model_mf.eval()

# Load MultiLabelBinarizers (MLBs) for mapping indices to GO terms
with open("mlb_cc.pkl", "rb") as f:
    mlb_cc = pickle.load(f)
with open("mlb_bp.pkl", "rb") as f:
    mlb_bp = pickle.load(f)
with open("mlb_mf.pkl", "rb") as f:
    mlb_mf = pickle.load(f)

# Function to make predictions with threshold and top-k
def predict_with_threshold(model, X_test, mlb, threshold=0.5, top_k=100):
    """ Generate GO term predictions from a model with a threshold and top-k filtering. """
    X_test = X_test.to(device)
    
    with torch.no_grad():
        probs = model(X_test).cpu().numpy()  # Get probabilities

    predictions = []
    for i, prob in enumerate(probs):
        # Select indices where probability > threshold
        high_confidence_indices = np.where(prob > threshold)[0]
        
        # Sort indices by probability in descending order
        sorted_indices = high_confidence_indices[np.argsort(prob[high_confidence_indices])[::-1]]
        
        # Keep only the top_k predictions
        top_indices = sorted_indices[:top_k]
        
        # Map indices to GO terms
        go_terms = mlb.classes_[top_indices]
        
        # Store predictions
        predictions.append(list(go_terms))
    
    return predictions, probs  # Return both terms and probabilities

# Ensure X_test_tensor is on the same device as the models
X_test_tensor = X_test_tensor.to(device)

# Get predictions from each model
predictions_cc, probs_cc = predict_with_threshold(model_cc, X_test_tensor, mlb_cc, threshold=0.5, top_k=500)
predictions_bp, probs_bp = predict_with_threshold(model_bp, X_test_tensor, mlb_bp, threshold=0.5, top_k=500)
predictions_mf, probs_mf = predict_with_threshold(model_mf, X_test_tensor, mlb_mf, threshold=0.5, top_k=500)

# Merge predictions into a dictionary
final_predictions = {}
for i, protein_id in enumerate(protein_ids):  # Ensure protein_ids matches test set order
    all_terms = set(predictions_cc[i] + predictions_bp[i] + predictions_mf[i])  # Merge GO terms
    all_probs = {term: max(probs_cc[i][mlb_cc.transform([[term]])[0][0]] if term in predictions_cc[i] else 0,
                           probs_bp[i][mlb_bp.transform([[term]])[0][0]] if term in predictions_bp[i] else 0,
                           probs_mf[i][mlb_mf.transform([[term]])[0][0]] if term in predictions_mf[i] else 0)
                 for term in all_terms}
    
    # Sort by highest probabilities and keep top 1500 GO terms per protein
    sorted_terms = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)[:1500]
    
    final_predictions[protein_id] = sorted_terms

# Save the predictions in submission format
with open("submission.txt", "w") as f:
    for protein, terms in final_predictions.items():
        for term, score in terms:
            formatted_score = round(score, 3)
            formatted_score = max(formatted_score, 0.001)  # Ensure score > 0
            f.write(f"{protein} {term} {formatted_score}\n")

print("Submission file saved as submission.txt")
'''

'\nimport torch\nimport pickle\nimport numpy as np\n\n# Load saved models\ndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")\n\nmodel_cc = torch.load("model_cc.pth", map_location=device)\nmodel_bp = torch.load("model_bp.pth", map_location=device)\nmodel_mf = torch.load("model_mf.pth", map_location=device)\n\nmodel_cc.eval()\nmodel_bp.eval()\nmodel_mf.eval()\n\n# Load MultiLabelBinarizers (MLBs) for mapping indices to GO terms\nwith open("mlb_cc.pkl", "rb") as f:\n    mlb_cc = pickle.load(f)\nwith open("mlb_bp.pkl", "rb") as f:\n    mlb_bp = pickle.load(f)\nwith open("mlb_mf.pkl", "rb") as f:\n    mlb_mf = pickle.load(f)\n\n# Function to make predictions with threshold and top-k\ndef predict_with_threshold(model, X_test, mlb, threshold=0.5, top_k=100):\n    """ Generate GO term predictions from a model with a threshold and top-k filtering. """\n    X_test = X_test.to(device)\n    \n    with torch.no_grad():\n        probs = model(X_test).cpu().numpy()  # Get probabil

In [21]:
#model for cc
import torch
import torch.nn as nn

class MultilabelNN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(MultilabelNN, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_sizes[0]),
            nn.ReLU(),
            nn.Dropout(0.5),  # Dropout for regularization
            nn.Linear(hidden_sizes[0], hidden_sizes[1]),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(hidden_sizes[1], output_size),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)

# Define device first!
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define model parameters
input_size = 1024
hidden_sizes = [4096, 4096]    
output_size = 678   

# Initialize the model
model_cc = MultilabelNN(input_size, hidden_sizes, output_size).to(device)

# Load saved weights
model_cc.load_state_dict(torch.load("model_cc.pth", map_location=device))

# Set model to evaluation mode
model_cc.eval()

print("Model loaded successfully and ready for inference!")


Model loaded successfully and ready for inference!


In [22]:
#model for mf

class MultilabelNN_mf(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(MultilabelNN_mf, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_sizes[0]),
            nn.ReLU(),
            nn.Dropout(0.4),  # Dropout for regularization
            nn.Linear(hidden_sizes[0], hidden_sizes[1]),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_sizes[1], output_size),
            nn.Sigmoid()  # For multilabel classification
        )
    
    def forward(self, x):
        return self.model(x)

# Define model parameters
input_size = 1024  # 300 features
hidden_sizes = [4096, 4096]            # Hidden layer sizes
output_size = 839    # Number of labels

# Initialize the model
model_mf = MultilabelNN_mf(input_size, hidden_sizes, output_size).to(device)

# Load saved weights
model_mf.load_state_dict(torch.load("model_mf.pth", map_location=device))

# Set model to evaluation mode
model_mf.eval()

print("Model loaded successfully and ready for inference!")


Model loaded successfully and ready for inference!


In [23]:
#model for bp
class NN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(NN, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_sizes[0]),
            nn.ReLU(),
            nn.Dropout(0.5),  # Dropout for regularization
            nn.Linear(hidden_sizes[0], hidden_sizes[1]),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(hidden_sizes[1], output_size),
            nn.Sigmoid()  # For multilabel classification
        )
    
    def forward(self, x):
        return self.model(x)
    
    # Define model parameters
input_size = 1024  # 300 features
hidden_sizes = [8192, 4096]            # Hidden layer sizes
output_size = 1487    # Number of labels

# Initialize the model
model_bp = NN(input_size, hidden_sizes, output_size).to(device)

# Load saved weights
model_bp.load_state_dict(torch.load("model_bp.pth", map_location=device))

# Set model to evaluation mode
model_bp.eval()

print("Model loaded successfully and ready for inference!")

Model loaded successfully and ready for inference!


In [24]:
import pickle

# Load MultiLabelBinarizers (MLBs) for mapping indices to GO terms
with open("mlb_cc.pkl", "rb") as f:
    mlb_cc = pickle.load(f)
with open("mlb_bp.pkl", "rb") as f:
    mlb_bp = pickle.load(f)
with open("mlb_mf.pkl", "rb") as f:
    mlb_mf = pickle.load(f)

# Precompute GO term index mappings (to avoid repeated calls to transform)
cc_go_mapping = {idx: term for idx, term in enumerate(mlb_cc.classes_)}
bp_go_mapping = {idx: term for idx, term in enumerate(mlb_bp.classes_)}
mf_go_mapping = {idx: term for idx, term in enumerate(mlb_mf.classes_)}

# Function to make predictions with threshold and top-k filtering
def predict_with_threshold(model, X_test, go_mapping, threshold=0.5, top_k=10):
    """ Generate GO term predictions from a model with a threshold and top-k filtering. """
    X_test = X_test.to(device)

    with torch.no_grad():
        probs = model(X_test).cpu().numpy()  # Get probabilities

    predictions = []
    for prob in probs:
        # Select indices where probability > threshold
        high_confidence_indices = np.where(prob > threshold)[0]

        # Sort indices by probability in descending order
        sorted_indices = high_confidence_indices[np.argsort(prob[high_confidence_indices])[::-1]]

        # Keep only the top_k predictions
        top_indices = sorted_indices[:top_k]

        # Map indices to GO terms
        go_terms = [go_mapping[idx] for idx in top_indices]

        predictions.append((go_terms, prob))  # Store GO terms and their probabilities

    return predictions

# Ensure X_test_tensor is on the same device as the models
X_test_tensor = X_test_tensor.to(device)

# Get predictions from each model
predictions_cc = predict_with_threshold(model_cc, X_test_tensor, cc_go_mapping, threshold=0.25, top_k=30)
predictions_bp = predict_with_threshold(model_bp, X_test_tensor, bp_go_mapping, threshold=0.25, top_k=30)
predictions_mf = predict_with_threshold(model_mf, X_test_tensor, mf_go_mapping, threshold=0.25, top_k=30)

# Merge predictions into a dictionary
final_predictions = {}
for i, protein_id in enumerate(protein_ids):  # Ensure protein_ids matches test set order
    all_terms = {}
    
    # Create a list of tuples for each branch along with its corresponding MLB
    branch_predictions = [
        (predictions_cc[i], mlb_cc),
        (predictions_bp[i], mlb_bp),
        (predictions_mf[i], mlb_mf)
    ]
    
    for (go_terms, probs), mlb in branch_predictions:
        for term in go_terms:
            idx = np.where(mlb.classes_ == term)[0]
            term_prob = probs[idx][0] if len(idx) > 0 else 0
            
            if term in all_terms:
                all_terms[term] = max(all_terms[term], term_prob)  # Keep max probability
            else:
                all_terms[term] = term_prob
    
    # Sort by highest probabilities and keep top 1500 GO terms per protein
    sorted_terms = sorted(all_terms.items(), key=lambda x: x[1], reverse=True)[:1500]
    
    final_predictions[protein_id] = sorted_terms

# Save the predictions in submission format
with open("submission.txt", "w") as f:
    for protein, terms in final_predictions.items():
        for term, score in terms:
            # If score is a numpy array, get the scalar value using .item()
            if isinstance(score, np.ndarray):
                formatted_score = round(score.item(), 3)
            else:
                # If score is already a scalar (int or float), directly round it
                formatted_score = round(score, 3)

            formatted_score = max(formatted_score, 0.001)  # Ensure score > 0
            f.write(f"{protein} {term} {formatted_score:.3f}\n")
print("Submission file saved as submission.txt")

# Save the predictions in TSV format
with open("submission.tsv", "w") as f:
    for protein, terms in final_predictions.items():
        for term, score in terms:
            # If score is a numpy array, get the scalar value using .item()
            if isinstance(score, np.ndarray):
                formatted_score = round(score.item(), 3)
            else:
                # If score is already a scalar (int or float), directly round it
                formatted_score = round(score, 3)

            formatted_score = max(formatted_score, 0.001)  # Ensure score > 0
            # Write the protein, GO term, and probability as a tab-separated line
            f.write(f"{protein}\t{term}\t{formatted_score:.3f}\n")

print("Submission file saved as submission.tsv")

Submission file saved as submission.txt
Submission file saved as submission.tsv
