# Package

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter
from rdkit import Chem
import torch
import os
from glob import glob
import numpy as np
import sys
root_path = "/data/mntdata/aaa/molsolver/"
sys.path.append(root_path)
folder_path = root_path + 'data/test_set'
eval_path = root_path + 'test/samples'

## 加载分子

In [None]:
# ligand_names = []

# # 获取所有子文件夹，并按名称排序
# subfolders = [os.path.join(folder_path, d) for d in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, d))]
# subfolders.sort()

# # 按子文件夹顺序读取 .sdf 文件
# for subfolder in subfolders:
#     subfolder_name = os.path.basename(subfolder)  # 获取子文件夹的名称
#     for root, dirs, files in os.walk(subfolder):
#         for file in files:
#             if file.endswith('.sdf'):
#                 # 将子文件夹名和文件名组合，形式如 "123/abc.sdf"
#                 ligand_name = f"{subfolder_name}/{file}"
#                 ligand_names.append(ligand_name)


In [None]:
ar_path = glob(os.path.join(eval_path, "ar_vina_docked_pose_checked.pt"))[0]
flag_path = glob(os.path.join(eval_path, "flag_official_vina_docked_pose_checked.pt"))[0]
tg_path = glob(os.path.join(eval_path, "targetdiff_vina_docked_pose_checked.pt"))[0]
# decomp_path = glob(os.path.join(eval_path, "decompdiff_vina_docked_pose_checked.pt"))[0]
decomp_ref_path = glob(os.path.join(eval_path, "decompdiff_ref_vina_docked_pose_checked.pt"))[0]
ref_path = glob(os.path.join(eval_path, "crossdocked_test_vina_docked_pose_checked.pt"))[0]
pocket2mol_path = glob(os.path.join(eval_path, "pocket2mol_vina_docked_pose_checked.pt"))[0]
molcraft_path = glob(os.path.join(eval_path, "molcraft_vina_docked_pose_checked.pt"))[0]
molcraft_large_path = glob(os.path.join(eval_path, "molcraft_large_vina_docked_pose_checked.pt"))[0]
# ours_path = glob(os.path.join(eval_path, "success_ours.pt"))[0]

ours_path = glob(os.path.join(eval_path, "steps80_success.pt"))[0]
# train_path = glob(os.path.join(eval_path, "crossdocked_train.pt"))[0]

In [None]:

from eval_utils import ModelResults


ref = ModelResults('Reference', ref_path)
ar = ModelResults('AR', ar_path)
p2m = ModelResults('Pocket2Mol', pocket2mol_path)
flag = ModelResults('FLAG', flag_path)
tg = ModelResults('TargetDiff', tg_path)
# dcmp = ModelResults('DecompDiff-O', decomp_path)
dcmp_ref = ModelResults('DecompDiff', decomp_ref_path)
molcraft = ModelResults('MolCRAFT', molcraft_path)
ours = ModelResults('Ours', ours_path)

models = [ref, ar, p2m, flag, tg, dcmp_ref, molcraft, ours]
for model in models:
    model.load_pose_checked()

In [None]:
# 绘制应变能

In [None]:
import matplotlib.font_manager as fm
print(any('Times New Roman' in f.name for f in fm.fontManager.ttflist))


In [None]:
import matplotlib.pyplot as plt
import math

# 假设 models 是包含所有模型对象的列表
# 每个模型对象有属性 name 和 flat_results
# flat_results 是一个列表，其中每个元素是一个字典，
# 在字典中的 'pose_check' 里有 'strain' 保存应变能的数值

boxplot_data = []  # 用来保存每个模型的应变能数据
model_names = []   # 用来保存对应的模型名称

for model in models:
    strain_values = [
        result['pose_check']['strain']
        for result in model.flat_results
        if 'pose_check' in result and 'strain' in result['pose_check']
    ]
    if strain_values:
        boxplot_data.append(strain_values)
        model_names.append(model.name)

# 假设 boxplot_data 是一个列表的列表，每个子列表包含数值
cleaned_boxplot_data = []
for sublist in boxplot_data:
    # 过滤掉 NaN 值（注意：先判断类型，再用 math.isnan）
    cleaned_sublist = [x for x in sublist if not (isinstance(x, float) and math.isnan(x))]
    cleaned_boxplot_data.append(cleaned_sublist)

# 如果想覆盖原数据，可以将 cleaned_boxplot_data 赋值回 boxplot_data
boxplot_data = cleaned_boxplot_data


# 绘制所有模型的箱型图在同一图中
# plt.figure(figsize=(10, 6))
# plt.boxplot(boxplot_data[0], labels=model_names)
plt.boxplot(boxplot_data)
# plt.xlabel('模型名称')
plt.ylabel('Strin Energy')
# plt.title('各模型应变能箱型图')
plt.autoscale()
plt.show()


In [None]:
cleaned_boxplot_data = []
for sublist in boxplot_data:
    # 只保留小于等于 5000 的值
    filtered_sublist = [x for x in sublist if x <= 1000]
    cleaned_boxplot_data.append(filtered_sublist)

# 如果你想直接更新 boxplot_data：
# plt.figure(figsize=(10, 6))
# plt.boxplot(cleaned_boxplot_data, labels=model_names)
# plt.violinplot(cleaned_boxplot_data, showmeans=True)

In [None]:
import numpy as np

fig, ax = plt.subplots(figsize=(4,3))
bp = ax.boxplot(cleaned_boxplot_data, patch_artist=True)

# 使用内置 colormap 'viridis' 自动生成颜色，生成6个颜色
cmap = plt.get_cmap("Pastel1")
colors = cmap(np.linspace(0, 1, len(bp['boxes'])))
font_path = "/usr/share/fonts/truetype/msttcorefonts/times.ttf"   # 或 ~/.fonts/times.ttf
tnr = fm.FontProperties(fname=font_path, size=15)
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
ax.set_ylabel("strain energy", fontproperties=tnr)
ax.xaxis.set_ticks([])
plt.show()
fig.savefig("se.png", dpi=300, bbox_inches='tight')


In [None]:
rmsd_data = []  # 用来保存每个模型的应变能数据
model_names = []   # 用来保存对应的模型名称

for model in models:
    rmsd_values = [
        result['rmsd']
        for result in model.flat_results
        if 'rmsd' in result
    ]
    if rmsd_values:
        rmsd_data.append(rmsd_values)
        model_names.append(model.name)


# 假设 boxplot_data 是一个列表的列表，每个子列表包含数值
cleaned_rmsd_data = []
for sublist in rmsd_data:
    # 过滤掉 NaN 值（注意：先判断类型，再用 math.isnan）
    cleaned_sublist = [x for x in sublist if not math.isnan(x) and x <= 20]
    
    cleaned_rmsd_data.append(cleaned_sublist)

# 如果想覆盖原数据，可以将 cleaned_boxplot_data 赋值回 boxplot_data
# rmsd_data = cleaned_rmsd_data


# # 绘制所有模型的箱型图在同一图中
# plt.figure(figsize=(10, 6))
# # plt.boxplot(boxplot_data[0], labels=model_names)
# plt.boxplot(rmsd_data, labels=model_names)
# # plt.xlabel('模型名称')
# plt.ylabel('rmsd')
# # plt.title('各模型应变能箱型图')
# plt.autoscale()
# plt.show()

In [None]:
# import numpy as np

# fig, ax = plt.subplots()
# bp = ax.boxplot(rmsd_data, labels=model_names, patch_artist=True)

# # 使用内置 colormap 'viridis' 自动生成颜色，生成6个颜色
# cmap = plt.get_cmap("tab10")
# colors = cmap(np.linspace(0, 1, len(bp['boxes'])))

# for patch, color in zip(bp['boxes'], colors):
#     patch.set_facecolor(color)
# plt.title('rmsd')
# plt.show()

In [None]:
proportions = [np.mean(np.array(data) < 2) for data in rmsd_data]
fig, ax = plt.subplots(figsize=(4,3))
bars = ax.bar(range(len(proportions)), proportions)

# 使用 Matplotlib 内置的 "Set2" 颜色映射，生成与柱子数量相同的颜色
cmap = plt.get_cmap("Pastel1")
colors = cmap(np.linspace(0, 1, len(proportions)))

# 为每个柱子设置对应的颜色
for bar, color in zip(bars, colors):
    bar.set_color(color)

# ax.set_ylabel("比例 (RMSD < 2)")
# ax.set_xlabel("模型")
# ax.set_ylabel("RMSD")
# # ax.set_ylim(0, 1)  # 因为比例在 0 到 1 之间
# # 移除 x 轴所有刻度
# ax.xaxis.set_ticks([])
# plt.show()
ax.set_ylabel("RMSD", fontproperties=tnr)
ax.xaxis.set_ticks([])
plt.show()
fig.savefig("rmsd.png", dpi=300, bbox_inches='tight')

In [None]:
print(proportions)

In [None]:
clash_data = []  # 用来保存每个模型的应变能数据
model_names = []   # 用来保存对应的模型名称
cleaned_clash_data = []

for model in models:
    clash_values = [
        result['pose_check']['clash']
        for result in model.flat_results
        if 'pose_check' in result and 'clash' in result['pose_check']
    ]
    if clash_values:
        clash_data.append(clash_values)
        model_names.append(model.name)


# 假设 boxplot_data 是一个列表的列表，每个子列表包含数值

for sublist in clash_data:
    # 过滤掉 NaN 值（注意：先判断类型，再用 math.isnan）
    cleaned_sublist = [x for x in sublist if not (isinstance(x, float) and math.isnan(x))]
    cleaned_clash_data.append(cleaned_sublist)

# 如果想覆盖原数据，可以将 cleaned_boxplot_data 赋值回 boxplot_data
clash_data = cleaned_clash_data
tmp_clash_data =  [
    [x for x in sublist if not math.isnan(x) and x <= 30]
    for sublist in clash_data
]
clash_data = tmp_clash_data


# # 绘制所有模型的箱型图在同一图中
# plt.figure(figsize=(10, 6))
# # plt.boxplot(boxplot_data[0], labels=model_names)
# plt.boxplot(clash_data, labels=model_names)
# # plt.xlabel('模型名称')
# plt.ylabel('clash')
# # plt.title('各模型应变能箱型图')
# plt.autoscale()
# plt.show()

In [None]:
import numpy as np
import matplotlib.patches as mpatches
fig, ax = plt.subplots(figsize=(4,3))
bp = ax.boxplot(clash_data, patch_artist=True)

# 使用内置 colormap 'viridis' 自动生成颜色，生成6个颜色
cmap = plt.get_cmap("Pastel1")
colors = cmap(np.linspace(0, 1, len(bp['boxes'])))

for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)

# ax.set_ylabel("clash")
# ax.xaxis.set_ticks([])
# plt.tight_layout()  # 调整布局，防止图例遮挡
# plt.show()
ax.set_ylabel("clash", fontproperties=tnr)
ax.xaxis.set_ticks([])
plt.show()
fig.savefig("clash.png", dpi=300, bbox_inches='tight')

In [None]:
import numpy as np
import matplotlib.patches as mpatches
fig, ax = plt.subplots()
bp = ax.boxplot(clash_data, patch_artist=True)

# 使用内置 colormap 'viridis' 自动生成颜色，生成6个颜色
cmap = plt.get_cmap("Pastel1")
colors = cmap(np.linspace(0, 1, len(bp['boxes'])))

for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)

# 生成图例，显示每个箱型图对应的名称和颜色
legend_patches = [mpatches.Patch(color=color, label=name)
                  for color, name in zip(colors, model_names)]
# 将图例放置在图形外面的右侧
ax.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0., prop=tnr)

plt.title('clash')
plt.tight_layout()  # 调整布局，防止图例遮挡
plt.show()
fig.savefig("legend.png", dpi=300, bbox_inches='tight')

# Generation Success

In [None]:
valid_complete_counts = [
    sum(1 for molecule in model.flat_results if molecule['complete'] and molecule['validity'])
    for model in models
]

print(valid_complete_counts)

In [None]:
valid_complete_counts = []  # 用于保存每个模型中满足条件的分子数量
for model in models:
    count = 0
    for molecule in model.flat_results:
        if molecule['complete'] and molecule['validity']:
            count += 1
    valid_complete_counts.append(count)

print(valid_complete_counts)

In [None]:
ours_success = []
len_molcraft = len(ours.flat_results)
vinascore = []
vinamin = []
vinadock = []
se = []
sa = []
clash = []
qed = []
rmsd = []
# cnt = []
for mol in models[7].flat_results:
    if mol['complete'] and mol['validity']:
        # cnt += 1
        vinascore.append(mol['vina']['score_only'][0]['affinity'])
        vinamin.append(mol['vina']['minimize'][0]['affinity'])
        # vinadock.append(mol['vina']['dock'][0]['affinity'])
        se.append(mol['pose_check']['strain'])
        # print(cnt)
        if 'clash' in mol['pose_check']:
            clash.append(mol['pose_check']['clash'])
        sa.append(mol['chem_results']['sa'])
        qed.append(mol['chem_results']['qed'])
        if 'rmsd' in mol:
            # print(mol['rmsd'])
            rmsd.append(mol['rmsd'])

        ours_success.append(mol)

print(np.mean(vinascore), np.median(vinascore))
print(np.mean(vinamin), np.median(vinamin))
print(np.mean(vinadock), np.median(vinadock))
# print(np.mena(se))
print('strain', np.quantile(se, 0.25), np.median(se), np.quantile(se, 0.75))
print(np.mean(sa))
print(np.mean(clash))
print(np.mean(qed))
# print(np.mean(rmsd))
print(np.mean(np.array(rmsd)) < 2, len(rmsd))


In [None]:
print(len(ours_success))

In [None]:
vinascore = []
vinamin = []
vinadock = []
se = []
sa = []
clash = []
qed = []
rmsd = []
for mol in models[7].flat_results:
    if mol['complete'] and mol['validity']:
        # cnt += 1
        vinascore.append(mol['vina']['score_only'][0]['affinity'])
        vinamin.append(mol['vina']['minimize'][0]['affinity'])
        vinadock.append(mol['vina']['dock'][0]['affinity'])
        se.append(mol['pose_check']['strain'])
        # print(cnt)
        if 'clash' in mol['pose_check']:
            clash.append(mol['pose_check']['clash'])
        sa.append(mol['chem_results']['sa'])
        qed.append(mol['chem_results']['qed'])
        if 'rmsd' in mol:
            # print(mol['rmsd'])
            rmsd.append(mol['rmsd'])

        ours_success.append(mol)

print(np.mean(vinascore), np.median(vinascore))
print(np.mean(vinamin), np.median(vinamin))
print(np.mean(vinadock), np.median(vinadock))
# print(np.mena(se))
# print('strain', np.quantile(se, 0.25), np.median(se), np.quantile(se, 0.75))
# print(np.mean(sa))
# print(np.mean(clash))
# print(np.mean(qed))
# print(np.mean(rmsd))

In [None]:
torch.save(ours_success, "success_ours.pt")

In [None]:
abc = torch.load("success_ours.pt")
print(len(abc))

In [None]:
print(len(ours_success))

In [None]:
ours.flat_results[300]['pose_check']

# 散点图

In [None]:
# Generation_success = [9036, 8292, 9667, 9880]
# times = [788, 1096, 74, 74]
# times = 100 / times
# names = ["TargetDiff", "DecompDiff", "MolCRAFT", "MolSolver"]

In [None]:
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# 数据定义
Generation_success = [0.9036, 0.8292, 0.9667, 0.9889]
times_original = [788, 1096, 74, 59]
times = [100 / t for t in times_original]
names = ["TargetDiff", "DecompDiff", "MolCRAFT", "Ours"]

# 使用更鲜明的配色方案（类似 Tableu 色板）
colors = ["#4E79A7", "#F28E2B", "#E15759", "#76B7B2"]
markers = ['o', 's', 'D', '*']  # 圆形、正方形、菱形、五角星

# 创建较小尺寸的图形
fig, ax = plt.subplots(figsize=(6, 4))

# 绘制各模型的散点，对于最后一个模型使用更大的 marker 尺寸
for x, y, name, color, marker in zip(Generation_success, times, names, colors, markers):
    # 对于 MolSolver（星形），放大 marker 尺寸
    scatter_size = 340 if name == "Ours" else 150
    ax.scatter(x, y, s=scatter_size, color=color, marker=marker, edgecolor='k', alpha=0.8)

# 创建自定义图例句柄，根据模型名称调整 marker 大小（最后一个模型更大）
legend_handles = []
for name, color, marker in zip(names, colors, markers):
    msize = 18 if name == "Ours" else 12
    legend_handles.append(Line2D([0], [0], marker=marker, color='w', label=name,
                                 markerfacecolor=color, markersize=msize, markeredgecolor='k'))

# 添加图例
font_path = "/usr/share/fonts/truetype/msttcorefonts/times.ttf"
tnr = fm.FontProperties(fname=font_path, size=12)
ax.legend(handles=legend_handles, title="Models", loc="upper left",
          frameon=True, fancybox=True, shadow=True, prop=tnr)
# 设置坐标轴标签和标题

# ax.set_xlabel("Generation Success", fontsize=12)
# ax.set_ylabel("Generated Molecules Per Second", fontsize=12)
# ax.set_title("Scatter Plot: Generation Success vs Time", fontsize=14)
ax.grid(True, linestyle='--', alpha=0.5)

# ax
ax.set_xlabel("Success Rate", fontproperties=tnr)
ax.set_ylabel("Generated Molecules Per Second", fontproperties=tnr)

# ax.xaxis.set_ticks([])
ax
fig.savefig("efficiency.png", dpi=600, bbox_inches='tight')


In [None]:
proportions = [0.9295, 0.9831, 0.9985, 0.9036, 0.8292, 0.9667, 0.9979]
fig, ax = plt.subplots(figsize=(5.45,4.15))
bars = ax.bar(range(len(proportions)), proportions)

# 使用 Matplotlib 内置的 "Set2" 颜色映射，生成与柱子数量相同的颜色
cmap = plt.get_cmap("Pastel1")
colors = cmap(np.linspace(0, 1, len(proportions)))

# 为每个柱子设置对应的颜色
for bar, color in zip(bars, colors):
    bar.set_color(color)

# ax.set_ylabel("比例 (RMSD < 2)")
# ax.set_xlabel("模型")
ax.set_ylabel("Generation Success")
# ax.set_ylim(0, 1)  # 因为比例在 0 到 1 之间
# 移除 x 轴所有刻度
ax.xaxis.set_ticks([])
plt.show()

In [None]:
# 初始化各环大小的计数器
ring_counts = {3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0}
total_rings = 0
# 遍历每个分子
for r in ours_success:
# for r in ours_success.flat_results:
    mol = r['mol']
    # 获取分子中的所有对称环（使用 RDKit 提供的方法）
    ssr = Chem.GetSymmSSSR(mol)
    for ring in ssr:
        ring_size = len(ring)
        # 如果环的原子数为3,4,5,6，则计数
        if ring_size in ring_counts:
            ring_counts[ring_size] += 1
        total_rings += 1

# 计算各个环的比例
if total_rings > 0:
    proportions = {size: ring_counts[size] / total_rings for size in ring_counts}
else:
    proportions = {size: 0 for size in ring_counts}


# Ring

In [None]:
# 初始化各环大小的计数器
ring_counts = {3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0}
total_rings = 0
# 遍历每个分子
# for r in ours_success:
statistic_now = molcraft.flat_results
for r in statistic_now:
    mol = r['mol']
    # 使用 RDKit 获取分子中所有对称环
    ssr = Chem.GetSymmSSSR(mol)
    # 用集合记录该分子中出现过的环尺寸
    unique_ring_sizes = set()

    for ring in ssr:
        ring_size = len(ring)
        # 如果环的尺寸为 3,4,5,6，则记录
        if ring_size in ring_counts:
            unique_ring_sizes.add(ring_size)

    # 对于集合中的每个环尺寸，只计数一次
    for size in unique_ring_sizes:
        ring_counts[size] += 1
        total_rings += 1

# 计算各个环的比例
# if total_rings > 0:
#     proportions = {size: ring_counts[size] / total_rings for size in ring_counts}
# else:
#     proportions = {size: 0 for size in ring_counts}
# print(ring_counts)
normalized_ring_counts = {
    k: round(v / len(statistic_now), 4)
    for k, v in ring_counts.items()
}
print(normalized_ring_counts)

In [None]:
# ---------- ① 注册并设定全局字体 ----------
font_path = "/usr/share/fonts/truetype/msttcorefonts/times.ttf"  # Times New Roman
font_prop = fm.FontProperties(fname=font_path, size=11)
plt.rcParams['font.family'] = font_prop.get_name()   # 全局默认字体
plt.rcParams['mathtext.fontset'] = 'cm'              # 数学公式继续用 Computer Modern
# plt.rcParams['font.size'] = 11          # 默认字号
# plt.rcParams['axes.titlesize'] = 12     # 子图标题
# plt.rcParams['axes.labelsize'] = 11     # 轴标签
# plt.rcParams['xtick.labelsize'] = 9
# plt.rcParams['ytick.labelsize'] = 9

In [None]:
import seaborn as sns

In [None]:
# ------------------ 统计，同之前 ------------------
ring_sizes = range(3, 9)
ring_prop_per_model = {}

for model in models:
    ring_counts = {s: 0 for s in ring_sizes}
    for entry in model.flat_results:
        mol = entry['mol']
        ssr = Chem.GetSymmSSSR(mol)
        unique_ring_sizes = {len(r) for r in ssr if len(r) in ring_sizes}
        for size in unique_ring_sizes:
            ring_counts[size] += 1
    n_mols = len(model.flat_results)
    ring_prop_per_model[model.name] = {s: ring_counts[s] * 100 / n_mols for s in ring_sizes}

In [None]:
n_models    = len(ring_prop_per_model)
model_names = list(ring_prop_per_model.keys())
colors      = sns.color_palette("colorblind", n_models)   # 均匀色环

fig, axes = plt.subplots(2, 3, figsize=(10, 7), sharey=True)
fig.subplots_adjust(wspace=0.35, hspace=0.35)
axes = axes.flatten()

proxy_bars = None  # 用于 legend

for idx, size in enumerate(ring_sizes):
    ax    = axes[idx]
    props = [ring_prop_per_model[m][size] for m in model_names]
    bars  = ax.bar(range(n_models), props,
                   color=colors, edgecolor='black')

    # 记录一次代理柱用于 legend
    if proxy_bars is None:
        proxy_bars = bars

    # ——去掉 x 轴刻度标签——
    ax.set_xticks(range(n_models))
    ax.set_xticklabels(['' for _ in model_names])   # 留空

    ax.set_ylabel('Proportion(%)')
    ax.set_title(f'{size}-atom ring')

    # 柱顶百分比
    for bar, val in zip(bars, props):
        ax.text(bar.get_x() + bar.get_width()/2, val,
                f'{val:.2f}', ha='center', va='bottom', fontsize=9)

# ——统一图例放在右侧——
fig.legend(proxy_bars, model_names,
           loc='center left',
           bbox_to_anchor=(1, 0.5),    # 右侧居中
           borderaxespad=0.,
           title='Models',
           labelspacing=0.4,
           handlelength=1.2,
           handletextpad=0.5,
           )

plt.tight_layout(rect=[0, 0, 1, 1])   # 右侧预留 13% 给图例
plt.savefig('ring_stats.png', dpi=1000, bbox_inches='tight')
plt.show()

In [None]:
import pandas as pd
# import matplotlib.pyplot as plt

def calc_element_ratios(models, elements=("C", "N", "O")) -> pd.DataFrame:
    """
    返回 DataFrame：
        行索引 = 模型 name
        列      = 元素比例（0~1）
    """
    rows = []
    for model in models:
        # 统计单个模型
        elem_count = {e: 0 for e in elements}
        total_atoms = 0

        for res in model.flat_results:          # 遍历该模型的所有分子
            mol = res['mol']
            for atom in mol.GetAtoms():         # 遍历分子里的所有原子
                symbol = atom.GetSymbol()
                if symbol in elements:
                    elem_count[symbol] += 1
                total_atoms += 1

        # 比例 = 元素原子数 / 全部原子数
        ratios = {e: (elem_count[e] / total_atoms if total_atoms else 0)
                  for e in elements}
        ratios["name"] = model.name
        rows.append(ratios)

    df = pd.DataFrame(rows).set_index("name").sort_index()
    return df


def plot_element_ratios(df: pd.DataFrame):
    """
    依据 calc_element_ratios() 的输出绘图
    """
    elements = df.columns
    n_elements = len(elements)

    fig, axes = plt.subplots(
        1, n_elements,
        figsize=(5 * n_elements, 4),
        sharey=True,
        dpi=110
    )

    # 如果只有一个元素，axes 不是列表，需要包装一下
    if n_elements == 1:
        axes = [axes]

    for idx, elem in enumerate(elements):
        ax = axes[idx]
        bars = ax.bar(df.index, df[elem])

        # 在柱顶写百分比
        for bar in bars:
            height = bar.get_height()
            ax.annotate(f"{height:.1%}",
                        xy=(bar.get_x() + bar.get_width() / 2, height),
                        xytext=(0, 3),
                        textcoords="offset points",
                        ha="center", va="bottom", fontsize=8)

        ax.set_title(f"{elem} 原子比例")
        ax.set_ylabel("比例")
        ax.set_ylim(0, 1)

        # 只在第一张子图放 legend，条形图颜色自动映射
        if idx == 0:
            ax.legend(bars, df.index, title="模型")

    fig.suptitle("各模型中 C / N / O 原子所占比例", fontsize=14, y=1.03)
    fig.tight_layout()
    plt.show()


# ===== 主流程 =====
df_ratios = calc_element_ratios(models)
plot_element_ratios(df_ratios)

# Train set

In [None]:
# train12 = torch.load("/data/mntdata/aaa/molsolver/data/crossdocked_v1.1_rmsd1.0_pocket10_add_aromatic_transformed_simple.pt")

In [None]:
# print(type(train12))
# print("hello")

# Ablation Study

In [None]:
exp0 = torch.load(os.path.join(eval_path, "exp0_docked.pt"))
exp1 = torch.load(os.path.join(eval_path, "exp1_docked.pt"))
exp2 = torch.load(os.path.join(eval_path, "exp2_docked.pt"))
# exp2 = torch.load("/aaa/molsolver/samples/sde_discreteloss_exp2.pt")
steps80 = torch.load(os.path.join(eval_path, "steps80_success.pt"))

In [None]:
ours_success = []
# len_molcraft = len(steps80.flat_results)
vinascore = []
vinamin = []
vinadock = []
se = []
sa = []
clash = []
qed = []
rmsd = []
# cnt = []
for mol in ours.flat_results:
    if mol['complete'] and mol['validity']:
        # cnt += 1
        vinascore.append(mol['vina']['score_only'][0]['affinity'])
        vinamin.append(mol['vina']['minimize'][0]['affinity'])
        vinadock.append(mol['vina']['dock'][0]['affinity'])
        if not math.isnan(mol['pose_check']['strain']):
            se.append(mol['pose_check']['strain'])
        # print(cnt)
        if 'clash' in mol['pose_check']:
            clash.append(mol['pose_check']['clash'])
        sa.append(mol['chem_results']['sa'])
        qed.append(mol['chem_results']['qed'])
        if 'rmsd' in mol:
            # print(mol['rmsd'])
            rmsd.append(mol['rmsd'])

        # ours_success.append(mol)

print(np.mean(vinascore), np.median(vinascore))
print(np.mean(vinamin), np.median(vinamin))
print(np.mean(vinadock), np.median(vinadock))
# print(np.mena(se))
print('strain', np.quantile(se, 0.25), np.median(se), np.quantile(se, 0.75))
print(np.mean(sa))
print(np.mean(clash))
print(np.mean(qed))
# print(np.mean(rmsd))
print(np.mean(np.array(rmsd) < 2), len(rmsd))
print(len(ours.flat_results))
# torch.save(ours_success, "steps80_success.pt")

In [None]:
exp0_success = []
# len_molcraft = len(steps80.flat_results)
vinascore = []
vinamin = []
vinadock = []
se = []
sa = []
clash = []
qed = []
rmsd = []
# cnt = []
for mol in exp0:
    if mol['complete'] and mol['validity']:
        # cnt += 1
        vinascore.append(mol['vina']['score_only'][0]['affinity'])
        vinamin.append(mol['vina']['minimize'][0]['affinity'])
        # vinadock.append(mol['vina']['dock'][0]['affinity'])
        if not math.isnan(mol['pose_check']['strain']):
            se.append(mol['pose_check']['strain'])
        # print(cnt)
        if 'clash' in mol['pose_check']:
            clash.append(mol['pose_check']['clash'])
        sa.append(mol['chem_results']['sa'])
        qed.append(mol['chem_results']['qed'])
        if 'rmsd' in mol:
            # print(mol['rmsd'])
            rmsd.append(mol['rmsd'])

        exp0_success.append(mol)

print(np.mean(vinascore), np.median(vinascore))
print(np.mean(vinamin), np.median(vinamin))
# print(np.mean(vinadock), np.median(vinadock))
# print(np.mena(se))
print('strain', np.quantile(se, 0.25), np.median(se), np.quantile(se, 0.75))
print(np.mean(sa))
print(np.mean(clash))
print(np.mean(qed))
# print(np.mean(rmsd))
print(np.mean(np.array(rmsd) < 2), len(rmsd))
print(len(exp0_success))
# torch.save(ours_success, "steps80_success.pt")

In [None]:
exp1_success = []
# len_molcraft = len(steps80.flat_results)
vinascore = []
vinamin = []
vinadock = []
se = []
sa = []
clash = []
qed = []
rmsd = []
# cnt = []
for mol in exp1:
    if mol['complete'] and mol['validity']:
        # cnt += 1
        vinascore.append(mol['vina']['score_only'][0]['affinity'])
        vinamin.append(mol['vina']['minimize'][0]['affinity'])
        # vinadock.append(mol['vina']['dock'][0]['affinity'])
        if not math.isnan(mol['pose_check']['strain']):
            se.append(mol['pose_check']['strain'])
        # print(cnt)
        if 'clash' in mol['pose_check']:
            clash.append(mol['pose_check']['clash'])
        sa.append(mol['chem_results']['sa'])
        qed.append(mol['chem_results']['qed'])
        if 'rmsd' in mol:
            # print(mol['rmsd'])
            rmsd.append(mol['rmsd'])

        exp1_success.append(mol)

print(np.mean(vinascore), np.median(vinascore))
print(np.mean(vinamin), np.median(vinamin))
# print(np.mean(vinadock), np.median(vinadock))
# print(np.mena(se))
print('strain', np.quantile(se, 0.25), np.median(se), np.quantile(se, 0.75))
print(np.mean(sa))
print(np.mean(clash))
print(np.mean(qed))
# print(np.mean(rmsd))
print(np.mean(np.array(rmsd) < 2), len(rmsd))
print(len(exp1_success))
# torch.save(ours_success, "steps80_success.pt")

In [None]:
exp2_success = []
# len_molcraft = len(steps80.flat_results)
vinascore = []
vinamin = []
vinadock = []
se = []
sa = []
clash = []
qed = []
rmsd = []
# cnt = []
for mol in exp2:
    if mol['complete'] and mol['validity']:
        # cnt += 1
        vinascore.append(mol['vina']['score_only'][0]['affinity'])
        vinamin.append(mol['vina']['minimize'][0]['affinity'])
        # vinadock.append(mol['vina']['dock'][0]['affinity'])
        if not math.isnan(mol['pose_check']['strain']):
            se.append(mol['pose_check']['strain'])
        # print(cnt)
        if 'clash' in mol['pose_check']:
            clash.append(mol['pose_check']['clash'])
        sa.append(mol['chem_results']['sa'])
        qed.append(mol['chem_results']['qed'])
        if 'rmsd' in mol:
            # print(mol['rmsd'])
            rmsd.append(mol['rmsd'])

        exp2_success.append(mol)

print(np.mean(vinascore), np.median(vinascore))
print(np.mean(vinamin), np.median(vinamin))
# print(np.mean(vinadock), np.median(vinadock))
# print(np.mena(se))
print('strain', np.quantile(se, 0.25), np.median(se), np.quantile(se, 0.75))
print(np.mean(sa))
print(np.mean(clash))
print(np.mean(qed))
# print(np.mean(rmsd))
print(np.mean(np.array(rmsd) < 2), len(rmsd))
print(len(exp2_success))
# torch.save(ours_success, "steps80_success.pt")

# STEPS 分析

In [None]:
steps = [10, 20, 40, 80, 100, 200]
vina_score = [-1.94, -5.86, -6.34, -6.39, -6.35, -6.31]
SA = [0.529, 0.608, 0.647, 0.664, 0.667, 0.675]
QED = [0.536, 0.568, 0.575, 0.575, 0.578, 0.578]
# SE_50 = [254, 359, 247, 212, 209, 196]
completeness = [0.9732, 0.9872, 0.9892, 0.9889, 0.9879, 0.9879]

metrics = {
    'Vina Score': vina_score,
    'SA': SA,
    'QED': QED,
    'Completeness': completeness
}
from matplotlib.ticker import MaxNLocator
# Use Times New Roman
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 14
    # ——关键①：让 y 轴最多只放 6 个主刻度（=> 5 条网格线）

# Plot each metric in a separate figure
cnt = 0
for name, values in metrics.items():
    plt.figure(figsize=(6,4))
    plt.plot(steps, values, marker='o')
    plt.xlabel('Steps', fontsize=20)
    # plt.ylabel(name)
    if cnt == 0:
        plt.title(f'{name} (↓)', fontsize=20)
    else:
        plt.title(f'{name} (↑)', fontsize=20)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"{name}_vs_steps.png", dpi=500)
    plt.show()
    cnt += 1

In [None]:
steps = [10, 20, 40, 80, 100, 200]
vina_score = [-1.94, -5.86, -6.34, -6.39, -6.35, -6.31]
SA = [0.529, 0.608, 0.647, 0.664, 0.667, 0.675]
QED = [0.536, 0.568, 0.575, 0.575, 0.578, 0.578]
# SE_50 = [254, 359, 247, 212, 209, 196]
completeness = [0.9732, 0.9872, 0.9892, 0.9889, 0.9879, 0.9879]

metrics = {
    'Vina Score': vina_score,
    'SA': SA,
    'QED': QED,
    'Completeness': completeness
}
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np

plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size']   = 14

order = ['Vina Score', 'SA', 'QED', 'Completeness']
for name in order:
    values = metrics[name]

    fig, ax = plt.subplots(figsize=(6, 4))
    ax.plot(steps, values, marker='o')

    # ——关键①：让 y 轴最多只放 6 个主刻度（=> 5 条网格线）
    ax.yaxis.set_major_locator(MaxNLocator(nbins=6))
    ax.xaxis.set_major_locator(MaxNLocator(nbins=6))

    # ——关键②：开网格线；只画主刻度的网格即可
    ax.grid(True, which='major')

    ax.set_xlabel('Steps', fontsize=20)
    ax.set_title(f'{name} (↓)' if name == 'Vina Score' else f'{name} (↑)',
                 fontsize=20)
    fig.tight_layout()
    fig.savefig(f'{name}_vs_steps.png', dpi=500)
    plt.show()


## 原子分布图

## 计算分子重心

## 绘制分子