<a href="https://colab.research.google.com/github/inshra12/ESM-MLP-baseline-protein-function/blob/main/notebooks/sequence-based-function-prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Model 1

## Importing files

In [36]:
from data_loader import load_data,encode_labels

In [37]:
!pip install fair-esm



In [38]:
from esm_embedder import save_embeddings

## Loading Data

In [39]:
df=load_data("/content/Protein_data(265).csv")

First protein entry:
 ID                                                     O00222
Sequence    MVCEGKRSASCPCFFLLTAKFYWILTMMQRTHSQEYAHSIRVDGDI...
BP          GO:0007196; GO:0007193; GO:0007216; GO:0051966...
CC                                                 GO:0005886
MF                         GO:0004930; GO:0008066; GO:0001642
Labels      [ GO:0007601, GO:0004930,  GO:0051966,  GO:000...
Name: 0, dtype: object


In [40]:
save_embeddings("/content/Protein_data(265).csv", "cls_embeddings.csv")

Extracting embeddings for 265 proteins...


100%|██████████| 34/34 [01:38<00:00,  2.89s/it]


Saved CLS embeddings to: cls_embeddings.csv


In [41]:
import pandas as pd
# Load ESM embeddings (previously extracted)
esm_df = pd.read_csv("/content/cls_embeddings.csv")  # must have column 'ID'

## Encoding multi-hot labels

In [42]:
encoded_labels, mlb = encode_labels(df["Labels"])

Number of unique GO terms: 1995


In [44]:
# Step 2: Convert encoded labels into a DataFrame
encoded_labels_df = pd.DataFrame(encoded_labels, columns=mlb.classes_)
encoded_labels_df["ID"] = df["ID"].values  # Add ID for merging
encoded_labels_df


Unnamed: 0,Unnamed: 1,GO:0000122,GO:0000139,GO:0000165,GO:0000226,GO:0000278,GO:0000578,GO:0001501,GO:0001503,GO:0001516,...,GO:0098981,GO:1901363,GO:1902093,GO:1903143,GO:1903413,GO:1904646,GO:1990409,GO:1990430,GO:1990763,ID
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,O00222
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,O14842
2,0,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,O15303
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,O43603
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,O43613
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
260,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Q9Y271
261,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Q9Y2T5
262,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Q9Y2T6
263,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Q9Y5N1


## Merge sequence embedding and labels in one file

In [45]:
# Step 3: Merge ESM embeddings (esm_df) with labels
final_df = pd.merge(esm_df, encoded_labels_df, on="ID")



In [46]:
# Move "ID" column to the front
cols = list(final_df.columns)
cols.remove("ID")
new_order = ["ID"] + cols
final_df = final_df[new_order]


In [47]:
# Step 4: Save final_df to CSV
final_df.to_csv("esm_with_multihot_labels.csv", index=False)
print("Saved: esm_with_multihot_labels.csv")

Saved: esm_with_multihot_labels.csv


## Splitting Column in X(Sequence) and Y(labels)

In [None]:
# Get all column names
all_columns = final_df.columns

# GO columns start from index 1281 onwards
go_columns = all_columns[1281:]

# Now define X and Y
X = final_df.drop(columns=["ID"] + list(go_columns))  # Features (1280 ESM cols)
Y = final_df[list(go_columns)]                        # Labels (GO terms)




In [48]:
X.to_csv("X_esm.csv", index=False)
Y.to_csv("Y_go.csv", index=False)

In [49]:
from sklearn.model_selection import train_test_split

# Split the data
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

print("✅ Training samples:", len(X_train))
print("✅ Test samples:", len(X_test))





✅ Training samples: 212
✅ Test samples: 53


In [50]:
import torch
from torch.utils.data import TensorDataset, DataLoader

# Convert DataFrames to tensors
X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
Y_train_tensor = torch.tensor(Y_train.values, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)
Y_test_tensor = torch.tensor(Y_test.values, dtype=torch.float32)

# Create datasets
train_dataset = TensorDataset(X_train_tensor, Y_train_tensor)
val_dataset = TensorDataset(X_test_tensor, Y_test_tensor)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)


In [51]:
from MLP_Model import build_mlp_model

In [54]:
# Define sizes
input_size = 1280       # ESM embedding size
hidden_size = 512       # You can adjust this
output_size = 653       # Number of GO term labels

In [78]:
# Build the model
model = build_mlp_model(input_size, hidden_size, output_size)

# Check summary
print(model)

Sequential(
  (0): Linear(in_features=1280, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=653, bias=True)
  (4): Sigmoid()
)


In [56]:
from train_MLP import train_model,validate_model

In [57]:
print(Y_train.shape[1])

653


In [79]:
train_model(model, train_loader, val_loader, epochs=20, lr=0.001)

Epoch 1/20, Loss: 0.5575
Validation Macro F1 Score: 0.0000
Epoch 2/20, Loss: 0.1381
Validation Macro F1 Score: 0.0000
Epoch 3/20, Loss: 0.0323
Validation Macro F1 Score: 0.0000
Epoch 4/20, Loss: 0.0364
Validation Macro F1 Score: 0.0000
Epoch 5/20, Loss: 0.0398
Validation Macro F1 Score: 0.0000
Epoch 6/20, Loss: 0.0398
Validation Macro F1 Score: 0.0000
Epoch 7/20, Loss: 0.0380
Validation Macro F1 Score: 0.0000
Epoch 8/20, Loss: 0.0353
Validation Macro F1 Score: 0.0000
Epoch 9/20, Loss: 0.0325
Validation Macro F1 Score: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 10/20, Loss: 0.0307
Validation Macro F1 Score: 0.0000
Epoch 11/20, Loss: 0.0294
Validation Macro F1 Score: 0.0000
Epoch 12/20, Loss: 0.0287
Validation Macro F1 Score: 0.0000
Epoch 13/20, Loss: 0.0284
Validation Macro F1 Score: 0.0000
Epoch 14/20, Loss: 0.0284
Validation Macro F1 Score: 0.0000
Epoch 15/20, Loss: 0.0284
Validation Macro F1 Score: 0.0000
Epoch 16/20, Loss: 0.0283
Validation Macro F1 Score: 0.0000
Epoch 17/20, Loss: 0.0283
Validation Macro F1 Score: 0.0000
Epoch 18/20, Loss: 0.0283
Validation Macro F1 Score: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 19/20, Loss: 0.0282
Validation Macro F1 Score: 0.0000
Epoch 20/20, Loss: 0.0284
Validation Macro F1 Score: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
