<a href="https://colab.research.google.com/github/nanpolend/machine-learning/blob/master/RNA3D_Gemini%E4%BF%AE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#-- coding: utf-8 --#

""" RNA 3D 結構預測 Kaggle 比賽範例程式碼
對手：deepseek
作者：ChatGPT
說明：以分段方式撰寫，方便手動排查與除錯
"""

# ==========================
# 1. 環境設定與函式庫載入
# ==========================

import os
import math
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Biopython 用於 PDB 解析
from Bio.PDB import PDBParser

# ==========================
# 2. 全域參數與設定
# ==========================

class Config:
    # 資料路徑
    DATA_DIR = '/mnt/data/'          # PDB 資料夾
    SEQ_FILE = 'sequences.csv'       # RNA 序列檔案
    # 模型參數
    EMBED_DIM = 128
    HIDDEN_DIM = 256
    NUM_LAYERS = 4
    DROPOUT = 0.1
    # 訓練設定
    BATCH_SIZE = 8
    LR = 1e-4
    EPOCHS = 50
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = Config()

# ==========================
# 3. 資料載入與前處理
# ==========================

class RNADataset(Dataset):
    """ 自訂 Dataset 類別：
    - 讀取序列 (FASTA/CSV)
    - 解析對應 PDB 檔案，回傳 x,y,z 座標
    - 支援資料增強 (旋轉、平移)
    """
    def __init__(self, seq_df, pdb_dir, augment=False):
        self.seq_df = seq_df
        self.pdb_dir = pdb_dir
        self.augment = augment
        self.pdb_parser = PDBParser(QUIET=True)

    def __len__(self):
        return len(self.seq_df)

    def __getitem__(self, idx):
        # 1) 讀取序列
        row = self.seq_df.iloc[idx]
        seq_id = row['id']
        seq = row['sequence']

        # 2) 解析 PDB
        pdb_path = os.path.join(self.pdb_dir, f"{seq_id}.pdb")
        structure = self.pdb_parser.get_structure(seq_id, pdb_path)
        coords = []
        for model in structure:
            for chain in model:
                for res in chain:
                    atom = res['P']  # 取磷原子代表位置
                    coords.append(atom.get_coord())
        coords = np.array(coords)  # (L, 3)

        # 3) 標準化：置中 + 縮放
        coords -= coords.mean(axis=0)
        coords /= np.linalg.norm(coords)

        # 4) 編碼序列 (One-hot)
        one_hot = np.zeros((len(seq), 4), dtype=np.float32)
        mapping = {'A': 0, 'U': 1, 'C': 2, 'G': 3}
        for i, nt in enumerate(seq):
            one_hot[i, mapping[nt]] = 1.0

        # 5) 資料增強 (選用)
        if self.augment:
            # 隨機旋轉
            angle = np.random.rand() * 2 * math.pi
            R = np.array([[math.cos(angle), -math.sin(angle), 0],
                          [math.sin(angle), math.cos(angle), 0],
                          [0, 0, 1]])
            coords = coords.dot(R)

        return torch.from_numpy(one_hot), torch.from_numpy(coords).float()

# 載入序列檔
seq_df = pd.read_csv(os.path.join(config.DATA_DIR, config.SEQ_FILE))

# 80/20 切分訓練/驗證
train_df = seq_df.sample(frac=0.8, random_state=42)
valid_df = seq_df.drop(train_df.index)

train_dataset = RNADataset(train_df, config.DATA_DIR, augment=True)
valid_dataset = RNADataset(valid_df, config.DATA_DIR, augment=False)
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=config.BATCH_SIZE)

# ==========================
# 4. 模型定義
# ==========================

class SimpleTransformer(nn.Module):
    """ 基於 Transformer 的簡易模型架構 """
    def __init__(self, embed_dim, hidden_dim, num_layers, dropout):
        super().__init__()
        self.input_proj = nn.Linear(4, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=8, dim_feedforward=hidden_dim, dropout=dropout,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.output_proj = nn.Linear(embed_dim, 3)

    def forward(self, x):
        # x: (B, L, 4)
        h = self.input_proj(x)  # (B, L, D)
        # Transformer 需要 (L, B, D)
        h = h.permute(1, 0, 2)
        h = self.transformer(h)
        h = h.permute(1, 0, 2)  # 回到 (B, L, D)
        coords = self.output_proj(h)  # (B, L, 3)
        return coords

# Instantiate
model = SimpleTransformer(
    embed_dim=config.EMBED_DIM,
    hidden_dim=config.HIDDEN_DIM,
    num_layers=config.NUM_LAYERS,
    dropout=config.DROPOUT,
).to(config.DEVICE)

# ==========================
# 5. 損失函數與優化器
# ==========================

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.LR)

# ==========================
# 6. 訓練與驗證函式
# ==========================

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for x, y in tqdm(loader, desc='Train'):  # x: one-hot, y: coordinates
        x, y = x.to(device), y.to(device)
        pred = model(x)
        loss = criterion(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y in tqdm(loader, desc='Valid'):
            x, y = x.to(device), y.to(device)
            pred = model(x)
            loss = criterion(pred, y)
            total_loss += loss.item()
    return total_loss / len(loader)

# ==========================
# 7. 主訓練迴圈
# ==========================

best_val = float('inf')
for epoch in range(1, config.EPOCHS + 1):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, config.DEVICE)
    val_loss = validate(model, valid_loader, criterion, config.DEVICE)
    print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Valid Loss={val_loss:.4f}")
    # 儲存最佳模型
    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(), 'best_model.pt')
        print("[Info] Best model saved.")

# ==========================
# 8. 推論與結果儲存
# ==========================

# 加載最優模型
model.load_state_dict(torch.load('best_model.pt'))
model.eval()

output_dir = 'predictions'
os.makedirs(output_dir, exist_ok=True)
with torch.no_grad():
    for seq_id in tqdm(valid_df['id']):
        # 讀入序列並推論 (同 Dataset 實作)
        # ... (略)
        # 假設 pred_coords 為 (L,3)
        # np.savetxt(os.path.join(output_dir, f"{seq_id}.xyz"), pred_coords)
        pass

print("推論完成，結束程式。")