<a href="https://colab.research.google.com/github/inshra12/ESM-MLP-baseline-protein-function/blob/main/notebooks/sequence-based-function-prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Model 1

## Importing files

In [1]:
from data_loader import load_data,encode_labels

In [2]:
!pip install fair-esm

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


In [3]:
from esm_embedder import save_embeddings

## Loading Data

In [4]:
df=load_data("/content/Protein_data(265).csv")

First protein entry:
 ID                                                     O00222
Sequence    MVCEGKRSASCPCFFLLTAKFYWILTMMQRTHSQEYAHSIRVDGDI...
BP          GO:0007196; GO:0007193; GO:0007216; GO:0051966...
CC                                                 GO:0005886
MF                         GO:0004930; GO:0008066; GO:0001642
Labels      [ GO:0001642,  GO:0007601, GO:0007196, GO:0004...
Name: 0, dtype: object


In [5]:
#save_embeddings("/content/Protein_data(265).csv", "cls_embeddings.csv")

In [6]:
import pandas as pd
# Load ESM embeddings (previously extracted)
esm_df = pd.read_csv("/content/cls_embeddings.csv")  # must have column 'ID'

## Encoding multi-hot labels

In [7]:
encoded_labels, mlb = encode_labels(df["Labels"])

Number of unique GO terms: 1995


In [8]:
# Step 2: Convert encoded labels into a DataFrame
encoded_labels_df = pd.DataFrame(encoded_labels, columns=mlb.classes_)
encoded_labels_df["ID"] = df["ID"]  # Add the ID column
# Move "ID" column to the front
cols = list(encoded_labels_df.columns)
cols.remove("ID")
new_order = ["ID"] + cols
final_df = encoded_labels_df[new_order]
final_df = final_df.loc[:, final_df.columns != '']
final_df

Unnamed: 0,ID,GO:0000122,GO:0000139,GO:0000165,GO:0000226,GO:0000278,GO:0000578,GO:0001501,GO:0001503,GO:0001516,...,GO:0098978,GO:0098981,GO:1901363,GO:1902093,GO:1903143,GO:1903413,GO:1904646,GO:1990409,GO:1990430,GO:1990763
0,O00222,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,O14842,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,O15303,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,O43603,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,O43613,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
260,Q9Y271,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
261,Q9Y2T5,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
262,Q9Y2T6,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
263,Q9Y5N1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


## Merge sequence embedding and labels in one file

In [9]:
# Drop the ID column (non-binary) and sum over the rows
go_term_counts = final_df.drop(columns=["ID"]).sum(axis=0).sort_values(ascending=False)
print(go_term_counts)  # Show top 10 most frequent GO terms


 GO:0005886    206
 GO:0007186    110
 GO:0007200     74
 GO:0007204     72
 GO:0007187     56
              ... 
GO:0045453       1
GO:0045178       1
GO:0044298       1
 GO:0006468      1
GO:0097746       1
Length: 1994, dtype: int64


In [10]:
# Filter GO terms with at least 3 positives (i.e., sum >= 3)
valid_go_terms = go_term_counts[go_term_counts >= 40].index

# Keep only those columns + the ID
filtered_df = final_df[["ID"] + valid_go_terms.tolist()]


In [11]:
filtered_df

Unnamed: 0,ID,GO:0005886,GO:0007186,GO:0007200,GO:0007204,GO:0007187,GO:0016020,GO:0004930,GO:0030425,GO:0007268,GO:0007189,GO:0045202,GO:0005886.1,GO:0007218,GO:0006954
0,O00222,0,0,0,0,0,0,0,0,0,0,0,1,0,0
1,O14842,0,0,1,1,0,0,1,0,0,0,0,1,0,0
2,O15303,1,0,0,0,0,0,1,0,0,0,1,0,0,0
3,O43603,1,0,1,1,0,1,0,0,0,1,0,0,1,0
4,O43613,0,0,0,0,0,0,0,0,1,0,1,1,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
260,Q9Y271,1,0,0,1,0,0,0,0,0,0,0,0,1,0
261,Q9Y2T5,0,1,0,0,0,0,1,0,0,0,0,1,0,0
262,Q9Y2T6,0,1,1,0,0,0,1,0,0,0,0,1,0,0
263,Q9Y5N1,1,0,0,0,1,0,0,0,1,0,1,0,0,0


In [12]:
# Step 3: Merge ESM embeddings (esm_df) with labels
final_df = pd.merge(esm_df, filtered_df, on="ID")



In [13]:
final_df

Unnamed: 0,ID,0,1,2,3,4,5,6,7,8,...,GO:0007187,GO:0016020,GO:0004930,GO:0030425,GO:0007268,GO:0007189,GO:0045202,GO:0005886,GO:0007218,GO:0006954
0,O00222,0.053351,-0.016727,0.095342,0.019120,-0.040266,0.029943,0.034857,-0.065425,-0.119622,...,0,0,0,0,0,0,0,1,0,0
1,O14842,0.033164,0.003283,0.070440,0.017247,-0.000506,-0.001080,0.072019,-0.019934,-0.002518,...,0,0,1,0,0,0,0,1,0,0
2,O15303,0.046058,-0.003895,0.107735,-0.043677,-0.036064,0.051198,0.008938,-0.091502,-0.082031,...,0,0,1,0,0,0,1,0,0,0
3,O43603,0.026092,0.021502,0.108253,-0.051013,-0.084989,-0.003950,0.036050,-0.097860,-0.047290,...,0,1,0,0,0,1,0,0,1,0
4,O43613,0.011840,0.010907,0.097771,-0.025915,-0.035944,0.000257,0.034860,-0.074876,-0.053891,...,0,0,0,0,1,0,1,1,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
260,Q9Y271,0.028590,0.042119,0.078044,0.041247,-0.061766,-0.046561,0.032011,-0.080206,-0.087077,...,0,0,0,0,0,0,0,0,1,0
261,Q9Y2T5,0.043165,0.018314,0.109320,0.002341,-0.042691,-0.040927,0.049685,-0.068692,-0.097311,...,0,0,1,0,0,0,0,1,0,0
262,Q9Y2T6,0.044411,0.027693,0.063553,0.027069,-0.050945,-0.046612,0.042627,-0.076333,-0.078617,...,0,0,1,0,0,0,0,1,0,0
263,Q9Y5N1,0.003747,-0.051823,0.133492,0.011327,-0.057135,0.037724,0.079348,-0.107627,-0.037733,...,1,0,0,0,1,0,1,0,0,0


In [14]:
# Step 4: Save final_df to CSV
final_df.to_csv("esm_with_multihot_labels.csv", index=False)
#print("Saved: esm_with_multihot_labels.csv")

## Splitting Column in X(Sequence) and Y(labels)

In [15]:

all_columns = final_df.columns

# GO columns start from index 1281 onwards
#go_columns = all_columns[1281:]

# Now define X and Y
#X = final_df.drop(columns=["ID"] + list(go_columns))  # Features (1280 ESM cols)
#Y = final_df[list(go_columns)]

                    # Labels (GO terms)
all_columns

Index(['ID', '0', '1', '2', '3', '4', '5', '6', '7', '8',
       ...
       ' GO:0007187', ' GO:0016020', ' GO:0004930', ' GO:0030425',
       ' GO:0007268', 'GO:0007189', ' GO:0045202', 'GO:0005886', ' GO:0007218',
       ' GO:0006954'],
      dtype='object', length=1295)

In [16]:
# Recalculate GO columns
go_columns = [col for col in final_df.columns if col.startswith("GO:")]
columns = all_columns[1281:]
X = final_df.drop(columns=["ID"] + list(columns))
# Extract labels
Y = final_df[columns]

In [17]:
Y

Unnamed: 0,GO:0005886,GO:0007186,GO:0007200,GO:0007204,GO:0007187,GO:0016020,GO:0004930,GO:0030425,GO:0007268,GO:0007189,GO:0045202,GO:0005886.1,GO:0007218,GO:0006954
0,0,0,0,0,0,0,0,0,0,0,0,1,0,0
1,0,0,1,1,0,0,1,0,0,0,0,1,0,0
2,1,0,0,0,0,0,1,0,0,0,1,0,0,0
3,1,0,1,1,0,1,0,0,0,1,0,0,1,0
4,0,0,0,0,0,0,0,0,1,0,1,1,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
260,1,0,0,1,0,0,0,0,0,0,0,0,1,0
261,0,1,0,0,0,0,1,0,0,0,0,1,0,0
262,0,1,1,0,0,0,1,0,0,0,0,1,0,0
263,1,0,0,0,1,0,0,0,1,0,1,0,0,0


In [18]:
#X.to_csv("X_esm.csv", index=False)
#Y.to_csv("Y_go.csv", index=False)

In [19]:
from sklearn.model_selection import train_test_split

# Split the data
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

print("✅ Training samples:", len(X_train))
print("✅ Test samples:", len(X_test))





✅ Training samples: 212
✅ Test samples: 53


In [20]:
import torch
from torch.utils.data import TensorDataset, DataLoader

# Convert DataFrames to tensors
X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
Y_train_tensor = torch.tensor(Y_train.values, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)
Y_test_tensor = torch.tensor(Y_test.values, dtype=torch.float32)

# Create datasets
train_dataset = TensorDataset(X_train_tensor, Y_train_tensor)
val_dataset = TensorDataset(X_test_tensor, Y_test_tensor)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)


In [21]:
from MLP_Model import build_mlp_model
from train_MLP import train_model


In [22]:
# Define sizes
input_size = 1280      # ESM embedding size
hidden_size = 512      # You can adjust this
output_size = Y_train.shape[1]


  # Number of GO term labels


In [23]:
output_size

14

In [24]:
# Build the model
model = build_mlp_model(input_size, hidden_size, output_size)

# Check summary
print(model)

Sequential(
  (0): Linear(in_features=1280, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=14, bias=True)
  (4): Sigmoid()
)


In [25]:
print(Y_train.shape[1])

14


In [26]:
train_model(model, train_loader, val_loader, epochs=20, lr=0.001, threshold=0.3)


Epoch 1/20, Loss: 0.5816, Validation Macro F1 Score: 0.1055
Epoch 2/20, Loss: 0.5189, Validation Macro F1 Score: 0.1055
Epoch 3/20, Loss: 0.5030, Validation Macro F1 Score: 0.1632
Epoch 4/20, Loss: 0.4976, Validation Macro F1 Score: 0.1317
Epoch 5/20, Loss: 0.4973, Validation Macro F1 Score: 0.1055
Epoch 6/20, Loss: 0.4942, Validation Macro F1 Score: 0.1356
Epoch 7/20, Loss: 0.4940, Validation Macro F1 Score: 0.1681
Epoch 8/20, Loss: 0.4908, Validation Macro F1 Score: 0.1627
Epoch 9/20, Loss: 0.4819, Validation Macro F1 Score: 0.1246
Epoch 10/20, Loss: 0.4816, Validation Macro F1 Score: 0.1341
Epoch 11/20, Loss: 0.4760, Validation Macro F1 Score: 0.1750
Epoch 12/20, Loss: 0.4711, Validation Macro F1 Score: 0.1892
Epoch 13/20, Loss: 0.4626, Validation Macro F1 Score: 0.1654
Epoch 14/20, Loss: 0.4627, Validation Macro F1 Score: 0.2087
Epoch 15/20, Loss: 0.4587, Validation Macro F1 Score: 0.3244
Epoch 16/20, Loss: 0.4568, Validation Macro F1 Score: 0.3037
Epoch 17/20, Loss: 0.4539, Valida