In [2]:
# run_test.py
# 无监督流形学习生成分子 3D 构象（带拓扑门控 + 键长弹簧 + 轻量排斥 + Adam）
import os
import math
import numpy as np
import imageio.v2 as imageio
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
from rdkit import Chem

from dist import (
    compute_augmented_graph_distance,
    compute_AE_tanimoto_distance,
    compute_embed3d_distance,
)

# ===================== 可调超参 =====================
SMILES = os.environ.get("SMILES", "c1ccccc1")  # 默认苯环，可改
K_HOP_MAX = 2                # 只允许 hop<=K 的近邻参与 P
N_NEIGHBOR = 3               # 每个点选 k 个近邻
ALPHA = 0.5                  # D_feat 权重
BETA = 1.0                   # D_edge 权重
GAMMA = 1.0                  # D_topo 权重
MIN_DIST = 0.4               # UMAP 曲线 “最小距离”感知
EE_EPOCHS = 80               # early exaggeration 轮数
EE_FACTOR = 5.0              # 夸大系数
EPOCHS = 400
LR = 0.03                    # Adam 学习率
LAMBDA_BOND = 2.0            # 键长弹簧权重
LAMBDA_REPULSE = 0.2         # 轻量排斥权重
REPULSION_CUTOFF = 1.2       # Å，非键/远端对的排斥触发
TARGET_RMS = 1.0             # 统一尺度
SAVE_EVERY = 10              # 每多少 epoch 存一帧
OUTDIR = "./out_frames"
OUT_MP4 = "optimization.mp4"
OUT_GIF = "optimization.gif"

os.makedirs(OUTDIR, exist_ok=True)

# ===================== 工具函数 =====================
def rdkit_mol_from_smiles(smi: str):
    m = Chem.MolFromSmiles(smi)
    if m is None:
        raise ValueError(f"Invalid SMILES: {smi}")
    m = Chem.AddHs(m)  # 加氢有助于键长参考
    return m

def hop_matrix(mol: Chem.Mol, kmax: int = 3):
    n = mol.GetNumAtoms()
    adj = [[] for _ in range(n)]
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        adj[i].append(j); adj[j].append(i)
    INF = 10**9
    H = np.full((n, n), INF, dtype=int)
    for s in range(n):
        H[s, s] = 0
        q = [s]
        while q:
            u = q.pop(0)
            for v in adj[u]:
                if H[s, v] == INF:
                    H[s, v] = H[s, u] + 1
                    if H[s, v] < kmax:
                        q.append(v)
    return H

def fuzzy_union_symmetrize(P):
    # UMAP 的 fuzzy union：P ⊕ P^T = P + P^T - P * P^T（逐元素）
    return P + P.T - P * P.T

def find_ab_params(spread: float = 1.0, min_dist: float = 0.4):
    # 近似 UMAP 的 a,b 拟合
    # 参考：umap/umap_.py（这里用数值法近似解）
    def curve(d, a, b):
        return 1.0 / (1.0 + a * (d ** (2 * b)))
    # 用粗网格快速拟合
    best_a, best_b, best_err = 1.0, 1.0, 1e9
    xs = np.linspace(0, spread * 3, 200)
    target = np.where(xs < min_dist, 1.0, np.exp(-(xs - min_dist)))
    for a in np.linspace(0.5, 3.0, 20):
        for b in np.linspace(0.3, 1.5, 25):
            pred = curve(xs, a, b)
            err = np.mean((pred - target) ** 2)
            if err < best_err:
                best_a, best_b, best_err = a, b, err
    return best_a, best_b

def init_spectral(P, ndim=3):
    # 用对称相似度做谱嵌入（手写版）
    W = (P + P.T) / 2.0
    W[W < 0] = 0.0
    d = W.sum(axis=1)
    D_inv_sqrt = np.diag(1.0 / np.sqrt(d + 1e-12))
    S = D_inv_sqrt @ W @ D_inv_sqrt
    # 取最大的若干特征向量（舍弃 DC 分量）
    vals, vecs = np.linalg.eigh(S)
    idx = np.argsort(vals)[::-1]
    vecs = vecs[:, idx]
    X = vecs[:, 1:ndim+1]  # 跳过第一列（常量）
    # 居中+统一尺度
    X = X - X.mean(axis=0, keepdims=True)
    rms = np.sqrt((X**2).mean())
    X = X / (rms + 1e-12)
    return X

def umap_Q(Y, a, b):
    # Q_ij = 1 / (1 + a * ||y_i - y_j||^(2b))
    n = Y.shape[0]
    diff = Y[:, None, :] - Y[None, :, :]
    d = np.sqrt((diff ** 2).sum(axis=2) + 1e-12)
    Q = 1.0 / (1.0 + a * (d ** (2 * b)))
    np.fill_diagonal(Q, 0.0)
    return Q, d

def ce_grad_attractive(Y, P, a, b, w_pos=1.0):
    """
    只对正样本（P>0）做吸引： - w_pos * sum P_ij * log Q_ij
    dQ/dd = -(2ab d^(2b-1)) / (1 + a d^(2b))^2
    grad_i += (dL/dd) * (yi - yj)/d
    """
    n = Y.shape[0]
    Q, D = umap_Q(Y, a, b)
    eps = 1e-8
    mask = P > 0
    Pi = P[mask]
    Qi = Q[mask] + eps
    # dL/dQ = -w_pos * P / Q
    dLdQ = - w_pos * (Pi / Qi)
    # dQ/dd
    # 注意：只取对应 ij 的 d 值
    Dij = D[mask] + eps
    num = (2.0 * a * b) * (Dij ** (2 * b - 1.0))
    den = (1.0 + a * (Dij ** (2 * b))) ** 2
    dQdd = - num / (den + eps)
    dLdd = dLdQ * dQdd  # shape: (#edges,)

    # 累加到坐标梯度
    grad = np.zeros_like(Y)
    idx_i, idx_j = np.where(mask)
    vec_ij = (Y[idx_i] - Y[idx_j]) / Dij[:, None]
    contrib = dLdd[:, None] * vec_ij
    # i 加， j 减
    for k in range(len(idx_i)):
        i, j = int(idx_i[k]), int(idx_j[k])
        g = contrib[k]
        grad[i] += g
        grad[j] -= g
    loss_pos = - np.sum(w_pos * Pi * np.log(Qi))
    return loss_pos, grad

# ---- 键长弹簧 & 轻量排斥 ----
def rdkit_bond_targets(mol: Chem.Mol):
    # 用 ETKDG 估一个“软目标”键长
    from rdkit.Chem import AllChem
    m2 = Chem.Mol(mol)
    AllChem.EmbedMolecule(m2, AllChem.ETKDG())
    conf = m2.GetConformer()
    N = m2.GetNumAtoms()
    pos = np.array([list(conf.GetAtomPosition(i)) for i in range(N)], dtype=float)
    tgt = {}
    for b in m2.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        dij = np.linalg.norm(pos[i] - pos[j]) + 1e-12
        tgt[(i, j)] = tgt[(j, i)] = float(dij)
    return tgt

def bond_spring(Y, mol: Chem.Mol, targets: dict):
    loss, grad = 0.0, np.zeros_like(Y)
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        diff = Y[i] - Y[j]
        d = float(np.sqrt((diff * diff).sum()) + 1e-12)
        L = targets.get((i, j), d)
        l = (d - L) ** 2
        loss += l
        # grad
        g = 2.0 * (d - L) * (diff / d)
        grad[i] += g
        grad[j] -= g
    nb = max(1, mol.GetNumBonds())
    return loss / nb, grad / nb

def repulsion(Y, hopM, cutoff=1.2, exclude_hop_le=2):
    """
    对远端对（hop > exclude）做软排斥：max(0, c - d)^2
    """
    n = Y.shape[0]
    loss = 0.0
    grad = np.zeros_like(Y)
    cnt = 0
    for i in range(n):
        for j in range(i+1, n):
            if hopM[i, j] <= exclude_hop_le:
                continue
            diff = Y[i] - Y[j]
            d = float(np.sqrt((diff * diff).sum()) + 1e-12)
            if d < cutoff:
                # (c - d)^2
                l = (cutoff - d) ** 2
                loss += l
                coef = 2.0 * (d - cutoff) / d  # = dL/dd * (1/d)
                g = coef * diff
                grad[i] += g
                grad[j] -= g
                cnt += 1
    cnt = max(1, cnt)
    return loss / cnt, grad / cnt

def center_and_rescale(Y, target_rms=1.0):
    Y = Y - Y.mean(axis=0, keepdims=True)
    rms = np.sqrt((Y ** 2).mean())
    if rms < 1e-12:
        return Y
    return Y * (target_rms / rms)

# ---- P 的构造（k-hop 门控 + 每行选 k 个最近邻 + 指数权重）----
def build_P_from_dist(D, hopM, k=3, khop_max=2):
    n = D.shape[0]
    P = np.zeros((n, n), dtype=float)
    for i in range(n):
        # 只考虑 hop<=khop_max 的候选
        mask = (hopM[i] <= khop_max) & (np.arange(n) != i)
        cand_idx = np.where(mask)[0]
        if cand_idx.size == 0:
            continue
        di = D[i, cand_idx]
        order = np.argsort(di)
        nn = cand_idx[order[:min(k, order.size)]]
        base = di[order[:min(k, order.size)]]
        # 以中位数为尺度，做指数权重
        sigma = float(np.median(base) + 1e-6)
        w = np.exp(- base / sigma)
        w = w / (w.sum() + 1e-12)
        P[i, nn] = w
    # fuzzy union 对称化
    P = P + P.T - P * P.T
    # 归一化到 [0,1]
    P[P < 0] = 0.0
    P[P > 1] = 1.0
    np.fill_diagonal(P, 0.0)
    return P

# ---- Adam ----
class Adam:
    def __init__(self, shape, lr=0.03, b1=0.9, b2=0.999, eps=1e-8):
        self.m = np.zeros(shape, dtype=float)
        self.v = np.zeros(shape, dtype=float)
        self.t = 0
        self.lr = lr
        self.b1 = b1
        self.b2 = b2
        self.eps = eps
    def step(self, grad):
        self.t += 1
        self.m = self.b1 * self.m + (1 - self.b1) * grad
        self.v = self.b2 * self.v + (1 - self.b2) * (grad * grad)
        mhat = self.m / (1 - (self.b1 ** self.t))
        vhat = self.v / (1 - (self.b2 ** self.t))
        return - self.lr * mhat / (np.sqrt(vhat) + self.eps)

# ===================== 主流程 =====================
def main():
    mol = rdkit_mol_from_smiles(SMILES)
    n = mol.GetNumAtoms()

    # --- 高维距离（D1 为主；D2 可做弱先验/仅监控）
    D1 = compute_augmented_graph_distance(mol, alpha=ALPHA, beta=BETA, gamma=GAMMA)
    # D2 = compute_AE_tanimoto_distance(mol)  # 如需融合，可自己权重相加
    # D3 = compute_embed3d_distance(mol)      # 可作为评估对照

    # --- k-hop 门控矩阵
    H = hop_matrix(mol, kmax=K_HOP_MAX+1)

    # --- 构造 P（仅拓扑近邻进入）
    P = build_P_from_dist(D1, H, k=N_NEIGHBOR, khop_max=K_HOP_MAX)

    # --- 谱嵌入初始化
    Y = init_spectral(P, ndim=3)

    # --- UMAP a,b
    a_umap, b_umap = find_ab_params(min_dist=MIN_DIST)

    # --- 物理项准备
    bond_targets = rdkit_bond_targets(mol)

    # --- 优化器
    opt = Adam(Y.shape, lr=LR)

    # --- 训练循环
    frames = []
    for epoch in range(1, EPOCHS + 1):
        # Early Exaggeration
        wpos = EE_FACTOR if epoch <= EE_EPOCHS else 1.0

        # CE 吸引（正样本）
        loss_pos, g_pos = ce_grad_attractive(Y, P, a_umap, b_umap, w_pos=wpos)

        # 键长弹簧
        lb, g_b = bond_spring(Y, mol, bond_targets)

        # 轻量排斥
        lrp, g_rp = repulsion(Y, H, cutoff=REPULSION_CUTOFF, exclude_hop_le=K_HOP_MAX)

        # 总损失与梯度
        loss = loss_pos + LAMBDA_BOND * lb + LAMBDA_REPULSE * lrp
        grad = g_pos + LAMBDA_BOND * g_b + LAMBDA_REPULSE * g_rp

        # 更新
        step = opt.step(grad)
        Y = Y + step
        Y = center_and_rescale(Y, target_rms=TARGET_RMS)

        # 监控指标
        bond_rmse = math.sqrt(lb)
        # 冲突统计（远端对距离 < 1.0 Å）
        coll_cnt = 0
        for i in range(n):
            for j in range(i+1, n):
                if H[i, j] <= K_HOP_MAX:  # 邻/次邻不算碰撞
                    continue
                d = float(np.linalg.norm(Y[i] - Y[j]))
                if d < 1.0:
                    coll_cnt += 1

        if epoch % 10 == 0 or epoch == 1:
            print(f"[{epoch:04d}] loss={loss:.4f}  bondRMSE={bond_rmse:.3f}  collisions={coll_cnt}  wpos={wpos:g}")

        # 保存帧图
        if epoch % SAVE_EVERY == 0 or epoch in (1, EPOCHS):
            fig = plt.figure(figsize=(4, 4))
            ax = fig.add_subplot(111, projection='3d')
            ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], s=30, depthshade=True)
            # 画成键边
            for bond in mol.GetBonds():
                i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                xs = [Y[i, 0], Y[j, 0]]
                ys = [Y[i, 1], Y[j, 1]]
                zs = [Y[i, 2], Y[j, 2]]
                ax.plot(xs, ys, zs, linewidth=1.0)
            ax.set_title(f"Epoch {epoch}")
            ax.set_axis_off()
            fpath = os.path.join(OUTDIR, f"frame_{epoch:04d}.png")
            plt.tight_layout()
            plt.savefig(fpath, dpi=150)
            plt.close(fig)
            frames.append(imageio.imread(fpath))

    # 导出 mp4 与 gif（可按需选一个）
    imageio.mimsave(OUT_MP4, frames, fps=max(2, 1000 // SAVE_EVERY))
    imageio.mimsave(OUT_GIF, frames, fps=max(2, 1000 // SAVE_EVERY))
    print(f"Saved {OUT_MP4} and {OUT_GIF} with {len(frames)} frames.")

if __name__ == "__main__":
    main()


[0001] loss=161.0642  bondRMSE=0.553  collisions=0  wpos=5
[0010] loss=156.6381  bondRMSE=0.466  collisions=0  wpos=5
[0020] loss=149.9649  bondRMSE=0.518  collisions=0  wpos=5
[0030] loss=144.1462  bondRMSE=0.685  collisions=0  wpos=5
[0040] loss=136.3253  bondRMSE=0.815  collisions=1  wpos=5
[0050] loss=128.2985  bondRMSE=0.874  collisions=2  wpos=5
[0060] loss=121.4432  bondRMSE=0.918  collisions=3  wpos=5
[0070] loss=117.8474  bondRMSE=0.966  collisions=6  wpos=5
[0080] loss=116.3898  bondRMSE=0.996  collisions=6  wpos=5
[0090] loss=24.8749  bondRMSE=0.985  collisions=6  wpos=1
[0100] loss=24.9650  bondRMSE=0.942  collisions=5  wpos=1
[0110] loss=25.1664  bondRMSE=0.915  collisions=4  wpos=1
[0120] loss=25.3269  bondRMSE=0.900  collisions=4  wpos=1
[0130] loss=25.4029  bondRMSE=0.889  collisions=3  wpos=1
[0140] loss=25.4774  bondRMSE=0.884  collisions=3  wpos=1
[0150] loss=25.5609  bondRMSE=0.883  collisions=3  wpos=1
[0160] loss=25.6145  bondRMSE=0.881  collisions=3  wpos=1
[0170



[0400] loss=23.5574  bondRMSE=0.958  collisions=4  wpos=1
Saved optimization.mp4 and optimization.gif with 41 frames.
