In [1]:
import os, sys, json, pickle, random
from collections import OrderedDict
from copy import deepcopy

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from tqdm import tqdm

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import matthews_corrcoef, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score

from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors, Descriptors
from rdkit.ML.Descriptors import MoleculeDescriptors
from rdkit.Chem.Scaffolds import MurckoScaffold

import optuna
from optuna.trial import Trial

# g-mlp 모듈 경로
GMLP_DIR = "/home/minji/g-mlp"
if GMLP_DIR not in sys.path:
    sys.path.append(GMLP_DIR)
from g_mlp import gMLP

# -------------------- [0. 공통 유틸/환경] --------------------
def set_seed(seed=700):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    if torch.cuda.is_available():
        torch.use_deterministic_algorithms(True)
        os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    else:
        torch.use_deterministic_algorithms(True, warn_only=True)
set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
# -------------------- [1. 피처 생성 유틸] --------------------
def to_numpy_bitvect(bitvect, n_bits=None, drop_first=False):
    if n_bits is None:
        n_bits = bitvect.GetNumBits()
    arr = np.zeros((n_bits,), dtype=np.int8)
    DataStructs.ConvertToNumpyArray(bitvect, arr)
    if drop_first:
        arr = arr[1:]
    return arr.astype(np.float32)

def get_ecfp(mol, radius=2, nbits=1024):
    return to_numpy_bitvect(AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nbits), n_bits=nbits)
def get_maccs(mol):
    bv = MACCSkeys.GenMACCSKeys(mol)
    return to_numpy_bitvect(bv, n_bits=bv.GetNumBits(), drop_first=True)
def get_avalon(mol, nbits=512):
    from rdkit.Avalon import pyAvalonTools
    return to_numpy_bitvect(pyAvalonTools.GetAvalonFP(mol, nbits), n_bits=nbits)
def get_topological_torsion(mol, nbits=1024):
    bv = rdMolDescriptors.GetHashedTopologicalTorsionFingerprintAsBitVect(mol, nBits=nbits)
    return to_numpy_bitvect(bv, n_bits=nbits)
def get_rdkit_desc(mol):
    calc = MoleculeDescriptors.MolecularDescriptorCalculator([d[0] for d in Descriptors._descList])
    try:
        descs = calc.CalcDescriptors(mol)
        descs = np.array(descs, dtype=np.float32)
        descs = np.nan_to_num(descs, nan=0.0, posinf=0.0, neginf=0.0)
    except Exception:
        descs = np.zeros(len(Descriptors._descList), dtype=np.float32)
    return descs
def get_rdkit_descriptor_length():
    return len(Descriptors._descList)

# -------------------- [2. SCAGE 임베딩 로드 & 차원 감지] --------------------
def load_scage_embeddings(scage_paths: dict):
    scage_embed_dicts, scage_dims = {}, {}
    for name, path in scage_paths.items():
        try:
            df = pd.read_csv(path)
        except Exception:
            scage_embed_dicts[name] = {}
            scage_dims[name] = 0
            continue
        def canon(s):
            m = Chem.MolFromSmiles(s)
            return Chem.MolToSmiles(m, canonical=True) if m else None
        df['smiles'] = df['smiles'].apply(canon)
        df = df.dropna(subset=['smiles']).reset_index(drop=True)
        embed_cols = [c for c in df.columns if c != 'smiles']
        dim = len(embed_cols)
        scage_dims[name] = dim
        scage_embed_dicts[name] = {
            row['smiles']: row[embed_cols].to_numpy(dtype=np.float32, copy=False)
            for _, row in df.iterrows()
        }
    return scage_embed_dicts, scage_dims

In [3]:
# -------------------- [3. 기대 차원 계산 + 안전 결합] --------------------
def compute_expected_dims(fp_types, scage_dims: dict):
    expected = OrderedDict()
    for t in fp_types:
        if t == 'ecfp': expected[t] = 1024
        elif t == 'avalon': expected[t] = 512
        elif t == 'maccs': expected[t] = 166
        elif t == 'tt': expected[t] = 1024
        elif t == 'rdkit': expected[t] = get_rdkit_descriptor_length()
        elif 'scage' in t: expected[t] = scage_dims.get(t, 512) if scage_dims.get(t, 0) > 0 else 512
        else: raise ValueError(f"Unknown fp_type: {t}")
    return expected
def safe_fit_to_dim(vec: np.ndarray, target_dim: int) -> np.ndarray:
    if vec is None: return np.zeros(target_dim, dtype=np.float32)
    vec = vec.astype(np.float32, copy=False)
    if np.any(np.isnan(vec)) or np.any(np.isinf(vec)):
        vec = np.nan_to_num(vec, nan=0.0, posinf=0.0, neginf=0.0)
    cur = vec.shape[0]
    if cur == target_dim: return vec
    elif cur < target_dim:
        pad = np.zeros(target_dim - cur, dtype=np.float32)
        return np.concatenate([vec, pad], axis=0)
    else: return vec[:target_dim]
def make_feature_vector(mol, smiles, fp_types, expected_dims, scage_embed_dicts):
    chunks = []
    for t in fp_types:
        dim = expected_dims[t]
        try:
            if t == 'ecfp': vec = get_ecfp(mol, radius=2, nbits=dim)
            elif t == 'avalon': vec = get_avalon(mol, nbits=dim)
            elif t == 'maccs': vec = get_maccs(mol)
            elif t == 'tt': vec = get_topological_torsion(mol, nbits=dim)
            elif t == 'rdkit': vec = get_rdkit_desc(mol)
            elif 'scage' in t: vec = scage_embed_dicts.get(t, {}).get(smiles, None)
            else: vec = None
        except Exception: vec = None
        chunks.append(safe_fit_to_dim(vec, dim))
    feat = np.concatenate(chunks, axis=0)
    feat = np.nan_to_num(feat, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    return feat


In [4]:
# -------------------- [4. Dataset] --------------------
class ScageConcatDataset(data.Dataset):
    def __init__(self, label_path, scage_paths: dict, fp_types, expected_dims=None):
        df = pd.read_csv(label_path)
        df = df[['smiles', 'p_np']].rename(columns={'smiles': 'smiles', 'p_np': 'label'})
        df['label'] = df['label'].replace({'BBB-': 0, 'BBB+': 1})
        df = df.drop_duplicates(subset='smiles').reset_index(drop=True)
        self.scage_embed_dicts, scage_dims = load_scage_embeddings(scage_paths)
        if expected_dims is None: expected_dims = compute_expected_dims(fp_types, scage_dims)
        self.expected_dims = expected_dims
        self.fp_types = list(fp_types)
        def canon(s):
            m = Chem.MolFromSmiles(s)
            return Chem.MolToSmiles(m, canonical=True) if m else None
        df['smiles'] = df['smiles'].apply(canon)
        df = df.dropna(subset=['smiles']).reset_index(drop=True)
        features, labels, failed = [], [], []
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Generating Features"):
            smi = row['smiles']
            mol = Chem.MolFromSmiles(smi)
            if mol is None: failed.append(smi); continue
            feat = make_feature_vector(mol, smi, self.fp_types, self.expected_dims, self.scage_embed_dicts)
            if feat is None or feat.ndim != 1: failed.append(smi); continue
            features.append(feat)
            labels.append(row['label'])
        self.features = torch.tensor(np.stack(features, axis=0), dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)
        self.df = df[~df['smiles'].isin(failed)].reset_index(drop=True)
    def __len__(self): return len(self.labels)
    def __getitem__(self, idx): return self.features[idx], self.labels[idx]

In [5]:
# -------------------- [5. 스플릿 + RDKit 정규화 (토글)] --------------------
def split_then_normalize(dataset: ScageConcatDataset, split_mode: str = "scaffold", train_ratio: float = 0.8, val_ratio: float = 0.1, seed: int = 700):
    set_seed(seed)
    df = dataset.df.copy()
    def get_scaffold(smi):
        m = Chem.MolFromSmiles(smi)
        return Chem.MolToSmiles(MurckoScaffold.GetScaffoldForMol(m)) if m else None
    df['scaffold'] = df['smiles'].apply(get_scaffold)
    groups = list(df.groupby('scaffold').groups.values())
    if split_mode == "scaffold": groups = sorted(groups, key=lambda g: len(g), reverse=True)
    elif split_mode == "random_scaffold":
        rnd = random.Random(seed)
        rnd.shuffle(groups)
    else: raise ValueError("split_mode must be 'scaffold' or 'random_scaffold'")
    n = len(df)
    train_cap = int(round(train_ratio * n))
    val_cap = int(round(val_ratio * n))
    train_idx, val_idx, test_idx = [], [], []
    for g in groups:
        g = list(g)
        if len(train_idx) + len(g) <= train_cap: train_idx += g
        elif len(val_idx) + len(g) <= val_cap: val_idx += g
        else: test_idx += g
    def pick(idxs): return dataset.features[idxs], dataset.labels[idxs]
    X_train, y_train = pick(train_idx)
    X_val, y_val = pick(val_idx)
    X_test, y_test = pick(test_idx)
    rd_start, rd_end = None, None
    offset = 0
    for t in dataset.fp_types:
        dim = dataset.expected_dims[t]
        if t == 'rdkit':
            rd_start, rd_end = offset, offset + dim
            break
        offset += dim
    scaler = None
    if rd_start is not None:
        scaler = StandardScaler().fit(X_train[:, rd_start:rd_end])
        X_train[:, rd_start:rd_end] = torch.tensor(scaler.transform(X_train[:, rd_start:rd_end]), dtype=torch.float32)
        X_val[:, rd_start:rd_end] = torch.tensor(scaler.transform(X_val[:, rd_start:rd_end]), dtype=torch.float32)
        X_test[:, rd_start:rd_end] = torch.tensor(scaler.transform(X_test[:, rd_start:rd_end]), dtype=torch.float32)
    return (data.TensorDataset(X_train, y_train), data.TensorDataset(X_val, y_val), data.TensorDataset(X_test, y_test), scaler, (rd_start, rd_end), (train_idx, val_idx, test_idx))

# -------------------- [6. 스플릿 인덱스 저장/로드] --------------------
def save_split_indices(out_dir, tag, split_indices):
    os.makedirs(out_dir, exist_ok=True)
    train_idx, val_idx, test_idx = split_indices
    np.save(os.path.join(out_dir, f"train_idx_{tag}.npy"), np.array(train_idx, dtype=np.int64))
    np.save(os.path.join(out_dir, f"val_idx_{tag}.npy"), np.array(val_idx, dtype=np.int64))
    np.save(os.path.join(out_dir, f"test_idx_{tag}.npy"), np.array(test_idx, dtype=np.int64))
def load_split_indices(path, tag):
    train_idx = np.load(os.path.join(path, f"train_idx_{tag}.npy"))
    val_idx = np.load(os.path.join(path, f"val_idx_{tag}.npy"))
    test_idx = np.load(os.path.join(path, f"test_idx_{tag}.npy"))
    return train_idx, val_idx, test_idx
# -------------------- [7. 구성/스케일러 저장·로드] --------------------
def save_config(cfg_path, config: dict):
    with open(cfg_path, 'w') as f:
        json.dump(config, f, indent=2)
def load_config(cfg_path):
    with open(cfg_path, 'r') as f:
        return json.load(f)
def save_scaler(path, scaler):
    with open(path, 'wb') as f:
        pickle.dump(scaler, f)
def load_scaler(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


In [6]:
# -------------------- [8. 모델] --------------------
class MultiModalGMLPFromFlat(nn.Module):
    def __init__(self, mod_dims: OrderedDict, d_model=512, d_ffn=1024, depth=4, dropout=0.2, use_gated_pool=True):
        super().__init__()
        self.mod_names = list(mod_dims.keys())
        self.mod_dims = [mod_dims[n] for n in self.mod_names]
        self.in_features = sum(self.mod_dims)
        self.seq_len = len(self.mod_names)
        self.use_gated_pool = use_gated_pool
        self.proj = nn.ModuleDict({name: nn.Linear(in_dim, d_model) for name, in_dim in zip(self.mod_names, self.mod_dims)})
        self.backbone = gMLP(seq_len=self.seq_len, d_model=d_model, d_ffn=d_ffn, num_layers=depth)
        self.norm = nn.LayerNorm(d_model)
        if use_gated_pool:
            self.alpha = nn.Parameter(torch.zeros(self.seq_len))
        self.head = nn.Linear(d_model, 1)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        chunks = torch.split(x, self.mod_dims, dim=1)
        tokens = [self.proj[name](chunk) for name, chunk in zip(self.mod_names, chunks)]
        X = torch.stack(tokens, dim=1)
        X = self.backbone(X)
        if self.use_gated_pool:
            w = torch.softmax(self.alpha, dim=0)
            Xp = (X * w.view(1, -1, 1)).sum(dim=1)
        else: Xp = X.mean(dim=1)
        Xp = self.drop(self.norm(Xp))
        logits = self.head(Xp).squeeze(-1)
        return logits
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [7]:
# -------------------- [9. 학습/평가 루틴] --------------------
def train_model(model, optimizer, train_loader, val_loader, loss_fn, num_epochs=50, patience=10):
    best_val = float('inf'); best_state = None; bad = 0
    for epoch in range(num_epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = loss_fn(logits, y)
            loss.backward()
            optimizer.step()
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                val_loss += loss_fn(model(x), y).item()
        val_loss /= len(val_loader)
        if val_loss < best_val:
            best_val = val_loss
            best_state = deepcopy(model.state_dict())
            bad = 0
        else:
            bad += 1
            if bad >= patience: break
    if best_state: model.load_state_dict(best_state)
    return model

def eval_model(model, loader):
    model.eval()
    y_true, y_prob, y_pred = [], [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            logits = model(x)
            probs = torch.sigmoid(logits).cpu().numpy()
            pred = (probs > 0.5).astype(int)
            y_prob.extend(probs)
            y_pred.extend(pred)
            y_true.extend(y.numpy())
    cm = confusion_matrix(y_true, y_pred)
    if cm.size == 4:
        tn, fp, fn, tp = cm.ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        sensitivity = recall_score(y_true, y_pred, zero_division=0)
    else: specificity, sensitivity = 0.0, 0.0
    return {
        'accuracy': round(accuracy_score(y_true, y_pred), 3),
        'precision': round(precision_score(y_true, y_pred, zero_division=0), 3),
        'recall': round(sensitivity, 3),
        'f1': round(f1_score(y_true, y_pred, zero_division=0), 3),
        'roc_auc': round(roc_auc_score(y_true, y_prob) if len(set(y_true)) > 1 else 0.0, 3),
        'mcc': round(matthews_corrcoef(y_true, y_pred), 3),
        'sensitivity': sensitivity,
        'specificity': specificity,
    }

In [8]:
# -------------------- [10. Optuna 학습/평가 루틴] --------------------
def train_model_mcc_optuna(model, optimizer, train_loader, val_loader, loss_fn, num_epochs=50, patience=10):
    best_val_mcc = -1.0
    best_state = None
    bad_epochs = 0
    for epoch in range(num_epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = loss_fn(logits, y)
            loss.backward()
            optimizer.step()
        model.eval()
        y_true, y_pred = [], []
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                logits = model(x)
                probs = torch.sigmoid(logits).cpu().numpy()
                pred = (probs > 0.5).astype(int)
                y_true.extend(y.cpu().numpy())
                y_pred.extend(pred)
        val_mcc = matthews_corrcoef(y_true, y_pred) if len(set(y_true)) > 1 else 0.0
        if val_mcc > best_val_mcc:
            best_val_mcc = val_mcc
            best_state = deepcopy(model.state_dict())
            bad_epochs = 0
        else:
            bad_epochs += 1
            if bad_epochs >= patience: break
    return best_val_mcc, best_state

def objective(trial: Trial, train_ds, val_ds, mod_dims, pos_weight):
    set_seed(42)
    d_model = trial.suggest_categorical('d_model', [256, 512, 1024])
    d_ffn = trial.suggest_int('d_ffn', d_model * 2, d_model * 4)
    depth = trial.suggest_int('depth', 2, 6)
    dropout = trial.suggest_float('dropout', 0.1, 0.4)
    lr = trial.suggest_float('lr', 1e-5, 1e-3, log=True)
    batch_size = trial.suggest_categorical('batch_size', [64, 128, 256])
    model = MultiModalGMLPFromFlat(
        mod_dims=mod_dims, d_model=d_model, d_ffn=d_ffn, depth=depth, dropout=dropout, use_gated_pool=True
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) if pos_weight is not None else nn.BCEWithLogitsLoss()
    g = torch.Generator(); g.manual_seed(42)
    train_loader = data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, generator=g)
    val_loader = data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, pin_memory=True)
    val_mcc, best_state = train_model_mcc_optuna(model, optimizer, train_loader, val_loader, loss_fn)
    trial.set_user_attr('best_model_state', best_state)
    return val_mcc

In [21]:
if __name__ == "__main__":
    # (A) 실험 스위치
    split_mode = "random_scaffold"
    seed = 42
    # (B) 경로/설정
    label_path = '/home/minji/Downloads/bbbp.csv'
    scage_paths = {'scage1': '/home/minji/scage/BBB/bench_embed.csv', 'scage2': '/home/minji/scage/BBB/bench_atom_embed.csv'}
    fp_types = ['ecfp', 'avalon', 'rdkit', 'maccs', 'tt', 'scage1', 'scage2']
    # (C) Dataset & dims
    dataset = ScageConcatDataset(label_path, scage_paths, fp_types=fp_types)
    expected_dims = dataset.expected_dims
    mod_dims = OrderedDict((t, expected_dims[t]) for t in fp_types)
    # (D) Split + RDKit normalize (8:1:1)
    train_ds, val_ds, test_ds, scaler, (rd_start, rd_end), split_indices = split_then_normalize(
        dataset, split_mode=split_mode, train_ratio=0.8, val_ratio=0.1, seed=seed
    )
    # (E) 스플릿 저장
    split_dir = "./splits"
    tag = f"{split_mode}_seed{seed}"
    save_split_indices(split_dir, tag, split_indices)
    print(f"[Save] split indices -> {split_dir} (tag={tag})")
    # (F) DataLoader
    train_loader = data.DataLoader(train_ds, batch_size=128, shuffle=True)
    val_loader = data.DataLoader(val_ds, batch_size=128, shuffle=False)
    test_loader = data.DataLoader(test_ds, batch_size=128, shuffle=False)
    # (G) Model
    set_seed(seed)
    model = MultiModalGMLPFromFlat(
        mod_dims=mod_dims, d_model=512, d_ffn=1048, depth=4, dropout=0.2, use_gated_pool=True
    ).to(device)
    # (H) pos_weight
    pos_weight = None
    try:
        y_train_np = train_ds.tensors[1].cpu().numpy()
        n_pos = (y_train_np == 1).sum()
        n_neg = (y_train_np == 0).sum()
        if n_pos > 0:
            pos_weight = torch.tensor([max(n_neg / n_pos, 1.0)], dtype=torch.float32, device=device)
            print(f"[Info] Using pos_weight={pos_weight.item():.4f} (neg/pos={n_neg}/{n_pos})")
    except Exception as e:
        print(f"[Warn] pos_weight auto-calc skipped: {e}")
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) if pos_weight is not None else nn.BCEWithLogitsLoss()
    # (I) Train
    print(f"--- [STEP 1] Training gMLP on BBBP (split_mode={split_mode}) ---")
    model = train_model(model, optimizer, train_loader, val_loader, loss_fn, num_epochs=50, patience=10)
    print("\n--- [STEP 1.5] Evaluating on BBBP Validation Split ---")
    baseline_val_metrics = eval_model(model, val_loader)
    print("\n" + "="*50)
    print(f"BASELINE PERFORMANCE on Validation Set (split={split_mode}, seed={seed})")
    print("="*50)
    for k, v in baseline_val_metrics.items():
        print(f"{k:<12}| {v:.4f}")
    print("="*50)
    # (J) Test eval
    print("\n--- [STEP 2] Evaluating on BBBP Test Split ---")
    metrics = eval_model(model, test_loader)
    print("\n" + "="*50)
    print(f"FINAL PERFORMANCE on BBBP (split={split_mode}, seed={seed})")
    print("="*50)
    for k, v in metrics.items():
        print(f"{k:<12}| {v:.4f}")
    print("="*50)
    # (K) 아티팩트 저장
    os.makedirs("./artifacts", exist_ok=True)
    model_path = f"./artifacts/gmlp_best_model_{tag}.pth"
    cfg_path = f"./artifacts/feature_config_{tag}.json"
    scaler_path = f"./artifacts/rdkit_scaler_{tag}.pkl"
    torch.save(model.state_dict(), model_path)
    print(f"\n[Save] model -> {model_path}")
    cfg = {
        "fp_types": fp_types,
        "mod_dims": {k: int(v) for k, v in mod_dims.items()},
        "rd_slice": [rd_start, rd_end] if rd_start is not None else None,
        "split_mode": split_mode,
        "seed": seed,
        "model_hparams": {"d_model": 512, "d_ffn": 1048, "depth": 4, "dropout": 0.2, "use_gated_pool": True},
        "train_params": {"lr": 1e-4, "weight_decay": 1e-5, "num_epochs": 50, "patience": 10},
        "class_balance": {"n_pos": int(n_pos) if 'n_pos' in locals() else None,
                            "n_neg": int(n_neg) if 'n_neg' in locals() else None,
                            "pos_weight": float(pos_weight.item()) if pos_weight is not None else None}
    }
    save_config(cfg_path, cfg)
    print(f"[Save] config -> {cfg_path}")
    if scaler is not None:
        save_scaler(scaler_path, scaler)
        print(f"[Save] RDKit scaler -> {scaler_path}")
    else:
        print("[Info] RDKit slice not found. Skipping scaler save.")
    print("\n--- [STEP 3] Final Learned Alpha Parameters ---")
    if model.use_gated_pool:
        learned_alpha = model.alpha.detach().cpu().numpy()
        normalized_weights = torch.softmax(torch.tensor(learned_alpha), dim=0).numpy()
        for i, name in enumerate(model.mod_names):
            print(f" - '{name}': alpha = {learned_alpha[i]:.4f}, normalized weight = {normalized_weights[i]:.4f}")
    else:
        print("Gated pooling is not enabled (use_gated_pool=False).")

  descs = np.array(descs, dtype=np.float32)
Generating Features: 100%|███████████████████████████████████████████████| 2039/2039 [00:14<00:00, 144.53it/s]


[Save] split indices -> ./splits (tag=random_scaffold_seed42)
[Info] Using pos_weight=1.0000 (neg/pos=403/1228)
--- [STEP 1] Training gMLP on BBBP (split_mode=random_scaffold) ---

--- [STEP 1.5] Evaluating on BBBP Validation Split ---

BASELINE PERFORMANCE on Validation Set (split=random_scaffold, seed=42)
accuracy    | 0.9310
precision   | 0.9330
recall      | 0.9880
f1          | 0.9600
roc_auc     | 0.9130
mcc         | 0.7350
sensitivity | 0.9882
specificity | 0.6471

--- [STEP 2] Evaluating on BBBP Test Split ---

FINAL PERFORMANCE on BBBP (split=random_scaffold, seed=42)
accuracy    | 0.8730
precision   | 0.9050
recall      | 0.9380
f1          | 0.9210
roc_auc     | 0.9040
mcc         | 0.5910
sensitivity | 0.9383
specificity | 0.6190

[Save] model -> ./artifacts/gmlp_best_model_random_scaffold_seed42.pth
[Save] config -> ./artifacts/feature_config_random_scaffold_seed42.json
[Save] RDKit scaler -> ./artifacts/rdkit_scaler_random_scaffold_seed42.pkl

--- [STEP 3] Final Learned 

In [22]:
    # -------------------- [Optuna 최적화 및 최종 평가] --------------------
    print("\n\n" + "-"*50)
    print("--- Starting Optuna Hyperparameter Optimization (Objective: Maximize MCC) ---")
    print("-" * 50)

    study = optuna.create_study(direction='maximize')
    study.optimize(lambda trial: objective(trial, train_ds, val_ds, mod_dims, pos_weight), n_trials=100, catch=(torch.cuda.OutOfMemoryError,))

    print("\n--- Best Hyperparameters Found ---")
    best_params = study.best_trial.params
    print(best_params)

    print("\n--- Final Evaluation with Best Parameters ---")

    model_best = MultiModalGMLPFromFlat(
        mod_dims=mod_dims,
        d_model=best_params['d_model'],
        d_ffn=best_params['d_ffn'],
        depth=best_params['depth'],
        dropout=best_params['dropout'],
        use_gated_pool=True
    ).to(device)

    best_state = study.best_trial.user_attrs['best_model_state']
    model_best.load_state_dict(best_state)

    val_loader_optuna = data.DataLoader(val_ds, batch_size=best_params['batch_size'], shuffle=False, pin_memory=True)
    optuna_val_metrics = eval_model(model_best, val_loader_optuna)

    print("\n" + "="*50)
    print("BEST OPTUNA MODEL PERFORMANCE on Validation Set")
    print("="*50)
    for k, v in optuna_val_metrics.items():
        print(f"{k:<12}| {v:.4f}")
    print("="*50)

    test_loader_optuna = data.DataLoader(test_ds, batch_size=best_params['batch_size'], shuffle=False, pin_memory=True)
    metrics_optuna = eval_model(model_best, test_loader_optuna)

    print("\n" + "="*50)
    print("FINAL PERFORMANCE on BBBP with Best HPs (Objective: Maximize MCC)")
    print("="*50)
    for k, v in metrics_optuna.items():
        print(f"{k:<12}| {v:.4f}")
    print("="*50)

    print("\n--- Final Learned Alpha Parameters with Best HPs ---")
    if model_best.use_gated_pool:
        learned_alpha = model_best.alpha.detach().cpu().numpy()
        normalized_weights = torch.softmax(torch.tensor(learned_alpha), dim=0).numpy()
        for i, name in enumerate(model_best.mod_names):
            print(f"  - '{name}': alpha = {learned_alpha[i]:.4f}, normalized weight = {normalized_weights[i]:.4f}")
    else:
        print("Gated pooling is not enabled (use_gated_pool=False).")

[I 2025-09-25 16:10:44,450] A new study created in memory with name: no-name-4b42614e-2afa-462a-84ac-e77e37c58af7




--------------------------------------------------
--- Starting Optuna Hyperparameter Optimization (Objective: Maximize MCC) ---
--------------------------------------------------


[I 2025-09-25 16:10:46,546] Trial 0 finished with value: 0.6968257452720762 and parameters: {'d_model': 256, 'd_ffn': 873, 'depth': 6, 'dropout': 0.3380397814001426, 'lr': 2.0652974923836226e-05, 'batch_size': 256}. Best is trial 0 with value: 0.6968257452720762.
[I 2025-09-25 16:10:47,971] Trial 1 finished with value: 0.6940220937885672 and parameters: {'d_model': 512, 'd_ffn': 1260, 'depth': 3, 'dropout': 0.373125896765245, 'lr': 0.0007498117518509908, 'batch_size': 256}. Best is trial 0 with value: 0.6968257452720762.
[I 2025-09-25 16:10:54,885] Trial 2 finished with value: 0.6709076657921977 and parameters: {'d_model': 1024, 'd_ffn': 2670, 'depth': 5, 'dropout': 0.22823159581075275, 'lr': 0.0005205543608410038, 'batch_size': 256}. Best is trial 0 with value: 0.6968257452720762.
[I 2025-09-25 16:11:01,803] Trial 3 finished with value: 0.7139191262717324 and parameters: {'d_model': 1024, 'd_ffn': 3310, 'depth': 3, 'dropout': 0.1695985323160793, 'lr': 0.00014560894554239704, 'batch_si


--- Best Hyperparameters Found ---
{'d_model': 256, 'd_ffn': 688, 'depth': 5, 'dropout': 0.2903268401300295, 'lr': 1.9921885074066318e-05, 'batch_size': 64}

--- Final Evaluation with Best Parameters ---

BEST OPTUNA MODEL PERFORMANCE on Validation Set
accuracy    | 0.9460
precision   | 0.9540
recall      | 0.9820
f1          | 0.9680
roc_auc     | 0.9370
mcc         | 0.7970
sensitivity | 0.9824
specificity | 0.7647

FINAL PERFORMANCE on BBBP with Best HPs (Objective: Maximize MCC)
accuracy    | 0.8920
precision   | 0.9120
recall      | 0.9570
f1          | 0.9340
roc_auc     | 0.8990
mcc         | 0.6510
sensitivity | 0.9568
specificity | 0.6429

--- Final Learned Alpha Parameters with Best HPs ---
  - 'ecfp': alpha = 0.0052, normalized weight = 0.1436
  - 'avalon': alpha = -0.0009, normalized weight = 0.1427
  - 'rdkit': alpha = -0.0038, normalized weight = 0.1423
  - 'maccs': alpha = -0.0011, normalized weight = 0.1427
  - 'tt': alpha = 0.0031, normalized weight = 0.1433
  - 'scag

In [11]:
import os
import json
import torch
import torch.utils.data as data
from collections import OrderedDict
# 필요한 모든 함수와 클래스 정의
# MultiModalGMLPFromFlat, eval_model, load_scaler, load_split_indices 등
# 앞선 스크립트의 모든 함수를 복사/붙여넣기 해야 합니다.

# -------------------- [1. 설정] --------------------
# 비교할 시드와 태그를 명확히 정의합니다.
seed_to_load = 100
split_mode = "random_scaffold"
tag = f"{split_mode}_seed{seed_to_load}"

# -------------------- [2. 아티팩트 로드] --------------------
# 베이스라인 모델 아티팩트 로드
baseline_model_path = f"./artifacts/gmlp_best_model_{tag}.pth"
baseline_hparams_path = f"./artifacts/feature_config_{tag}.json"
scaler_path = f"./artifacts/rdkit_scaler_{tag}.pkl"
split_dir = "./splits"
label_path = '/home/minji/Downloads/bbbp.csv'
scage_paths = {'scage1': '/home/minji/scage/BBB/bench_embed.csv', 'scage2': '/home/minji/scage/BBB/bench_atom_embed.csv'}
fp_types = ['ecfp', 'avalon', 'rdkit', 'maccs', 'tt', 'scage1', 'scage2']

# Optuna 최적화 모델 아티팩트 로드 (경로를 실제로 저장한 경로로 수정해야 함)
optuna_model_path = "./best_model_artifacts/best_model_seed100.pth"
optuna_hparams_path = "./best_model_artifacts/best_hparams_seed100.json"

# -------------------- [3. 데이터셋 및 스케일러 준비] --------------------
# 베이스라인과 Optuna 모두 동일한 데이터 분할과 스케일러를 사용합니다.
train_idx, val_idx, test_idx = load_split_indices(split_dir, tag)
scaler = load_scaler(scaler_path)

# 데이터셋을 재구성하고 RDKit 피처의 시작/끝 인덱스를 찾습니다.
dataset = ScageConcatDataset(label_path, scage_paths, fp_types=fp_types)
test_ds = data.TensorDataset(dataset.features[test_idx], dataset.labels[test_idx])
# expected_dims 변수를 여기서 정의합니다.
expected_dims = dataset.expected_dims
mod_dims = OrderedDict((t, expected_dims[t]) for t in fp_types)

# **수정된 부분:** RDKit 피처 시작/끝 인덱스를 올바르게 계산
rd_start, rd_end = None, None
offset = 0
for t in fp_types:
    dim = expected_dims[t]
    if t == 'rdkit':
        rd_start, rd_end = offset, offset + dim
        break
    offset += dim

# 저장된 스케일러를 테스트 데이터셋에 적용
test_ds.tensors[0][:, rd_start:rd_end] = torch.tensor(
    scaler.transform(test_ds.tensors[0][:, rd_start:rd_end]), 
    dtype=torch.float32
)

# -------------------- [4. 모델 로드 및 평가 함수] --------------------
def load_and_evaluate(model_path, hparams_path, test_ds, mod_dims):
    with open(hparams_path, 'r', encoding='utf-8') as f:
        cfg = json.load(f)

    if "model_hparams" in cfg:
        hparams = cfg["model_hparams"]
        batch_size = cfg["train_params"]["batch_size"]
    else:
        hparams = cfg
        hparams.pop('mod_dims', None)
        batch_size = hparams.pop('batch_size', 128)

    model = MultiModalGMLPFromFlat(mod_dims=mod_dims, **hparams).to(device)
    model.load_state_dict(torch.load(model_path))

    test_loader = data.DataLoader(test_ds, batch_size=batch_size, shuffle=False)
    metrics = eval_model(model, test_loader)
    return metrics

# -------------------- [5. 결과 비교] --------------------
print(f"--- Comparing Models for seed = {seed_to_load} ---")
# 베이스라인 모델 평가
baseline_metrics = load_and_evaluate(baseline_model_path, baseline_hparams_path, test_ds, mod_dims)
print("\n" + "="*50)
print(f"BASELINE PERFORMANCE (seed={seed_to_load})")
print("="*50)
for k, v in baseline_metrics.items():
    print(f"{k:<12}| {v:.4f}")
print("="*50)

# Optuna 최적화 모델 평가
optuna_metrics = load_and_evaluate(optuna_model_path, optuna_hparams_path, test_ds, mod_dims)
print("\n" + "="*50)
print(f"OPTUNA OPTIMIZED PERFORMANCE (seed={seed_to_load})")
print("="*50)
for k, v in optuna_metrics.items():
    print(f"{k:<12}| {v:.4f}")
print("="*50)

# -------------------- [6. 최종 요약] --------------------
print("\n--- Summary ---")
print(f"Seed: {seed_to_load}")
print(f"Baseline Test MCC: {baseline_metrics['mcc']:.4f}")
print(f"Optuna Test MCC: {optuna_metrics['mcc']:.4f}")

  descs = np.array(descs, dtype=np.float32)
Generating Features: 100%|███████████████████████████████████████████████| 2039/2039 [00:14<00:00, 140.77it/s]

--- Comparing Models for seed = 100 ---





KeyError: 'batch_size'