In [7]:
import os, sys, json, pickle, random
from collections import OrderedDict
from copy import deepcopy

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from tqdm import tqdm

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import matthews_corrcoef, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score

from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors, Descriptors
from rdkit.ML.Descriptors import MoleculeDescriptors
from rdkit.Chem.Scaffolds import MurckoScaffold

import optuna
from optuna.trial import Trial

import matplotlib.pyplot as plt
import seaborn as sns


# -------------------- [0. g-mlp 모듈 정의] --------------------
# 사용자의 환경에 'g-mlp'가 설치되어 있지 않을 경우를 대비한 최소 정의
try:
    # GMLP_DIR이 설정되어 있다면 기존 모듈 사용
    GMLP_DIR = "/home/minji/g-mlp"
    if GMLP_DIR not in sys.path:
        sys.path.append(GMLP_DIR)
    from g_mlp import gMLP
except:
    # 모듈을 찾을 수 없을 때 임시 클래스 정의 (실제 학습에는 부적합)
    class gMLP(nn.Module):
        def __init__(self, seq_len, d_model, d_ffn, num_layers):
            super().__init__()
            # gMLP 대신 단순한 선형 레이어로 대체 (실제 학습에는 부적합)
            self.layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_layers)])
        def forward(self, x):
            for layer in self.layers:
                x = layer(x)
            return x


# -------------------- [1. 공통 유틸/환경] --------------------
def set_seed(seed=700):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        # ❌ 성능을 낮추더라도 재현성을 위해 False로 고정
        torch.backends.cudnn.benchmark = True 
        torch.backends.cudnn.deterministic = True
        torch.use_deterministic_algorithms(True)
        os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    else:
        torch.use_deterministic_algorithms(True, warn_only=True)

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------- [2. 피처 생성 유틸] --------------------
def to_numpy_bitvect(bitvect, n_bits=None, drop_first=False):
    if n_bits is None:
        n_bits = bitvect.GetNumBits()
    arr = np.zeros((n_bits,), dtype=np.int8)
    DataStructs.ConvertToNumpyArray(bitvect, arr)
    if drop_first:
        arr = arr[1:]
    return arr.astype(np.float32)

def get_ecfp(mol, radius=2, nbits=1024):
    return to_numpy_bitvect(AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nbits), n_bits=nbits)

def get_maccs(mol):
    bv = MACCSkeys.GenMACCSKeys(mol)
    return to_numpy_bitvect(bv, n_bits=bv.GetNumBits(), drop_first=True)

def get_avalon(mol, nbits=512):
    from rdkit.Avalon import pyAvalonTools
    return to_numpy_bitvect(pyAvalonTools.GetAvalonFP(mol, nbits), n_bits=nbits)

def get_topological_torsion(mol, nbits=1024):
    bv = rdMolDescriptors.GetHashedTopologicalTorsionFingerprintAsBitVect(mol, nBits=nbits)
    return to_numpy_bitvect(bv, n_bits=nbits)

def get_rdkit_desc(mol):
    calc = MoleculeDescriptors.MolecularDescriptorCalculator([d[0] for d in Descriptors._descList])
    try:
        descs = calc.CalcDescriptors(mol)
        descs = np.array(descs, dtype=np.float32)
        descs = np.nan_to_num(descs, nan=0.0, posinf=0.0, neginf=0.0)
    except Exception:
        descs = np.zeros(len(Descriptors._descList), dtype=np.float32)
    return descs

def get_rdkit_descriptor_length():
    return len(Descriptors._descList)


# -------------------- [3. 임베딩 로드 & 차원 감지 (MolE 포함)] --------------------
def load_molecular_embeddings(embed_paths: dict):
    embed_data, embed_dims = {}, {}
    for name, path in embed_paths.items():
        try:
            df = pd.read_csv(path)
        except Exception:
            embed_data[name] = {}
            embed_dims[name] = 0
            continue
        
        def canon(s):
            m = Chem.MolFromSmiles(s)
            return Chem.MolToSmiles(m, canonical=True) if m else None

        df['smiles'] = df['smiles'].apply(canon)
        df = df.dropna(subset=['smiles']).reset_index(drop=True)

        embed_cols = [c for c in df.columns if c != 'smiles']
        dim = len(embed_cols)
        embed_dims[name] = dim
        embed_data[name] = {
            row['smiles']: row[embed_cols].to_numpy(dtype=np.float32, copy=False)
            for _, row in df.iterrows()
        }
    return embed_data, embed_dims

# -------------------- [4. 기대 차원 계산 + 안전 결합 (MolE 차원 추가)] --------------------
def compute_expected_dims(fp_types, embed_dims: dict):
    expected = OrderedDict()
    for t in fp_types:
        if t == 'ecfp':
            expected[t] = 1024
        elif t == 'avalon':
            expected[t] = 512
        elif t == 'maccs':
            expected[t] = 166
        elif t == 'tt':
            expected[t] = 1024
        elif t == 'rdkit':
            expected[t] = get_rdkit_descriptor_length()
        elif 'scage' in t:
            default_dim = 512
            expected[t] = embed_dims.get(t, default_dim) if embed_dims.get(t, 0) > 0 else default_dim
        elif 'mole' in t: # MolE 차원 추가
            default_dim = 768 
            expected[t] = embed_dims.get(t, default_dim) if embed_dims.get(t, 0) > 0 else default_dim
        else:
            raise ValueError(f"Unknown fp_type: {t}")
    return expected

def safe_fit_to_dim(vec: np.ndarray, target_dim: int) -> np.ndarray:
    if vec is None:
        return np.zeros(target_dim, dtype=np.float32)
    vec = vec.astype(np.float32, copy=False)
    if np.any(np.isnan(vec)) or np.any(np.isinf(vec)):
        vec = np.nan_to_num(vec, nan=0.0, posinf=0.0, neginf=0.0)
    cur = vec.shape[0]
    if cur == target_dim:
        return vec
    elif cur < target_dim:
        pad = np.zeros(target_dim - cur, dtype=np.float32)
        return np.concatenate([vec, pad], axis=0)
    else:
        return vec[:target_dim]
        
def make_feature_vector(mol, smiles, fp_types, expected_dims, embed_dicts):
    chunks = []
    for t in fp_types:
        dim = expected_dims[t]
        try:
            if t == 'ecfp':
                vec = get_ecfp(mol, radius=2, nbits=dim)
            elif t == 'avalon':
                vec = get_avalon(mol, nbits=dim)
            elif t == 'maccs':
                vec = get_maccs(mol)
            elif t == 'tt':
                vec = get_topological_torsion(mol, nbits=dim)
            elif t == 'rdkit':
                vec = get_rdkit_desc(mol)
            elif 'scage' in t or 'mole' in t: # MolE 임베딩 포함
                vec = embed_dicts.get(t, {}).get(smiles, None)
            else:
                vec = None
        except Exception:
            vec = None
        chunks.append(safe_fit_to_dim(vec, dim))
    feat = np.concatenate(chunks, axis=0)
    feat = np.nan_to_num(feat, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    return feat

# -------------------- [5. Dataset] --------------------
class ScageConcatDataset(data.Dataset):
    def __init__(self, label_path, embed_paths: dict, fp_types, expected_dims=None): 
        df = pd.read_csv(label_path)
        
        # 라벨 컬럼 자동 감지 로직
        label_cols = [c for c in df.columns if c not in ['smiles', 'mol_id'] and df[c].dtype in ['int64', 'float64', 'int32', 'float32']]
        
        # BBBP 데이터셋 (문자열 라벨)을 위한 예외 처리: 'p_np'가 있다면 이진 분류로 간주
        if 'p_np' in df.columns and len(df['p_np'].unique()) <= 3:
            df = df.rename(columns={'p_np': 'label'})
            df['label'] = df['label'].replace({'BBB-': 0, 'BBB+': 1})
            label_cols = ['label']
        elif 'label' in df.columns and len(label_cols) == 0:
            label_cols = ['label']
        
        if not label_cols:
            raise ValueError("Could not find any suitable label columns (non-SMILES/mol_id numeric columns).")
        
        self.label_cols_names = label_cols
        self.num_labels = len(label_cols)
        
        # 라벨 컬럼에 NaN 값이 있는 행 제거
        df = df.dropna(subset=['smiles'] + self.label_cols_names).reset_index(drop=True)
        
        # MolE와 SCAGE를 모두 로드합니다.
        self.embed_dicts, embed_dims = load_molecular_embeddings(embed_paths)

        if expected_dims is None:
            expected_dims = compute_expected_dims(fp_types, embed_dims)
        self.expected_dims = expected_dims
        self.fp_types = list(fp_types)

        def canon(s):
            m = Chem.MolFromSmiles(s)
            return Chem.MolToSmiles(m, canonical=True) if m else None

        df['smiles'] = df['smiles'].apply(canon)
        df = df.dropna(subset=['smiles']).reset_index(drop=True)

        features, labels, failed = [], [], []
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Generating Features"):
            smi = row['smiles']
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                failed.append(smi)
                continue
            
            # make_feature_vector에 MolE, SCAGE를 포함한 embed_dicts 전달
            feat = make_feature_vector(mol, smi, self.fp_types, self.expected_dims, self.embed_dicts)
            if feat is None or feat.ndim != 1:
                failed.append(smi)
                continue
            features.append(feat)
            # 라벨을 2D 형태로 NumPy 배열에 저장
            labels.append(row[self.label_cols_names].to_numpy(dtype=np.float32, copy=False))

        self.features = torch.tensor(np.stack(features, axis=0), dtype=torch.float32)
        self.labels = torch.tensor(np.stack(labels, axis=0), dtype=torch.float32)
        self.df = df[~df['smiles'].isin(failed)].reset_index(drop=True)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


# -------------------- [6. 스플릿 + RDKit 정규화 (토글)] --------------------
def split_then_normalize(
    dataset: ScageConcatDataset,
    split_mode: str = "scaffold",
    train_ratio: float = 0.8,
    val_ratio: float = 0.1,
    seed: int = 700
):
    set_seed(seed)
    df = dataset.df[['smiles']].copy() 

    def get_scaffold(smi):
        m = Chem.MolFromSmiles(smi)
        scaffold = MurckoScaffold.GetScaffoldForMol(m) if m else None
        return Chem.MolToSmiles(scaffold, canonical=True) if scaffold else None

    df['scaffold'] = dataset.df['smiles'].apply(get_scaffold)
    groups = list(df.groupby('scaffold').groups.values())

    if split_mode == "scaffold":
        groups = sorted(groups, key=lambda g: len(g), reverse=True)
    elif split_mode == "random_scaffold":
        rnd = random.Random(seed)
        rnd.shuffle(groups)
    else:
        raise ValueError("split_mode must be 'scaffold' or 'random_scaffold'")

    n = len(df)
    train_cap = int(round(train_ratio * n))
    val_cap   = int(round(val_ratio * n))

    train_idx, val_idx, test_idx = [], [], []
    for g in groups:
        g = list(g)
        if len(train_idx) + len(g) <= train_cap:
            train_idx += g
        elif len(val_idx) + len(g) <= val_cap:
            val_idx += g
        else:
            test_idx += g

    # 인덱스를 numpy 배열로 저장
    train_idx = np.array(train_idx, dtype=np.int64)
    val_idx = np.array(val_idx, dtype=np.int64)
    test_idx = np.array(test_idx, dtype=np.int64)

    return get_split_datasets(dataset, train_idx, val_idx, test_idx)

def get_split_datasets(dataset, train_idx, val_idx, test_idx):
    """저장된 인덱스 또는 새로 생성된 인덱스를 사용하여 데이터셋을 생성하고 RDKit 기술자를 정규화합니다."""
    
    def pick(idxs):
        return dataset.features[idxs], dataset.labels[idxs]

    X_train, y_train = pick(train_idx)
    X_val,   y_val   = pick(val_idx)
    X_test,  y_test  = pick(test_idx)
    
    # RDKit 기술자 슬라이스 찾기
    rd_start, rd_end = None, None
    offset = 0
    for t in dataset.fp_types:
        dim = dataset.expected_dims[t]
        if t == 'rdkit':
            rd_start, rd_end = offset, offset + dim
            break
        offset += dim
        
    # RDKit 기술자 정규화
    scaler = None
    if rd_start is not None:
        scaler = StandardScaler().fit(X_train[:, rd_start:rd_end])
        X_train[:, rd_start:rd_end] = torch.tensor(scaler.transform(X_train[:, rd_start:rd_end]), dtype=torch.float32)
        X_val[:, rd_start:rd_end]   = torch.tensor(scaler.transform(X_val[:, rd_start:rd_end]),   dtype=torch.float32)
        X_test[:, rd_start:rd_end]  = torch.tensor(scaler.transform(X_test[:, rd_start:rd_end]),  dtype=torch.float32)
        
    return (
        data.TensorDataset(X_train, y_train),
        data.TensorDataset(X_val,   y_val),
        data.TensorDataset(X_test,  y_test),
        scaler, (rd_start, rd_end),
        (train_idx, val_idx, test_idx)
    )

def load_split_dataset(dataset: ScageConcatDataset, split_dir, tag):
    """저장된 인덱스를 로드하여 데이터셋을 구성합니다."""
    
    # 인덱스 파일 경로 확인
    train_path = os.path.join(split_dir, f"train_idx_{tag}.npy")
    val_path = os.path.join(split_dir, f"val_idx_{tag}.npy")
    test_path = os.path.join(split_dir, f"test_idx_{tag}.npy")
    
    if not os.path.exists(train_path) or not os.path.exists(val_path) or not os.path.exists(test_path):
        return None
    
    # 인덱스 로드
    train_idx = np.load(train_path)
    val_idx = np.load(val_path)
    test_idx = np.load(test_path)
    
    print(f"[Info] Loaded existing split indices for tag: {tag}")
    return get_split_datasets(dataset, train_idx, val_idx, test_idx)


# -------------------- [7. 모델] --------------------
class MultiModalGMLPFromFlat(nn.Module):
    def __init__(self, mod_dims: OrderedDict, num_labels=1, d_model=512, d_ffn=1024, depth=4, dropout=0.2, use_gated_pool=True):
        super().__init__()
        self.mod_names = list(mod_dims.keys())
        self.mod_dims = [mod_dims[n] for n in self.mod_names]
        self.in_features = sum(self.mod_dims)
        self.seq_len = len(self.mod_names)
        self.use_gated_pool = use_gated_pool
        self.num_labels = num_labels

        self.proj = nn.ModuleDict({name: nn.Linear(in_dim, d_model) for name, in_dim in zip(self.mod_names, self.mod_dims)})
        self.backbone = gMLP(seq_len=self.seq_len, d_model=d_model, d_ffn=d_ffn, num_layers=depth)
        self.norm = nn.LayerNorm(d_model)
        if use_gated_pool:
            self.alpha = nn.Parameter(torch.zeros(self.seq_len))
        # 출력 차원을 라벨 개수에 맞춤
        self.head = nn.Linear(d_model, self.num_labels)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        chunks = torch.split(x, self.mod_dims, dim=1)
        tokens = [self.proj[name](chunk) for name, chunk in zip(self.mod_names, chunks)]
        X = torch.stack(tokens, dim=1)
        X = self.backbone(X)
        if self.use_gated_pool:
            w = torch.softmax(self.alpha, dim=0)
            Xp = (X * w.view(1, -1, 1)).sum(dim=1)
        else:
            Xp = X.mean(dim=1)
        Xp = self.drop(self.norm(Xp))
        # 단일 라벨일 때만 squeeze(-1) 적용 (BCEWithLogitsLoss의 요구사항)
        logits = self.head(Xp).squeeze(dim=-1) if self.num_labels == 1 else self.head(Xp)
        return logits
    
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# -------------------- [8. 학습/평가 루틴] --------------------
def train_model(model, optimizer, train_loader, val_loader, loss_fn, num_epochs=50, patience=10):
    best_val = float('inf'); best_state = None; bad = 0
    # 모델에서 라벨 개수를 가져옵니다.
    num_labels = model.num_labels  
    
    for epoch in range(num_epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            
            # 단일 라벨일 때만 타겟 텐서의 차원을 (N, 1) → (N)로 압축
            if num_labels == 1:
                y = y.squeeze(dim=-1)
                
            optimizer.zero_grad()
            logits = model(x)
            
            # logits는 (N, num_labels) 또는 (N) 형태, y는 (N, num_labels) 또는 (N) 형태로 맞춰집니다.
            loss = loss_fn(logits, y)
            loss.backward()
            optimizer.step()
        
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                
                # 단일 라벨일 때만 타겟 텐서를 squeeze
                if num_labels == 1:
                    y = y.squeeze(dim=-1)
                    
                val_loss += loss_fn(model(x), y).item()
        val_loss /= len(val_loader)
        
        if val_loss < best_val:
            best_val = val_loss
            best_state = deepcopy(model.state_dict())
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                break
    if best_state:
        model.load_state_dict(best_state)
    return model

def eval_model(model, loader, num_labels):
    model.eval()
    y_true_all, y_prob_all, y_pred_all = [], [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            logits = model(x)
            probs = torch.sigmoid(logits).cpu().numpy()
            pred = (probs > 0.5).astype(int)
            y_prob_all.extend(probs)
            y_pred_all.extend(pred)
            y_true_all.extend(y.cpu().numpy())
    
    y_true_all = np.array(y_true_all)
    y_pred_all = np.array(y_pred_all)
    y_prob_all = np.array(y_prob_all)

    metrics = {}
    if num_labels > 1:
        # 멀티 라벨 평가 (micro average)
        try:
            metrics['accuracy'] = round(accuracy_score(y_true_all, y_pred_all), 3) # Subset accuracy
            metrics['precision'] = round(precision_score(y_true_all, y_pred_all, average='micro', zero_division=0), 3)
            metrics['recall'] = round(recall_score(y_true_all, y_pred_all, average='micro', zero_division=0), 3)
            metrics['f1'] = round(f1_score(y_true_all, y_pred_all, average='micro', zero_division=0), 3)
            metrics['mcc'] = round(matthews_corrcoef(y_true_all.flatten(), y_pred_all.flatten()), 3)
            metrics['roc_auc'] = round(roc_auc_score(y_true_all, y_prob_all, average='micro') if len(y_true_all) > 1 and len(set(y_true_all.flatten())) > 1 else 0.0, 3)
            metrics['sensitivity'] = metrics['recall']
            metrics['specificity'] = 'N/A' # 멀티 라벨에서는 일반적으로 계산하지 않음
        except Exception as e:
            print(f"[Warn] Multi-label evaluation failed: {e}")
            metrics = {k: 0.0 for k in ['accuracy', 'precision', 'recall', 'f1', 'mcc', 'roc_auc', 'sensitivity']}
            metrics['specificity'] = 'N/A'

    else:
        # 단일 라벨 평가 (1D 배열로 처리)
        y_true_1d = y_true_all.flatten()
        y_pred_1d = y_pred_all.flatten()
        y_prob_1d = y_prob_all.flatten()
        
        cm = confusion_matrix(y_true_1d, y_pred_1d)
        if cm.size == 4:
            tn, fp, fn, tp = cm.ravel()
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
            sensitivity = recall_score(y_true_1d, y_pred_1d, zero_division=0)
        else:
            specificity, sensitivity = 0.0, 0.0
            
        metrics['accuracy'] = round(accuracy_score(y_true_1d, y_pred_1d), 3)
        metrics['precision'] = round(precision_score(y_true_1d, y_pred_1d, zero_division=0), 3)
        metrics['recall'] = round(sensitivity, 3)
        metrics['f1'] = round(f1_score(y_true_1d, y_pred_1d, zero_division=0), 3)
        metrics['roc_auc'] = round(roc_auc_score(y_true_1d, y_prob_1d) if len(set(y_true_1d)) > 1 else 0.0, 3)
        metrics['mcc'] = round(matthews_corrcoef(y_true_1d, y_pred_1d), 3)
        metrics['sensitivity'] = sensitivity
        metrics['specificity'] = specificity
    
    return metrics


# -------------------- [9. 스플릿 인덱스 저장/로드] --------------------
def save_split_indices(out_dir, tag, split_indices):
    os.makedirs(out_dir, exist_ok=True)
    train_idx, val_idx, test_idx = split_indices
    np.save(os.path.join(out_dir, f"train_idx_{tag}.npy"), np.array(train_idx, dtype=np.int64))
    np.save(os.path.join(out_dir, f"val_idx_{tag}.npy"),   np.array(val_idx,   dtype=np.int64))
    np.save(os.path.join(out_dir, f"test_idx_{tag}.npy"),  np.array(test_idx,  dtype=np.int64))

def load_split_indices(path, tag):
    # 인덱스 로드는 get_split_dataset에서 수행하므로 이 함수는 사용하지 않습니다.
    pass


# -------------------- [10. 구성/스케일러 저장·로드] --------------------
def save_config(cfg_path, config: dict):
    with open(cfg_path, 'w') as f:
        json.dump(config, f, indent=2)

def load_config(cfg_path):
    with open(cfg_path, 'r') as f:
        return json.load(f)

def save_scaler(path, scaler):
    with open(path, 'wb') as f:
        pickle.dump(scaler, f)

def load_scaler(path):
    with open(path, 'rb') as f:
        return pickle.load(f)

In [8]:
# -------------------- [11. 주 실행 함수] --------------------
def run_experiment(label_path, embed_paths, fp_types, split_mode="random_scaffold", seeds=None, split_dir="./splits", reuse_split=True):
    if seeds is None:
        seeds = [42, 100, 200, 300, 400, 500, 600, 700, 800, 900]
    
    results = []
    
    for seed in seeds:
        set_seed(seed)
        print("\n" + "="*80)
        print(f"--- [STARTING NEW RUN] Data: {os.path.basename(label_path)}, Split Mode: {split_mode}, Seed: {seed} ---")
        print("="*80)
        
        # (C) Dataset & dims
        try:
            # 데이터셋 인스턴스화 (피처 추출)
            dataset = ScageConcatDataset(label_path, embed_paths, fp_types=fp_types)
        except ValueError as e:
            print(f"[FATAL] Dataset initialization failed: {e}")
            continue

        expected_dims = dataset.expected_dims
        mod_dims = OrderedDict((t, expected_dims[t]) for t in fp_types)
        num_labels = dataset.num_labels
        label_cols_names = dataset.label_cols_names
        
        tag = f"{os.path.basename(label_path).split('.')[0]}_{split_mode}_seed{seed}"

        # (D) Split 로드/생성 + RDKit normalize
        split_data = None
        if reuse_split:
             split_data = load_split_dataset(dataset, split_dir, tag)

        if split_data is None:
            # 인덱스가 없으면 새로 분할하고 저장
            train_ds, val_ds, test_ds, scaler, rd_slice, split_indices = split_then_normalize(
                dataset, split_mode=split_mode, train_ratio=0.8, val_ratio=0.1, seed=seed
            )
            save_split_indices(split_dir, tag, split_indices)
            print(f"[Save] New split indices created and saved to {split_dir} (tag={tag})")
        else:
            # 인덱스가 있으면 로드하여 사용
            train_ds, val_ds, test_ds, scaler, rd_slice, split_indices = split_data
        
        # (F) DataLoader (재현성 확보를 위해 num_workers=0 명시)
        train_loader = data.DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=0) 
        val_loader   = data.DataLoader(val_ds,   batch_size=128, shuffle=False, num_workers=0)
        test_loader  = data.DataLoader(test_ds,  batch_size=128, shuffle=False, num_workers=0)
        
        # --- 데이터셋 정보 준비 ---
        y_train = train_ds.tensors[1].cpu().numpy()
        y_val = val_ds.tensors[1].cpu().numpy()
        y_test = test_ds.tensors[1].cpu().numpy()

        if num_labels == 1:
            y_train = y_train.squeeze()
            y_val = y_val.squeeze()
            y_test = y_test.squeeze()
        
        # --- 클래스 분포 분석/출력 (기존 로직 유지) ---
        def plot_class_distribution(train_labels, val_labels, test_labels, seed, num_labels, label_cols_names):
             os.makedirs("./artifacts", exist_ok=True)
             fig, ax = plt.subplots(figsize=(8, 6))
             
             if num_labels == 1:
                 labels_map = {0: 'Negative', 1: 'Positive'}
                 train_counts = pd.Series(train_labels).map(labels_map).value_counts(normalize=True).sort_index()
                 val_counts = pd.Series(val_labels).map(labels_map).value_counts(normalize=True).sort_index()
                 test_counts = pd.Series(test_labels).map(labels_map).value_counts(normalize=True).sort_index()
                 counts_df = pd.DataFrame({'Train': train_counts, 'Validation': val_counts, 'Test': test_counts}).fillna(0)
                 counts_df.T.plot(kind='bar', stacked=False, ax=ax, rot=0)
                 ax.legend(title=label_cols_names[0])
             else:
                 train_props = np.mean(train_labels, axis=0)
                 val_props = np.mean(val_labels, axis=0)
                 test_props = np.mean(test_labels, axis=0)
                 props_df = pd.DataFrame({
                     'Train': train_props, 'Validation': val_props, 'Test': test_props
                 }, index=label_cols_names)
                 props_df.plot(kind='bar', ax=ax, rot=90)
                 ax.set_ylabel('Positive Class Proportion')
                 ax.set_xlabel('Task')
                 ax.legend(title='Dataset Split')
                 fig.subplots_adjust(bottom=0.3)
 
             ax.set_title(f'Class Distribution by Split (Seed: {seed})')
             ax.set_ylabel('Proportion')
             plt.tight_layout()
             plt.savefig(f'./artifacts/class_distribution_{tag}.png')
             plt.close()

        print(f"\n--- [Info] Class distribution for seed {seed} ---")
        if num_labels == 1:
            print(f"Train: Positive={np.mean(y_train):.2f}, Negative={1-np.mean(y_train):.2f}")
            print(f"Validation: Positive={np.mean(y_val):.2f}, Negative={1-np.mean(y_val):.2f}")
            print(f"Test: Positive={np.mean(y_test):.2f}, Negative={1-np.mean(y_test):.2f}")
        else:
            print("Multi-label data. Displaying positive ratios per task:")
            for i, task in enumerate(label_cols_names):
                print(f" - {task}: Train={np.mean(y_train[:, i]):.2f}, Val={np.mean(y_val[:, i]):.2f}, Test={np.mean(y_test[:, i]):.2f}")
        
        plot_class_distribution(y_train, y_val, y_test, seed, num_labels, label_cols_names)
        
        # (G) Model
        set_seed(seed)
        model = MultiModalGMLPFromFlat(
            mod_dims=mod_dims, num_labels=num_labels, d_model=512, d_ffn=1048, depth=4, dropout=0.2, use_gated_pool=True
        ).to(device)

        # (H) pos_weight (옵션)
        pos_weight = None
        try:
            # y_train이 1D일 경우 (N,) 형태, 2D일 경우 (N, M) 형태를 유지합니다.
            if y_train.ndim == 1:
                n_pos = np.array([np.sum(y_train)])
                n_neg = np.array([np.sum(1 - y_train)])
            else:
                n_pos = np.sum(y_train, axis=0)
                n_neg = np.sum(1 - y_train, axis=0)
                
            ratio = np.divide(n_neg, n_pos, out=np.zeros_like(n_neg, dtype=float), where=n_pos!=0)
            pos_weight = torch.tensor(ratio, dtype=torch.float32, device=device)
            print(f"[Info] Using pos_weight={pos_weight.tolist()}")
        except Exception as e:
            print(f"[Warn] pos_weight auto-calc skipped: {e}")
        
        optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
        loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

        # (I) Train
        print(f"--- [STEP 1] Training gMLP on {os.path.basename(label_path).split('.')[0]} ---")
        model = train_model(model, optimizer, train_loader, val_loader, loss_fn, num_epochs=50, patience=10)
        
        # (J) Test eval
        print("\n--- [STEP 2] Evaluating on Test Split ---")
        metrics = eval_model(model, test_loader, num_labels)
        
        # (J-1) 각 시드별 결과 저장
        metrics['seed'] = seed
        results.append(metrics)
        
        # (K) 아티팩트 저장
        model_path  = f"./artifacts/gmlp_best_model_{tag}.pth"
        cfg_path    = f"./artifacts/feature_config_{tag}.json"
        scaler_path = f"./artifacts/rdkit_scaler_{tag}.pkl"
        torch.save(model.state_dict(), model_path)
        cfg = {
            "fp_types": fp_types,
            "mod_dims": {k: int(v) for k, v in mod_dims.items()},
            "num_labels": num_labels,
            "rd_slice": [rd_slice[0], rd_slice[1]] if rd_slice[0] is not None else None,
            "split_mode": split_mode,
            "seed": seed,
            "model_hparams": {"d_model": 512, "d_ffn": 1048, "depth": 4, "dropout": 0.2, "use_gated_pool": True},
            "train_params": {"lr": 1e-4, "weight_decay": 1e-5, "num_epochs": 50, "patience": 10},
            "class_balance": {"pos_weight": pos_weight.tolist() if pos_weight is not None else None}
        }
        save_config(cfg_path, cfg)
        if scaler is not None:
            save_scaler(scaler_path, scaler)

    # (M) 최종 결과 요약 및 출력
    print("\n" + "="*80)
    print(f"--- [FINAL SUMMARY] Average Performance for {os.path.basename(label_path).split('.')[0]} ---")
    print("="*80)
    
    results_df = pd.DataFrame(results).set_index('seed')
    summary = results_df.mean(numeric_only=True)
    std_dev = results_df.std(numeric_only=True)
    
    for metric in summary.index:
        print(f"{metric:<18}| Mean: {summary[metric]:.4f} | Std Dev: {std_dev[metric]:.4f}")
    
    results_df.to_csv(f"./artifacts/multi_seed_results_{tag}.csv")
    print(f"\n[Save] Detailed results saved to ./artifacts/multi_seed_results_{tag}.csv")
    print("="*80)


if __name__ == "__main__":
    # --- 실험 실행 예시: CLINTOX (Multi/Single-label) with MolE ---
    
    # 🔥 MolE 임베딩의 경로는 사용자 환경에 맞게 설정해야 합니다.
    MOLE_CLINTOX_PATH = '/home/minji/mole_public/MolE_embed_clintox.csv' 
    
    embed_paths_clintox = {
        'scage_graph': '/home/minji/scage/CLINTOX/clintox_graph.csv',
        'scage_atom': '/home/minji/scage/CLINTOX/clintox_atom.csv',
        'mole_clintox': MOLE_CLINTOX_PATH 
    }
    
    fp_types_clintox = ['rdkit', 'ecfp', 'scage_graph', 'mole_clintox']
    
    print("="*80)
    print("--- Running CLINTOX (Multi/Single-label) Experiment with MolE (Split Reuse Enabled) ---")
    
    run_experiment(
        label_path='/home/minji/scage/CLINTOX/clintox_label.csv',
        embed_paths=embed_paths_clintox,
        fp_types=fp_types_clintox,
        split_mode="scaffold",
        reuse_split=True, # 이전에 생성된 인덱스가 있으면 로드하여 재사용
        split_dir="./splits" # 인덱스 파일이 저장된/저장될 디렉토리
    )

--- Running CLINTOX (Multi/Single-label) Experiment with MolE (Split Reuse Enabled) ---

--- [STARTING NEW RUN] Data: clintox_label.csv, Split Mode: scaffold, Seed: 42 ---


  descs = np.array(descs, dtype=np.float32)
Generating Features: 100%|████████████████████████████████████████████████████████████████████████████████| 1477/1477 [00:08<00:00, 166.85it/s]


[Info] Loaded existing split indices for tag: clintox_label_scaffold_seed42

--- [Info] Class distribution for seed 42 ---
Multi-label data. Displaying positive ratios per task:
 - CT_TOX: Train=0.08, Val=0.11, Test=0.03
 - FDA_APPROVED: Train=0.94, Val=0.90, Test=0.98
[Info] Using pos_weight=[12.133333206176758, 0.06871609389781952]
--- [STEP 1] Training gMLP on clintox_label ---

--- [STEP 2] Evaluating on Test Split ---

--- [STARTING NEW RUN] Data: clintox_label.csv, Split Mode: scaffold, Seed: 100 ---


  descs = np.array(descs, dtype=np.float32)
Generating Features: 100%|████████████████████████████████████████████████████████████████████████████████| 1477/1477 [00:12<00:00, 117.69it/s]


[Info] Loaded existing split indices for tag: clintox_label_scaffold_seed100

--- [Info] Class distribution for seed 100 ---
Multi-label data. Displaying positive ratios per task:
 - CT_TOX: Train=0.08, Val=0.11, Test=0.03
 - FDA_APPROVED: Train=0.94, Val=0.90, Test=0.98
[Info] Using pos_weight=[12.133333206176758, 0.06871609389781952]
--- [STEP 1] Training gMLP on clintox_label ---

--- [STEP 2] Evaluating on Test Split ---

--- [STARTING NEW RUN] Data: clintox_label.csv, Split Mode: scaffold, Seed: 200 ---


  descs = np.array(descs, dtype=np.float32)
Generating Features: 100%|████████████████████████████████████████████████████████████████████████████████| 1477/1477 [00:12<00:00, 119.03it/s]


[Info] Loaded existing split indices for tag: clintox_label_scaffold_seed200

--- [Info] Class distribution for seed 200 ---
Multi-label data. Displaying positive ratios per task:
 - CT_TOX: Train=0.08, Val=0.11, Test=0.03
 - FDA_APPROVED: Train=0.94, Val=0.90, Test=0.98
[Info] Using pos_weight=[12.133333206176758, 0.06871609389781952]
--- [STEP 1] Training gMLP on clintox_label ---

--- [STEP 2] Evaluating on Test Split ---

--- [STARTING NEW RUN] Data: clintox_label.csv, Split Mode: scaffold, Seed: 300 ---


  descs = np.array(descs, dtype=np.float32)
Generating Features: 100%|█████████████████████████████████████████████████████████████████████████████████| 1477/1477 [00:16<00:00, 89.91it/s]


[Info] Loaded existing split indices for tag: clintox_label_scaffold_seed300

--- [Info] Class distribution for seed 300 ---
Multi-label data. Displaying positive ratios per task:
 - CT_TOX: Train=0.08, Val=0.11, Test=0.03
 - FDA_APPROVED: Train=0.94, Val=0.90, Test=0.98
[Info] Using pos_weight=[12.133333206176758, 0.06871609389781952]
--- [STEP 1] Training gMLP on clintox_label ---

--- [STEP 2] Evaluating on Test Split ---

--- [STARTING NEW RUN] Data: clintox_label.csv, Split Mode: scaffold, Seed: 400 ---


  descs = np.array(descs, dtype=np.float32)
Generating Features: 100%|█████████████████████████████████████████████████████████████████████████████████| 1477/1477 [00:15<00:00, 94.42it/s]


[Info] Loaded existing split indices for tag: clintox_label_scaffold_seed400

--- [Info] Class distribution for seed 400 ---
Multi-label data. Displaying positive ratios per task:
 - CT_TOX: Train=0.08, Val=0.11, Test=0.03
 - FDA_APPROVED: Train=0.94, Val=0.90, Test=0.98
[Info] Using pos_weight=[12.133333206176758, 0.06871609389781952]
--- [STEP 1] Training gMLP on clintox_label ---

--- [STEP 2] Evaluating on Test Split ---

--- [STARTING NEW RUN] Data: clintox_label.csv, Split Mode: scaffold, Seed: 500 ---


  descs = np.array(descs, dtype=np.float32)
Generating Features: 100%|████████████████████████████████████████████████████████████████████████████████| 1477/1477 [00:13<00:00, 111.17it/s]


[Info] Loaded existing split indices for tag: clintox_label_scaffold_seed500

--- [Info] Class distribution for seed 500 ---
Multi-label data. Displaying positive ratios per task:
 - CT_TOX: Train=0.08, Val=0.11, Test=0.03
 - FDA_APPROVED: Train=0.94, Val=0.90, Test=0.98
[Info] Using pos_weight=[12.133333206176758, 0.06871609389781952]
--- [STEP 1] Training gMLP on clintox_label ---

--- [STEP 2] Evaluating on Test Split ---

--- [STARTING NEW RUN] Data: clintox_label.csv, Split Mode: scaffold, Seed: 600 ---


  descs = np.array(descs, dtype=np.float32)
Generating Features: 100%|█████████████████████████████████████████████████████████████████████████████████| 1477/1477 [00:14<00:00, 99.64it/s]


[Info] Loaded existing split indices for tag: clintox_label_scaffold_seed600

--- [Info] Class distribution for seed 600 ---
Multi-label data. Displaying positive ratios per task:
 - CT_TOX: Train=0.08, Val=0.11, Test=0.03
 - FDA_APPROVED: Train=0.94, Val=0.90, Test=0.98
[Info] Using pos_weight=[12.133333206176758, 0.06871609389781952]
--- [STEP 1] Training gMLP on clintox_label ---

--- [STEP 2] Evaluating on Test Split ---

--- [STARTING NEW RUN] Data: clintox_label.csv, Split Mode: scaffold, Seed: 700 ---


  descs = np.array(descs, dtype=np.float32)
Generating Features: 100%|█████████████████████████████████████████████████████████████████████████████████| 1477/1477 [00:16<00:00, 89.92it/s]


[Info] Loaded existing split indices for tag: clintox_label_scaffold_seed700

--- [Info] Class distribution for seed 700 ---
Multi-label data. Displaying positive ratios per task:
 - CT_TOX: Train=0.08, Val=0.11, Test=0.03
 - FDA_APPROVED: Train=0.94, Val=0.90, Test=0.98
[Info] Using pos_weight=[12.133333206176758, 0.06871609389781952]
--- [STEP 1] Training gMLP on clintox_label ---

--- [STEP 2] Evaluating on Test Split ---

--- [STARTING NEW RUN] Data: clintox_label.csv, Split Mode: scaffold, Seed: 800 ---


  descs = np.array(descs, dtype=np.float32)
Generating Features: 100%|████████████████████████████████████████████████████████████████████████████████| 1477/1477 [00:14<00:00, 102.42it/s]


[Info] Loaded existing split indices for tag: clintox_label_scaffold_seed800

--- [Info] Class distribution for seed 800 ---
Multi-label data. Displaying positive ratios per task:
 - CT_TOX: Train=0.08, Val=0.11, Test=0.03
 - FDA_APPROVED: Train=0.94, Val=0.90, Test=0.98
[Info] Using pos_weight=[12.133333206176758, 0.06871609389781952]
--- [STEP 1] Training gMLP on clintox_label ---

--- [STEP 2] Evaluating on Test Split ---

--- [STARTING NEW RUN] Data: clintox_label.csv, Split Mode: scaffold, Seed: 900 ---


  descs = np.array(descs, dtype=np.float32)
Generating Features: 100%|████████████████████████████████████████████████████████████████████████████████| 1477/1477 [00:14<00:00, 102.86it/s]


[Info] Loaded existing split indices for tag: clintox_label_scaffold_seed900

--- [Info] Class distribution for seed 900 ---
Multi-label data. Displaying positive ratios per task:
 - CT_TOX: Train=0.08, Val=0.11, Test=0.03
 - FDA_APPROVED: Train=0.94, Val=0.90, Test=0.98
[Info] Using pos_weight=[12.133333206176758, 0.06871609389781952]
--- [STEP 1] Training gMLP on clintox_label ---

--- [STEP 2] Evaluating on Test Split ---

--- [FINAL SUMMARY] Average Performance for clintox_label ---
accuracy          | Mean: 0.5056 | Std Dev: 0.2000
precision         | Mean: 0.6749 | Std Dev: 0.1689
recall            | Mean: 0.6174 | Std Dev: 0.1595
f1                | Mean: 0.6391 | Std Dev: 0.1562
mcc               | Mean: 0.2980 | Std Dev: 0.3189
roc_auc           | Mean: 0.6962 | Std Dev: 0.2161
sensitivity       | Mean: 0.6174 | Std Dev: 0.1595

[Save] Detailed results saved to ./artifacts/multi_seed_results_clintox_label_scaffold_seed900.csv
