<a href="https://colab.research.google.com/github/daehyun827/2025-winter-URP/blob/GNN/ems1b%2Bgnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

2023312038 겨울 URP 단백질 4종류(Soluble, Lipid Anchor, Peripheral, Transmembrane)분류 모델

기본 구조

PLM(ESM1b_t33_650M_UR50S) -> Amino acid embedding & Attention map(Contact map) -> GNN(구조예측) -> 구조를 통한 단백질의 특징 추출 -> 단백질 분류

이런 구조를 정한 이유
Protein MPNN에서 입력과 출력을 반대로 만든 모델이라고도 볼 수 있음. Prot-Trans 논문에서 PLM의 마지막 hiddeen layer의 embedding에서 서열만으로 학습을 시켜도 구조정보와 아미노산의 특징을 제대로 반영할 수 있음을 보여줬음. 그렇기에, Embedding으로 얻은 구조 정보를 통해서 서열을 받았을 때 이를 다시 구조로 변환시켜, 단백질 외부 특징을 뽑아낼 수 있다면, 위 단백질들을 분류할 수 있지 않을까라는 생각에서 이런 아키텍처로 구성하게 되었음.
여러 종류의 PLM에서 ESM을 선택한 이유는, ESM의 경우 아미노산 간의 contact map을 제공하지만, Prot-Trans의 경우 이를 제공하지는 않아 hidden layer에서 attention map을 추출하여 가공해야하는 번거로움이 있어 일단 ESM1b_t33_650M_UR50S을 사용하게 되었음

* 참고내용1 :
GNN의 경우 서열을 그래프 형태(노드와 엣지(간선))으로 봐야하기에 이를 정의해야하는 문제가 있었음. 엣지의 경우 아미노산 간의 위치 관계를 나타내는 contact map을 사용하였고, 노드의 경우 input 서열에 해당하는 아미노산 embedding을 사용하였음

* 참고내용2 :
Prot-Trans 논문에서는 2차구조를 예측할 때 2 layer CNN을 달아서 구조를 예측했다고 나오는데, 시간이 남는다면 CNN으로 구조를 예측하여 단백질의 특징을 뽑고 분류하였을때 결과 차이가 날 지 보는 것도 좋을듯함.

* 참고내용3:
앞서, Prot-Trans의 경우 attention map을 가공해야되서 ESM을 사용했다고 하는데, Prot-Trans(그 중 큰 모델인 Prot-T5) 모델 중 ESM보다 더 좋은 embedding을 제공하는 경우도 있으므로, 시간이 된다면 Prot-Trans 버전으로도 만들어보면 좋을 것 같음.

* 참고내용4:
코랩에서 DeepLoc 데이터를 on fly로 가공해서 하려다보니, GPU 메모리 부족 문제가 발생해서 데이터를 100개의 단백질마다 끊어서 저장한 뒤, 이 데이터로 Training과 Validation을 진행했음. 혹시, 코딩을 잘 한다면 이 문제를 해결해주길 바람 ㅎㅎ



In [None]:
!pip -q install transformers torch_geometric tqdm scikit-learn

import torch
print("CUDA:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


CUDA: True
GPU: Tesla T4


In [None]:
import os, math, glob
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as GraphDataLoader
from torch_geometric.nn import EdgeConv, global_add_pool
from transformers import EsmModel, EsmTokenizer
from tqdm import tqdm

# ===== 1) GNN 모델 =====
class ProteinGNN(nn.Module):
    def __init__(self, input_dim=1280, hidden_dim=256, num_layers=4, num_classes=4):
        super().__init__()
        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU()
        )
        self.edge_mlps = nn.ModuleList([
            nn.Sequential(
                nn.Linear(2 * hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            ) for _ in range(num_layers)
        ])
        self.residuals = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim) if i == 0 else nn.Identity()
            for i in range(num_layers)
        ])
        self.gate_scorer = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, 1),
            nn.Sigmoid()
        )
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, x_esm, edge_index, batch):
        x = self.input_proj(x_esm)
        for mlp, res in zip(self.edge_mlps, self.residuals):
            edge_out = EdgeConv(mlp, aggr='max')(x, edge_index)
            x = F.relu(edge_out + res(x))
        scores = self.gate_scorer(x).squeeze(-1)
        x_weighted = x * scores.unsqueeze(-1)
        graph_vec = global_add_pool(x_weighted, batch)
        logits = self.classifier(graph_vec)
        return logits, scores


# ===== 2) 데이터 로드/전처리 =====
def load_and_preprocess_data(csv_path):
    df = pd.read_csv(csv_path)
    df = df.drop(['Unnamed: 0'], axis=1, errors='ignore')

    label_cols = ['Peripheral', 'Transmembrane', 'LipidAnchor', 'Soluble']
    df['label'] = df[label_cols].idxmax(axis=1).map({
        'Peripheral': 0, 'Transmembrane': 1, 'LipidAnchor': 2, 'Soluble': 3
    }).astype(int)

    train_df = df[df['Partition'] <= 2].copy()
    val_df   = df[df['Partition'] > 2].copy()

    train_df = train_df[(train_df['Sequence'].str.len() >= 100) & (train_df['Sequence'].str.len() <= 800)]
    val_df   = val_df[(val_df['Sequence'].str.len() >= 100) & (val_df['Sequence'].str.len() <= 800)]

    print(f"Train: {len(train_df)}, Val: {len(val_df)}")
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True)


# ===== 3) 평가 함수 =====
def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            logits, _ = model(batch.x, batch.edge_index, batch.batch)
            pred = logits.argmax(dim=1)
            correct += (pred == batch.y).sum().item()
            total += batch.num_graphs
    return correct / total


# ===== 4) (핵심) ESM→그래프 저장 =====
def precompute_and_save_onefile(df, split_name, esm_model, tokenizer, device,
                                out_dir="/content/graphs", prob_threshold=0.6,
                                save_x_dtype=torch.float16):
    """
    split_name_graphs.pt 하나로 저장 (DB/청크 없이)
    - A100 추천: autocast(bf16)
    - attn은 GPU에서 edge만 뽑고 edge_index만 CPU로 내려서 속도/메모리 개선
    - x는 fp16으로 저장해서 파일 용량 절감
    """
    os.makedirs(out_dir, exist_ok=True)

    esm_model.eval()
    esm_model.config.output_attentions = True
    esm_model.config.return_dict = True

    graphs = []
    n = len(df)

    for i in tqdm(range(n), desc=f"Precompute {split_name}"):
        row = df.iloc[i]
        seq = row["Sequence"]
        label = int(row["label"])

        inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1022)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.inference_mode():
            if device.startswith("cuda"):
              with torch.autocast("cuda", dtype=torch.bfloat16):
                  outputs = esm_model(**inputs, output_attentions=True, return_dict=True)
            else:
                outputs = esm_model(**inputs, output_attentions=True, return_dict=True)

            # node embedding (저장은 fp16로)
            x = outputs.last_hidden_state[0, 1:-1, :].detach().to(save_x_dtype).cpu()

            # ★ attn은 GPU에 두고 edge만 추출 후 edge_index만 CPU로
            attn = outputs.attentions[-1].mean(dim=1)[0, 1:-1, 1:-1]
            row_idx, col_idx = torch.where(attn > prob_threshold)
            valid = row_idx != col_idx
            edge_index = torch.stack([row_idx[valid], col_idx[valid]], dim=0).detach().cpu()

        graphs.append(Data(x=x, edge_index=edge_index, y=torch.tensor(label, dtype=torch.long)))

        # 정리
        del outputs, inputs, attn
        if device.startswith("cuda") and (i % 200 == 0):
            torch.cuda.empty_cache()

        if i % 100 == 0:
            print(f"[{split_name}] {i+1}/{n}")

    save_path = os.path.join(out_dir, f"{split_name}_graphs.pt")
    torch.save(graphs, save_path)
    print(f"✅ saved: {save_path} ({len(graphs)} graphs)")


# ===== (너 파일 경로만 여기서 맞춰주면 됨) =====
CSV_PATH = "/content/Swissprot_Membrane_Train_Validation_dataset.csv"
train_df, val_df = load_and_preprocess_data(CSV_PATH)


Train: 12571, Val: 8831


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# ESM 로드 (OOM 방지 위해 eager)
model_name = "facebook/esm1b_t33_650M_UR50S"
tokenizer = EsmTokenizer.from_pretrained(model_name)
esm_model = EsmModel.from_pretrained(model_name, attn_implementation="eager").to(device).eval()

# (선택) fp16으로 더 가볍게
# esm_model = esm_model.half()

THRESH = 0.6
OUT_DIR = "/content/graphs"

precompute_and_save_onefile(train_df, "train", esm_model, tokenizer, device,
                           out_dir=OUT_DIR, prob_threshold=THRESH)

precompute_and_save_onefile(val_df, "val", esm_model, tokenizer, device,
                           out_dir=OUT_DIR, prob_threshold=THRESH)


Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm1b_t33_650M_UR50S and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[train] chunk 1/126  sample 1/12571
[train] chunk 1/126  sample 11/12571
[train] chunk 1/126  sample 21/12571
[train] chunk 1/126  sample 31/12571
[train] chunk 1/126  sample 41/12571
[train] chunk 1/126  sample 51/12571
[train] chunk 1/126  sample 61/12571
[train] chunk 1/126  sample 71/12571
[train] chunk 1/126  sample 81/12571
[train] chunk 1/126  sample 91/12571
✅ saved: /content/graphs/train_chunk0000.pt (100 graphs)
[train] chunk 2/126  sample 101/12571
[train] chunk 2/126  sample 111/12571
[train] chunk 2/126  sample 121/12571
[train] chunk 2/126  sample 131/12571
[train] chunk 2/126  sample 141/12571
[train] chunk 2/126  sample 151/12571
[train] chunk 2/126  sample 161/12571
[train] chunk 2/126  sample 171/12571
[train] chunk 2/126  sample 181/12571
[train] chunk 2/126  sample 191/12571
✅ saved: /content/graphs/train_chunk0001.pt (100 graphs)
[train] chunk 3/126  sample 201/12571
[train] chunk 3/126  sample 211/12571
[train] chunk 3/126  sample 221/12571
[train] chunk 3/126  sa

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

train_dataset = torch.load("/content/graphs/train_chunk*.pt")
val_dataset   = torch.load("/content/graphs/val_chunk*.pt")

train_loader = GraphDataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
val_loader   = GraphDataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0)

model = ProteinGNN(input_dim=1280).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

best_val_acc = 0
EPOCHS = 20

print("\n=== Training GNN (no ESM during training) ===")
for epoch in range(EPOCHS):
    model.train()
    train_loss, train_correct, train_total = 0, 0, 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for batch in pbar:
        batch = batch.to(device)
        optimizer.zero_grad()

        logits, _ = model(batch.x, batch.edge_index, batch.batch)
        loss = F.cross_entropy(logits, batch.y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        train_loss += loss.item()
        pred = logits.argmax(1)
        train_correct += (pred == batch.y).sum().item()
        train_total += batch.num_graphs

        pbar.set_postfix(loss=train_loss/max(1, len(train_loader)),
                        acc=train_correct/max(1, train_total))

    val_acc = evaluate(model, val_loader, device)
    print(f"Epoch {epoch+1:2d}: TrainLoss {train_loss/len(train_loader):.4f}, "
          f"TrainAcc {train_correct/train_total:.4f}, ValAcc {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "/content/best_esm_gnn.pth")
        print("✅ saved best model")

print(f"\nBest Val Acc: {best_val_acc:.4f}")
print("Model saved: /content/best_esm_gnn.pth")
