In [17]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt


# Import from your Model folder
from Model.model import thermalMLP, ProteinTmDataset


In [18]:
embedding_np = np.load("/Users/aikium_intern/Desktop/aikium_thermal_stability/Data/esm2_embeddings_with_ids.npz", allow_pickle=True)
dataset_df = pd.read_csv("/Users/aikium_intern/Desktop/aikium_thermal_stability/Data/Stratified_data.csv")

embedding = embedding_np["embeddings"]   # shape: (N, hidden_dim)
ids = embedding_np["ids"]    

emb_df = pd.DataFrame({
    "ID": ids,
    "embedding": list(embedding)
})

emb_df["ID"] = emb_df["ID"].astype(int)
dataset_df["ID"] = dataset_df["ID"].astype(int)
merged = emb_df.merge(dataset_df, on="ID")
merged




Unnamed: 0,ID,embedding,Tm,Sequence,From,AA_length,MW
0,117012,"[3.4952238, 2.9369724, -3.406106, -1.0034976, ...",45.556042,MKWAYKEENNFEKRRAEGDKIRRKYPDRIPVIVEKAPKSKLHDLDK...,meltome,123,14.763731
1,313786,"[2.2217994, -2.0552912, -0.14634663, 1.3754547...",43.863329,MVKATNVDLSLEDIISKTRKTTGSIQKKSFGGARRGNTRPTGLPRR...,meltome,240,25.695370
2,169171,"[3.0030677, 4.205013, 4.074065, 0.7811, -6.895...",43.574470,MPRANEIKKGMVLNYNGKLLLVKDIDIQSPTARGAATLYKMRFSDV...,meltome,190,21.532381
3,227821,"[4.7244706, -0.123155594, 1.057647, 4.5875287,...",40.496359,MTKSELIERMLTKQPQLSAKDVELAVKTILDHMSQSLSTGERIEIR...,meltome,96,10.757199
4,97927,"[2.16275, -1.2902148, 2.1692014, 0.009807652, ...",37.073883,MRQVVLDTETTGIGAEKGHRIIEIGCVELIDRKLTGRHYHQYVNPQ...,meltome,234,26.063256
...,...,...,...,...,...,...,...
23415,200713,"[3.211512, 0.9198779, -0.08337058, -1.2356853,...",86.289547,MTVRQVLVHKGGGVHAIHPEATVLDALRKLAEHDIGALLVMEGERL...,meltome,143,15.916415
23416,267005,"[-1.4136505, -2.4800174, 0.06798242, 2.689519,...",88.057949,MTGLELLAVALGMRHGVDPDHLAAVDGLSRVRPSPLNGVLFALGHG...,meltome,220,22.874996
23417,269836,"[2.5266905, -5.975683, -3.1002605, 1.2522675, ...",88.583454,MASLSFMIKEYNDYYIIDFERPVRKFSSAPFNGGVGTSLRYINRTV...,meltome,215,24.262483
23418,298355,"[0.30273947, 2.1589444, -0.4573475, -0.4192088...",86.755190,MRFKAELMNAPEMRRALYRIAHEIVEANKGTEGLALVGIHTRGIPL...,meltome,181,20.466363


In [19]:
def train_and_eval(params, input_dim, train_loader, val_loader):
    
    model = thermalMLP(
        input_dim=input_dim,
        hidden_sizes=params["hidden_sizes"],
        activation=params["activation"],
        dropout=params["dropout"]
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=params["lr"])
    
    train_losses = []
    val_losses = []
    
    for epoch in range(params["epochs"]):
        # ---- Training ----
        model.train()
        epoch_train_loss = 0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            preds = model(xb)
            loss = criterion(preds, yb)
            loss.backward()
            optimizer.step()
            epoch_train_loss += loss.item()
        train_loss = epoch_train_loss / len(train_loader)
        train_losses.append(train_loss)

        # ---- Validation ----
        model.eval()
        epoch_val_loss = 0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                preds = model(xb)
                epoch_val_loss += criterion(preds, yb).item()
        val_loss = epoch_val_loss / len(val_loader)
        val_losses.append(val_loss)

        print(f"Epoch {epoch+1}/{params['epochs']} | "
              f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    return train_losses, val_losses

In [None]:
embeddings = np.stack(merged["embedding"].values)
Tm = merged["Tm"].values
input_dim = embeddings.shape[1]

X_train, X_val, y_train, y_val = train_test_split(embeddings, Tm, test_size=0.2, random_state=42)

train_ds = ProteinTmDataset(X_train, y_train)
val_ds = ProteinTmDataset(X_val, y_val)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)


# Define search space
hyperparam_sweep = [
    {"hidden_sizes": [64], "activation": nn.ReLU, "dropout": 0.0, "lr": 1e-3, "epochs": 10},
    {"hidden_sizes": [128], "activation": nn.ReLU, "dropout": 0.0, "lr": 1e-3, "epochs": 10},
    {"hidden_sizes": [256, 128], "activation": nn.ReLU, "dropout": 0.2, "lr": 1e-3, "epochs": 10},
    {"hidden_sizes": [512, 256], "activation": nn.ReLU, "dropout": 0.2, "lr": 1e-3, "epochs": 10},
    {"hidden_sizes": [512, 256, 128], "activation": nn.ReLU, "dropout": 0.3, "lr": 5e-4, "epochs": 15},
    {"hidden_sizes": [1024, 512, 256], "activation": nn.ReLU, "dropout": 0.3, "lr": 5e-4, "epochs": 15},
]

results = []
for i, params in enumerate(hyperparam_sweep):
    train_loss, val_loss = train_and_eval(params, input_dim, train_loader, val_loader)
    results.append((params, train_loss, val_loss))
    plt.figure()
    plt.plot(train_loss, label="Train Loss")
    plt.plot(val_loss, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("MSE Loss")
    plt.title(f"Training vs Validation Loss\nConfig {i+1}: {params['hidden_sizes']}")
    plt.legend()
    plt.show()

# Show best
best = sorted(results, key=lambda x: x[2][-1])[0]
print("\nBest config:", best[0], "| Final Val Loss:", best[2][-1])

KeyboardInterrupt: 