<a href="https://colab.research.google.com/github/inshra12/ESM-MLP-baseline-protein-function/blob/main/notebooks/sequence-based-function-prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Model 2-MLP

## Importing files

In [136]:
from data_loader import load_data,encode_labels

In [137]:
!pip install fair-esm



In [138]:
from esm_embedder import save_embeddings

## Loading Data

In [139]:
df=load_data("/content/Protein_data(265).csv")

First protein entry:
 ID                                                     O00222
Sequence    MVCEGKRSASCPCFFLLTAKFYWILTMMQRTHSQEYAHSIRVDGDI...
BP          GO:0007196; GO:0007193; GO:0007216; GO:0051966...
CC                                                 GO:0005886
MF                         GO:0004930; GO:0008066; GO:0001642
Labels      [GO:0005886,  GO:0007601,  GO:0008066,  GO:000...
Name: 0, dtype: object


In [140]:
#save_embeddings("/content/Protein_data(265).csv", "cls_embeddings.csv")

In [141]:
import pandas as pd
# Load ESM embeddings (previously extracted)
esm_df = pd.read_csv("/content/cls_embeddings.csv")  # must have column 'ID'

## Encoding multi-hot labels

In [142]:
encoded_labels, mlb = encode_labels(df["Labels"])

Number of unique GO terms: 1995


In [143]:
# Step 2: Convert encoded labels into a DataFrame
encoded_labels_df = pd.DataFrame(encoded_labels, columns=mlb.classes_)
encoded_labels_df["ID"] = df["ID"]  # Add the ID column
# Move "ID" column to the front
cols = list(encoded_labels_df.columns)
cols.remove("ID")
new_order = ["ID"] + cols
final_df = encoded_labels_df[new_order]
final_df = final_df.loc[:, final_df.columns != '']
final_df

Unnamed: 0,ID,GO:0000122,GO:0000139,GO:0000165,GO:0000226,GO:0000278,GO:0000578,GO:0001501,GO:0001503,GO:0001516,...,GO:0098978,GO:0098981,GO:1901363,GO:1902093,GO:1903143,GO:1903413,GO:1904646,GO:1990409,GO:1990430,GO:1990763
0,O00222,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,O14842,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,O15303,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,O43603,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,O43613,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
260,Q9Y271,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
261,Q9Y2T5,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
262,Q9Y2T6,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
263,Q9Y5N1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


## Merge sequence embedding and labels in one file

In [144]:
# Drop the ID column (non-binary) and sum over the rows
go_term_counts = final_df.drop(columns=["ID"]).sum(axis=0).sort_values(ascending=False)
print(go_term_counts)  # Show top 10 most frequent GO terms


 GO:0005886    206
 GO:0007186    110
 GO:0007200     74
 GO:0007204     72
 GO:0007187     56
              ... 
GO:0045453       1
GO:0045178       1
GO:0044298       1
 GO:0006468      1
GO:0097746       1
Length: 1994, dtype: int64


In [145]:
# Filter GO terms with at least 3 positives (i.e., sum >= 3)
valid_go_terms = go_term_counts[go_term_counts >= 40].index

# Keep only those columns + the ID
filtered_df = final_df[["ID"] + valid_go_terms.tolist()]


In [146]:
filtered_df

Unnamed: 0,ID,GO:0005886,GO:0007186,GO:0007200,GO:0007204,GO:0007187,GO:0016020,GO:0004930,GO:0030425,GO:0007268,GO:0007189,GO:0045202,GO:0005886.1,GO:0007218,GO:0006954
0,O00222,0,0,0,0,0,0,0,0,0,0,0,1,0,0
1,O14842,0,0,1,1,0,0,1,0,0,0,0,1,0,0
2,O15303,1,0,0,0,0,0,1,0,0,0,1,0,0,0
3,O43603,1,0,1,1,0,1,0,0,0,1,0,0,1,0
4,O43613,0,0,0,0,0,0,0,0,1,0,1,1,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
260,Q9Y271,1,0,0,1,0,0,0,0,0,0,0,0,1,0
261,Q9Y2T5,0,1,0,0,0,0,1,0,0,0,0,1,0,0
262,Q9Y2T6,0,1,1,0,0,0,1,0,0,0,0,1,0,0
263,Q9Y5N1,1,0,0,0,1,0,0,0,1,0,1,0,0,0


In [147]:
# Step 3: Merge ESM embeddings (esm_df) with labels
final_df = pd.merge(esm_df, filtered_df, on="ID")



In [148]:
final_df

Unnamed: 0,ID,0,1,2,3,4,5,6,7,8,...,GO:0007187,GO:0016020,GO:0004930,GO:0030425,GO:0007268,GO:0007189,GO:0045202,GO:0005886,GO:0007218,GO:0006954
0,O00222,0.053351,-0.016727,0.095342,0.019120,-0.040266,0.029943,0.034857,-0.065425,-0.119622,...,0,0,0,0,0,0,0,1,0,0
1,O14842,0.033164,0.003283,0.070440,0.017247,-0.000506,-0.001080,0.072019,-0.019934,-0.002518,...,0,0,1,0,0,0,0,1,0,0
2,O15303,0.046058,-0.003895,0.107735,-0.043677,-0.036064,0.051198,0.008938,-0.091502,-0.082031,...,0,0,1,0,0,0,1,0,0,0
3,O43603,0.026092,0.021502,0.108253,-0.051013,-0.084989,-0.003950,0.036050,-0.097860,-0.047290,...,0,1,0,0,0,1,0,0,1,0
4,O43613,0.011840,0.010907,0.097771,-0.025915,-0.035944,0.000257,0.034860,-0.074876,-0.053891,...,0,0,0,0,1,0,1,1,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
260,Q9Y271,0.028590,0.042119,0.078044,0.041247,-0.061766,-0.046561,0.032011,-0.080206,-0.087077,...,0,0,0,0,0,0,0,0,1,0
261,Q9Y2T5,0.043165,0.018314,0.109320,0.002341,-0.042691,-0.040927,0.049685,-0.068692,-0.097311,...,0,0,1,0,0,0,0,1,0,0
262,Q9Y2T6,0.044411,0.027693,0.063553,0.027069,-0.050945,-0.046612,0.042627,-0.076333,-0.078617,...,0,0,1,0,0,0,0,1,0,0
263,Q9Y5N1,0.003747,-0.051823,0.133492,0.011327,-0.057135,0.037724,0.079348,-0.107627,-0.037733,...,1,0,0,0,1,0,1,0,0,0


In [149]:
# Step 4: Save final_df to CSV
final_df.to_csv("esm_with_multihot_labels.csv", index=False)
#print("Saved: esm_with_multihot_labels.csv")

## Splitting Column in X(Sequence) and Y(labels)

In [150]:

all_columns = final_df.columns

# GO columns start from index 1281 onwards
#go_columns = all_columns[1281:]

# Now define X and Y
#X = final_df.drop(columns=["ID"] + list(go_columns))  # Features (1280 ESM cols)
#Y = final_df[list(go_columns)]

                    # Labels (GO terms)
all_columns

Index(['ID', '0', '1', '2', '3', '4', '5', '6', '7', '8',
       ...
       ' GO:0007187', ' GO:0016020', ' GO:0004930', ' GO:0030425',
       ' GO:0007268', 'GO:0007189', ' GO:0045202', 'GO:0005886', ' GO:0007218',
       ' GO:0006954'],
      dtype='object', length=1295)

In [151]:
# Recalculate GO columns
go_columns = [col for col in final_df.columns if col.startswith("GO:")]
columns = all_columns[1281:]
X = final_df.drop(columns=["ID"] + list(columns))
# Extract labels
Y = final_df[columns]

In [152]:
X

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1270,1271,1272,1273,1274,1275,1276,1277,1278,1279
0,0.053351,-0.016727,0.095342,0.019120,-0.040266,0.029943,0.034857,-0.065425,-0.119622,0.084034,...,-0.038261,-0.152742,-0.250213,-0.019936,-0.046279,-0.066123,-0.053912,-0.319778,0.154516,0.027740
1,0.033164,0.003283,0.070440,0.017247,-0.000506,-0.001080,0.072019,-0.019934,-0.002518,0.146560,...,-0.050757,-0.043355,-0.268326,0.007200,0.033871,-0.038196,-0.078597,-0.243790,0.139195,0.009131
2,0.046058,-0.003895,0.107735,-0.043677,-0.036064,0.051198,0.008938,-0.091502,-0.082031,0.126929,...,-0.083423,-0.134998,-0.274291,0.017472,-0.019652,-0.099833,-0.059169,-0.263769,0.147156,0.046976
3,0.026092,0.021502,0.108253,-0.051013,-0.084989,-0.003950,0.036050,-0.097860,-0.047290,0.132933,...,-0.095362,-0.115929,-0.210134,0.040711,0.048382,-0.081693,-0.079056,-0.241364,0.149850,-0.005024
4,0.011840,0.010907,0.097771,-0.025915,-0.035944,0.000257,0.034860,-0.074876,-0.053891,0.112991,...,-0.064174,-0.119595,-0.222651,0.037250,0.032173,-0.075750,-0.087266,-0.271265,0.127608,0.014359
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
260,0.028590,0.042119,0.078044,0.041247,-0.061766,-0.046561,0.032011,-0.080206,-0.087077,0.100788,...,-0.056868,-0.127857,-0.184754,-0.037019,0.031118,-0.029139,-0.077402,-0.247564,0.120906,-0.037632
261,0.043165,0.018314,0.109320,0.002341,-0.042691,-0.040927,0.049685,-0.068692,-0.097311,0.060280,...,0.009813,-0.106530,-0.155991,0.018616,0.009465,-0.007575,-0.045652,-0.269200,0.099570,-0.051747
262,0.044411,0.027693,0.063553,0.027069,-0.050945,-0.046612,0.042627,-0.076333,-0.078617,0.131539,...,-0.071286,-0.081688,-0.213624,-0.051272,0.036905,-0.028816,-0.089090,-0.246666,0.101319,-0.016196
263,0.003747,-0.051823,0.133492,0.011327,-0.057135,0.037724,0.079348,-0.107627,-0.037733,0.147743,...,-0.092034,-0.139131,-0.279851,0.037903,-0.047071,-0.094803,-0.119312,-0.316944,0.140985,0.022448


In [153]:
#X.to_csv("X_esm.csv", index=False)
#Y.to_csv("Y_go.csv", index=False)

In [154]:
from sklearn.model_selection import train_test_split

# Split the data
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

print("✅ Training samples:", len(X_train))
print("✅ Test samples:", len(X_test))





✅ Training samples: 212
✅ Test samples: 53


In [155]:
import torch
from torch.utils.data import TensorDataset, DataLoader

# Convert DataFrames to tensors
X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
Y_train_tensor = torch.tensor(Y_train.values, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)
Y_test_tensor = torch.tensor(Y_test.values, dtype=torch.float32)

# Create datasets
train_dataset = TensorDataset(X_train_tensor, Y_train_tensor)
val_dataset = TensorDataset(X_test_tensor, Y_test_tensor)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)


In [156]:
from MLP_Model import build_mlp_model
from train_MLP import train_model


In [157]:
# Define sizes
input_size = 1280      # ESM embedding size
hidden_size = 1024      # You can adjust this
output_size = Y_train.shape[1]


  # Number of GO term labels


In [158]:
output_size

14

In [159]:
# Build the model
model = build_mlp_model(input_size, hidden_size, output_size)

# Check summary
print(model)

Sequential(
  (0): Linear(in_features=1280, out_features=1024, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=1024, out_features=14, bias=True)
  (4): Sigmoid()
)


In [160]:
print(Y_train.shape[1])

14


In [162]:
train_model(model, train_loader, val_loader, epochs=90, lr=0.001, threshold=0.3)


Epoch 1/90, Loss: 0.3025, Validation Macro F1: 0.5578, Validation Micro F1: 0.6265
Epoch 2/90, Loss: 0.2949, Validation Macro F1: 0.5924, Validation Micro F1: 0.6635
Epoch 3/90, Loss: 0.2892, Validation Macro F1: 0.6202, Validation Micro F1: 0.6747
Epoch 4/90, Loss: 0.2883, Validation Macro F1: 0.5895, Validation Micro F1: 0.6602
Epoch 5/90, Loss: 0.2831, Validation Macro F1: 0.6251, Validation Micro F1: 0.6794
Epoch 6/90, Loss: 0.2881, Validation Macro F1: 0.5872, Validation Micro F1: 0.6616
Epoch 7/90, Loss: 0.2819, Validation Macro F1: 0.6112, Validation Micro F1: 0.6606
Epoch 8/90, Loss: 0.2804, Validation Macro F1: 0.6081, Validation Micro F1: 0.6700
Epoch 9/90, Loss: 0.2793, Validation Macro F1: 0.6081, Validation Micro F1: 0.6651
Epoch 10/90, Loss: 0.2796, Validation Macro F1: 0.6167, Validation Micro F1: 0.6714
Epoch 11/90, Loss: 0.2752, Validation Macro F1: 0.6058, Validation Micro F1: 0.6732
Epoch 12/90, Loss: 0.2703, Validation Macro F1: 0.6231, Validation Micro F1: 0.6781
E