<a href="https://colab.research.google.com/github/inshra12/ESM-MLP-baseline-protein-function/blob/main/notebooks/sequence-based-function-prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Model 1

## Importing files

In [35]:
from data_loader import load_data,encode_labels

In [36]:
!pip install fair-esm



In [37]:
from esm_embedder import save_embeddings

## Loading Data

In [38]:
df=load_data("/content/raw_sequence_data.csv")

First protein entry:
 ID                                                     P14416
Sequence    MDPLNLSWYDDDLERQNWSRPFNGSDGKADRPHYNYYATLLTLLIA...
MF          GO:0035240;GO:0001591;GO:0004930; GO:0001965;G...
CC          GO:0001669; GO:0030424; GO:0043679; GO:0060170...
BP          GO:0046717; GO:0021984; GO:0007195; GO:0007628...
Labels      [ GO:0043679,  GO:0001975,  GO:0032228,  GO:00...
Name: 0, dtype: object


In [39]:
#save_embeddings("/content/raw_sequence_data.csv", "cls_embeddings.csv")

In [40]:
import pandas as pd
# Load ESM embeddings (previously extracted)
esm_df = pd.read_csv("/content/cls_embeddings.csv")  # must have column 'ID'

## Encoding multi-hot labels

In [41]:
encoded_labels, mlb = encode_labels(df["Labels"])

Number of unique GO terms: 528


In [42]:
# Step 2: Convert encoded labels into a DataFrame
encoded_labels_df = pd.DataFrame(encoded_labels, columns=mlb.classes_)
encoded_labels_df["ID"] = df["ID"].values  # Add ID for merging

## Merge sequence embedding and labels in one file

In [43]:
# Step 3: Merge ESM embeddings (esm_df) with labels
final_df = pd.merge(esm_df, encoded_labels_df, on="ID")

In [44]:
# Move "ID" column to the front
cols = list(final_df.columns)
cols.remove("ID")
new_order = ["ID"] + cols
final_df = final_df[new_order]


In [45]:
# Step 4: Save final_df to CSV
final_df.to_csv("esm_with_multihot_labels.csv", index=False)
print("Saved: esm_with_multihot_labels.csv")

Saved: esm_with_multihot_labels.csv


## Splitting Column in X(Sequence) and Y(labels)

In [46]:
# Get all column names
all_columns = final_df.columns

# GO columns start from index 1281 onwards
go_columns = all_columns[1281:]

# Now define X and Y
X = final_df.drop(columns=["ID"] + list(go_columns))  # Features (1280 ESM cols)
Y = final_df[list(go_columns)]                        # Labels (GO terms)


In [47]:
X.to_csv("X_esm.csv", index=False)
Y.to_csv("Y_go.csv", index=False)

In [48]:
from sklearn.model_selection import train_test_split

# Split the data
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

print("✅ Training samples:", len(X_train))
print("✅ Test samples:", len(X_test))





✅ Training samples: 16
✅ Test samples: 4


In [49]:
import torch
from torch.utils.data import TensorDataset, DataLoader

# Convert DataFrames to tensors
X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
Y_train_tensor = torch.tensor(Y_train.values, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)
Y_test_tensor = torch.tensor(Y_test.values, dtype=torch.float32)

# Create datasets
train_dataset = TensorDataset(X_train_tensor, Y_train_tensor)
val_dataset = TensorDataset(X_test_tensor, Y_test_tensor)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)


In [50]:
from MLP_Model import build_mlp_model

In [51]:
# Define sizes
input_size = 1280       # ESM embedding size
hidden_size = 512       # You can adjust this
output_size = 528       # Number of GO term labels

In [52]:
# Build the model
model = build_mlp_model(input_size, hidden_size, output_size)

# Check summary
print(model)

Sequential(
  (0): Linear(in_features=1280, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=528, bias=True)
  (4): Sigmoid()
)


In [53]:
from train_MLP import train_model,validate_model

In [54]:
print(Y_train.shape[1])

528


In [None]:
train_model(model, train_loader, val_loader, epochs=20, lr=0.001)