# Model 2-MNN

## Importing files

In [61]:
# 1. Importing files and custom model
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

In [62]:
# Import the MultiInputNet class from your new file
from MNN_Model import MultiInputNet

In [63]:
# Your custom data and training functions
from data_loader import load_data, encode_labels
from esm_embedder import save_embeddings
from train_MLP import train_model

In [64]:
# 2. Loading Data
df = load_data("/content/Protein_data(265).csv")
esm_df = pd.read_csv("/content/cls_embeddings.csv")
static_df = pd.read_csv("/content/static_features.csv")

First protein entry:
 ID                                                     O00222
Sequence    MVCEGKRSASCPCFFLLTAKFYWILTMMQRTHSQEYAHSIRVDGDI...
BP          GO:0007196; GO:0007193; GO:0007216; GO:0051966...
CC                                                 GO:0005886
MF                         GO:0004930; GO:0008066; GO:0001642
Labels      [GO:0004930,  GO:0008066,  GO:0007216, GO:0005...
Name: 0, dtype: object


In [65]:
# 3. Encoding labels and filtering
encoded_labels, mlb = encode_labels(df["Labels"])
encoded_labels_df = pd.DataFrame(encoded_labels, columns=mlb.classes_).set_index(df["ID"])
go_term_counts = encoded_labels_df.sum(axis=0).sort_values(ascending=False)
valid_go_terms = go_term_counts[go_term_counts >= 40].index
filtered_labels_df = encoded_labels_df[valid_go_terms].reset_index()

Number of unique GO terms: 1995


In [66]:
# 4. Prepare ESM and Static features
go_columns = [col for col in filtered_labels_df.columns if col.startswith("GO:")]
go_labels = filtered_labels_df[go_columns]

In [67]:
esm_features = pd.merge(esm_df, filtered_labels_df["ID"], on="ID").drop(columns=["ID"])
static_features = pd.merge(static_df, filtered_labels_df["ID"], on="ID").drop(columns=["ID"])


In [68]:
# 5. Data Splitting and DataLoader Creation
X_esm_train, X_esm_test, _, _ = train_test_split(esm_features, go_labels, test_size=0.2, random_state=42)
X_static_train, X_static_test, Y_train, Y_test = train_test_split(static_features, go_labels, test_size=0.2, random_state=42)

X_esm_train_tensor = torch.tensor(X_esm_train.values, dtype=torch.float32)
X_static_train_tensor = torch.tensor(X_static_train.values, dtype=torch.float32)
Y_train_tensor = torch.tensor(Y_train.values, dtype=torch.float32)

X_esm_test_tensor = torch.tensor(X_esm_test.values, dtype=torch.float32)
X_static_test_tensor = torch.tensor(X_static_test.values, dtype=torch.float32)
Y_test_tensor = torch.tensor(Y_test.values, dtype=torch.float32)

train_dataset = TensorDataset(X_esm_train_tensor, X_static_train_tensor, Y_train_tensor)
val_dataset = TensorDataset(X_esm_test_tensor, X_static_test_tensor, Y_test_tensor)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)


In [69]:
# 6. Model Instantiation and Training
esm_input_size = X_esm_train.shape[1]
static_input_size = X_static_train.shape[1]
hidden_size = 1024
output_size = Y_train.shape[1]

In [70]:

# Instantiate the model from the imported class
model = MultiInputNet(esm_input_size, static_input_size, hidden_size, output_size)
print(model)

MultiInputNet(
  (esm_branch): Sequential(
    (0): Linear(in_features=1280, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (static_branch): Sequential(
    (0): Linear(in_features=18, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=1536, out_features=2, bias=True)
  )
)


In [71]:
from model_multi_input import train_model_multi_input

In [72]:
# Run the training
# Note: You must update train_MLP.py to handle two inputs, as shown previously
train_model_multi_input(model, train_loader, val_loader, epochs=90, lr=0.001, threshold=0.3)

Calculating class weights...
Epoch 1/90, Loss: 1.1824, Validation Macro F1: 0.3631, Validation Micro F1: 0.4138
Epoch 2/90, Loss: 0.6084, Validation Macro F1: 0.3500, Validation Micro F1: 0.3404
Epoch 3/90, Loss: 0.5350, Validation Macro F1: 0.3577, Validation Micro F1: 0.3469
Epoch 4/90, Loss: 0.4807, Validation Macro F1: 0.3926, Validation Micro F1: 0.3736
Epoch 5/90, Loss: 0.4073, Validation Macro F1: 0.3959, Validation Micro F1: 0.3765
Epoch 6/90, Loss: 0.3599, Validation Macro F1: 0.5111, Validation Micro F1: 0.5079
Epoch 7/90, Loss: 0.3590, Validation Macro F1: 0.5595, Validation Micro F1: 0.5769
Epoch 8/90, Loss: 0.3060, Validation Macro F1: 0.5357, Validation Micro F1: 0.5357
Epoch 9/90, Loss: 0.2727, Validation Macro F1: 0.6179, Validation Micro F1: 0.6522
Epoch 10/90, Loss: 0.2494, Validation Macro F1: 0.6453, Validation Micro F1: 0.6818
Epoch 11/90, Loss: 0.2511, Validation Macro F1: 0.6753, Validation Micro F1: 0.7143
Epoch 12/90, Loss: 0.2110, Validation Macro F1: 0.5793, 

## Loading Data

## Encoding multi-hot labels

## Merge sequence embedding and labels in one file