In [1]:
import os
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib import cm
import seaborn as sns
from scipy.stats import gaussian_kde
import warnings

# 忽略不必要的警告
warnings.filterwarnings('ignore')

# ==========================================
# 1. 全局配置
# ==========================================
class Config:
    # 基础路径
    BASE_PATH = r"C:\机器学习材料\季鏻"
    DATASET_PATH = os.path.join(BASE_PATH, "MEL数据集.xlsx")
    
    # 缓存与模型路径
    PROCESSED_CACHE_PATH = os.path.join(BASE_PATH, "cached_graphs_box64_cleaned.pt")
    MODEL_PATH = "zeolite_3d_gnn_enriched_cleaned.pth" 
    
    # 输出图片路径
    SAVE_DIR = r"C:\Users\admin\Energymodel\2-9"

    TARGET_COLS = [
        'Binding Energy (kJ/mol Si)',
        'Directivity Energy (kJ/mol Si)',
        'Competition Energy (kJ/mol Si)',
        'Binding Energy (kJ/mol OSDA)',
        'Competition Energy (kJ/mol OSDA)'
    ]

    BATCH_SIZE = 64
    NUM_WORKERS = 0 
    
    # 维度配置
    ATOM_EMBEDDING_DIM = 64
    HIDDEN_DIM = 128
    EMB_DIM_DEGREE = 8
    EMB_DIM_CHARGE = 8
    EMB_DIM_HYB = 8
    EMB_DIM_AROMATIC = 4
    EMB_DIM_CHIRAL = 4

    VOXEL_SIZE = 64
    VOXEL_RES = 0.5
    SIGMA = 0.5
    MIN_SAMPLES_PER_TOPO = 0

# 创建保存目录
if not os.path.exists(Config.SAVE_DIR):
    os.makedirs(Config.SAVE_DIR)
    print(f"创建输出目录: {Config.SAVE_DIR}")

# ==========================================
# 2. 辅助函数
# ==========================================
def coords_to_voxel(coords, grid_size=32, res=0.5, sigma=0.5):
    grid = np.zeros((grid_size, grid_size, grid_size), dtype=np.float32)
    limit = (grid_size * res) / 2.0
    mask = (coords[:, 0] > -limit) & (coords[:, 0] < limit) & \
           (coords[:, 1] > -limit) & (coords[:, 1] < limit) & \
           (coords[:, 2] > -limit) & (coords[:, 2] < limit)
    valid_coords = coords[mask]
    if len(valid_coords) == 0: return grid
    indices = ((valid_coords + limit) / res).astype(int)
    indices = np.clip(indices, 0, grid_size - 1)
    for idx in indices:
        x, y, z = idx
        x_min, x_max = max(0, x-1), min(grid_size, x+2)
        y_min, y_max = max(0, y-1), min(grid_size, y+2)
        z_min, z_max = max(0, z-1), min(grid_size, z+2)
        grid[x_min:x_max, y_min:y_max, z_min:z_max] += 1.0
    return np.clip(grid, 0, 1.0)

# ==========================================
# 3. 数据集定义
# ==========================================
class ZeoliteDataset(Dataset):
    def __init__(self, df, cache_data, target_scaler=None, props_scaler=None, is_train=False):
        super().__init__()
        self.target_scaler = target_scaler if target_scaler else StandardScaler()
        self.props_scaler = props_scaler if props_scaler else StandardScaler()
        self.is_train = is_train
        
        mol_cache = cache_data['mol_cache']
        zeo_cache = cache_data['zeo_cache']
        
        self.mol_list = []
        self.zeo_list = []
        raw_y_list = []
        
        for idx, row in df.iterrows():
            cid = row['CID']
            topo = row['Topology Code']
            
            if cid in mol_cache and topo in zeo_cache:
                targets = row[Config.TARGET_COLS].values.astype(float)
                if not np.isnan(targets).any():
                    self.mol_list.append(mol_cache[cid])
                    self.zeo_list.append(zeo_cache[topo])
                    raw_y_list.append(targets)
        
        y_all = np.array(raw_y_list)
        if is_train:
            y_norm = self.target_scaler.fit_transform(y_all)
        else:
            y_norm = self.target_scaler.transform(y_all) if hasattr(self.target_scaler, 'mean_') else y_all
            
        self.y_list = [torch.tensor(y, dtype=torch.float) for y in y_norm]
        
        if len(self.mol_list) > 0:
            all_props = torch.cat([m.global_attr for m in self.mol_list], dim=0).numpy()
            if is_train:
                self.props_scaler.fit(all_props)
                
        self.length = len(self.mol_list)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        mol_data = self.mol_list[idx].clone()
        zeo_data = self.zeo_list[idx].clone()
        y = self.y_list[idx]

        if hasattr(self.props_scaler, 'mean_'):
            props_raw = mol_data.global_attr.numpy()
            props_norm = self.props_scaler.transform(props_raw)
            mol_data.global_attr = torch.tensor(props_norm, dtype=torch.float)

        mol_coords = mol_data.pos.numpy()
        if hasattr(mol_data, 'pos_variants'):
            variants = mol_data.pos_variants
            mol_coords = variants[0].numpy() 
            del mol_data.pos_variants

        zeo_voxel_coords = zeo_data.pos_super.numpy() if hasattr(zeo_data, 'pos_super') else zeo_data.pos.numpy()
        if hasattr(zeo_data, 'pos_super'): del zeo_data.pos_super 

        mol_data.pos = torch.tensor(mol_coords, dtype=torch.float)
        
        grid_mol = coords_to_voxel(mol_coords, Config.VOXEL_SIZE, Config.VOXEL_RES, Config.SIGMA)
        grid_zeo = coords_to_voxel(zeo_voxel_coords, Config.VOXEL_SIZE, Config.VOXEL_RES, Config.SIGMA)
        
        voxel_tensor = torch.tensor(np.stack([grid_mol, grid_zeo], axis=0), dtype=torch.float)
        
        return mol_data, zeo_data, voxel_tensor, y

    @staticmethod
    def gpu_collate(batch):
        mol_list = [item[0] for item in batch]
        zeo_list = [item[1] for item in batch]
        voxel_list = [item[2] for item in batch]
        y_list = [item[3] for item in batch]
        return (Batch.from_data_list(mol_list), Batch.from_data_list(zeo_list), torch.stack(voxel_list), torch.stack(y_list))

# ==========================================
# 4. 模型架构
# ==========================================
class Voxel3DCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv3d(2, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(16)
        self.pool1 = nn.MaxPool3d(2)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(32)
        self.pool2 = nn.MaxPool3d(2)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm3d(64)
        self.pool3 = nn.MaxPool3d(2)
        self.fc = nn.Linear(64 * 8 * 8 * 8, 128)

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(x))
        return x

class DualBranchGNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb_atom = nn.Embedding(120, Config.ATOM_EMBEDDING_DIM)
        self.emb_degree = nn.Embedding(12, Config.EMB_DIM_DEGREE)
        self.emb_charge = nn.Embedding(15, Config.EMB_DIM_CHARGE)
        self.emb_hyb = nn.Embedding(8, Config.EMB_DIM_HYB)
        self.emb_aromatic = nn.Embedding(2, Config.EMB_DIM_AROMATIC)
        self.emb_chiral = nn.Embedding(4, Config.EMB_DIM_CHIRAL)
        
        total_emb_dim = (Config.ATOM_EMBEDDING_DIM + Config.EMB_DIM_DEGREE + 
                         Config.EMB_DIM_CHARGE + Config.EMB_DIM_HYB + 
                         Config.EMB_DIM_AROMATIC + Config.EMB_DIM_CHIRAL)
        
        self.mol_conv1 = GCNConv(total_emb_dim + 1, Config.HIDDEN_DIM)
        self.mol_conv2 = GCNConv(Config.HIDDEN_DIM, Config.HIDDEN_DIM)
        self.zeo_conv1 = GCNConv(total_emb_dim, Config.HIDDEN_DIM)
        self.zeo_conv2 = GCNConv(Config.HIDDEN_DIM, Config.HIDDEN_DIM)
        self.voxel_cnn = Voxel3DCNN()
        
        self.global_feat_dim = 17
        self.global_encoder = nn.Sequential(
            nn.Linear(self.global_feat_dim, 64),
            nn.ReLU(),
            nn.Linear(64, Config.HIDDEN_DIM),
            nn.BatchNorm1d(Config.HIDDEN_DIM),
            nn.ReLU()
        )
        
        fusion_dim = Config.HIDDEN_DIM * 4
        self.head = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, len(Config.TARGET_COLS))
        )

    def _embed_features(self, x_idx):
        e1 = self.emb_atom(x_idx[:, 0])
        e2 = self.emb_degree(x_idx[:, 1])
        e3 = self.emb_charge(x_idx[:, 2])
        e4 = self.emb_hyb(x_idx[:, 3])
        e5 = self.emb_aromatic(x_idx[:, 4])
        e6 = self.emb_chiral(x_idx[:, 5])
        return torch.cat([e1, e2, e3, e4, e5, e6], dim=1)

    def forward(self, mol_batch, zeo_batch, voxel_batch):
        x_m, edge_index_m, batch_m = mol_batch.x, mol_batch.edge_index, mol_batch.batch
        x_m_emb = self._embed_features(x_m)
        x_m_in = torch.cat([x_m_emb, mol_batch.x_charge], dim=1)
        x_m_out = F.relu(self.mol_conv1(x_m_in, edge_index_m, edge_weight=mol_batch.edge_weight))
        x_m_out = F.relu(self.mol_conv2(x_m_out, edge_index_m, edge_weight=mol_batch.edge_weight))
        feat_m = global_mean_pool(x_m_out, batch_m)
        
        x_z, edge_index_z, batch_z = zeo_batch.x, zeo_batch.edge_index, zeo_batch.batch
        x_z_emb = self._embed_features(x_z)
        x_z_out = F.relu(self.zeo_conv1(x_z_emb, edge_index_z))
        x_z_out = F.relu(self.zeo_conv2(x_z_out, edge_index_z))
        feat_z = global_mean_pool(x_z_out, batch_z)
        
        feat_v = self.voxel_cnn(voxel_batch)
        
        global_attr = mol_batch.global_attr
        if global_attr.dim() == 3: global_attr = global_attr.squeeze(1)
        feat_global = self.global_encoder(global_attr)
        
        combined = torch.cat([feat_m, feat_z, feat_global, feat_v], dim=1)
        return self.head(combined)

# ==========================================
# 5. 加载数据与模型
# ==========================================
print(">>> 1. 正在加载数据集与缓存...")
if not os.path.exists(Config.DATASET_PATH) or not os.path.exists(Config.PROCESSED_CACHE_PATH):
    raise FileNotFoundError("数据文件或缓存文件缺失，请检查路径。")

df = pd.read_excel(Config.DATASET_PATH, engine='openpyxl')
topo_counts = df['Topology Code'].value_counts()
valid_topos = topo_counts[topo_counts >= Config.MIN_SAMPLES_PER_TOPO].index
df_filtered = df[df['Topology Code'].isin(valid_topos)].reset_index(drop=True)

try:
    cache_data = torch.load(Config.PROCESSED_CACHE_PATH, weights_only=False)
    print(f"    缓存加载成功。包含分子: {len(cache_data['mol_cache'])}, 沸石: {len(cache_data['zeo_cache'])}")
except Exception as e:
    print(f"    缓存加载失败: {e}")
    exit()

print(">>> 2. 正在重建数据划分 (Random State = 42)...")
indices = list(range(len(df_filtered)))
train_idx, temp_idx = train_test_split(indices, train_size=0.8, random_state=42)
val_idx, test_idx = train_test_split(temp_idx, train_size=0.5, random_state=42)

print(">>> 3. 拟合 Scaler 并准备测试集...")
train_dataset_dummy = ZeoliteDataset(df_filtered.iloc[train_idx].reset_index(drop=True), cache_data=cache_data, is_train=True)
test_dataset = ZeoliteDataset(
    df_filtered.iloc[test_idx].reset_index(drop=True), 
    cache_data=cache_data, 
    target_scaler=train_dataset_dummy.target_scaler, 
    props_scaler=train_dataset_dummy.props_scaler, 
    is_train=False
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=Config.BATCH_SIZE, 
    shuffle=False, 
    num_workers=Config.NUM_WORKERS,
    collate_fn=ZeoliteDataset.gpu_collate
)

print(f">>> 4. 加载模型权重: {Config.MODEL_PATH}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DualBranchGNN().to(device)
if os.path.exists(Config.MODEL_PATH):
    state_dict = torch.load(Config.MODEL_PATH, map_location=device)
    model.load_state_dict(state_dict)
    print("    模型权重加载成功！")
else:
    raise FileNotFoundError(f"找不到模型文件: {Config.MODEL_PATH}")

model.eval()

# ==========================================
# 6. 推理与评估
# ==========================================
print(">>> 5. 开始测试集推理...")
preds_list = []
targets_list = []

with torch.no_grad():
    for mol, zeo, voxel, y in test_loader:
        mol, zeo, voxel, y = mol.to(device), zeo.to(device), voxel.to(device), y.to(device)
        output = model(mol, zeo, voxel)
        preds_list.append(output.cpu().numpy())
        targets_list.append(y.cpu().numpy())

preds_norm = np.vstack(preds_list)
targets_norm = np.vstack(targets_list)

y_pred = train_dataset_dummy.target_scaler.inverse_transform(preds_norm)
y_true = train_dataset_dummy.target_scaler.inverse_transform(targets_norm)

# ==========================================
# 7. 绘图与分析 (NCS 风格优化版)
# ==========================================
print(">>> 6. 生成高标准分析图表...")

plt.rcParams.update({
    'font.size': 18,
    'font.family': 'sans-serif',
    'font.sans-serif': ['Arial', 'Helvetica', 'DejaVu Sans'],
    'axes.labelsize': 20,
    'axes.titlesize': 22,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'legend.fontsize': 16,
    'figure.dpi': 300,
    'axes.linewidth': 1.5,
    'grid.alpha': 0.4
})

# 1. 设置色谱
magma = cm.get_cmap('magma')
new_colors = magma(np.linspace(0.3, 1, 256)) # 从 30% 开始，避免黑色
custom_magma = mcolors.LinearSegmentedColormap.from_list("trunc_magma", new_colors)

# 2. 提取用于边缘图的协调颜色 (选取 Magma 色谱中段的深紫色)
marginal_color = magma(0.35) 

def plot_marginal_density(y_t, y_p, title, filename_suffix):
    r2 = r2_score(y_t, y_p)
    mae = mean_absolute_error(y_t, y_p)
    mse = mean_squared_error(y_t, y_p)
    rmse = np.sqrt(mse)
    
    data = pd.DataFrame({'Experimental': y_t, 'Predicted': y_p})
    
    xy = np.vstack([y_t, y_p])
    try:
        z = gaussian_kde(xy)(xy)
        idx = z.argsort()
        x_sorted, y_sorted, z_sorted = y_t[idx], y_p[idx], z[idx]
    except:
        x_sorted, y_sorted, z_sorted = y_t, y_p, np.ones_like(y_t)

    # 绘图
    g = sns.JointGrid(x='Experimental', y='Predicted', data=data, height=8, ratio=5)
    
    # 中心图
    g.ax_joint.scatter(x_sorted, y_sorted, c=z_sorted, cmap=custom_magma, s=40, edgecolor='none', alpha=0.9)
    
    # 辅助线
    min_val = min(y_t.min(), y_p.min())
    max_val = max(y_t.max(), y_p.max())
    margin = (max_val - min_val) * 0.05
    lims = [min_val - margin, max_val + margin]
    
    g.ax_joint.plot(lims, lims, '--', color='gray', alpha=0.7, linewidth=2, label='Ideal')
    g.ax_joint.set_xlim(lims)
    g.ax_joint.set_ylim(lims)
    
    # 边缘图 (使用从 magma 提取的统一颜色，实现视觉协调)
    sns.histplot(data=data, x='Experimental', ax=g.ax_marg_x, fill=True, 
                 color=marginal_color, alpha=0.7, kde=True, line_kws={'linewidth': 2, 'color': 'k'})
    sns.histplot(data=data, y='Predicted', ax=g.ax_marg_y, fill=True, 
                 color=marginal_color, alpha=0.7, kde=True, line_kws={'linewidth': 2, 'color': 'k'})
    
    # 文本标注 (已移除 MSE 和 N)
    stats_text = (f'$R^2 = {r2:.4f}$\n'
                  f'$MAE = {mae:.4f}$\n'
                  f'$RMSE = {rmse:.4f}$')
    
    g.ax_joint.text(0.05, 0.95, stats_text, transform=g.ax_joint.transAxes,
                    verticalalignment='top', fontsize=20, # 字大一点
                    bbox=dict(boxstyle='round,pad=0.5', facecolor='white', alpha=0.8, edgecolor='gray'))
    
    g.ax_joint.set_xlabel(f'Experimental {title}', fontweight='bold')
    g.ax_joint.set_ylabel(f'Predicted {title}', fontweight='bold')
    
    save_path = os.path.join(Config.SAVE_DIR, f"Analysis_{filename_suffix}.png")
    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close()
    print(f"    图表已保存: {save_path}")
    
    return r2, mae, mse, rmse

metrics_summary = []
print("\n" + "="*60)
print(f"{'Target Variable':<35} | {'R²':<8} | {'MAE':<8} | {'MSE':<8}")
print("-" * 65)

for i, col_name in enumerate(Config.TARGET_COLS):
    clean_name = col_name.split(' (')[0].replace(' ', '_').replace('/', '_')
    unit_label = col_name 
    
    r2, mae, mse, rmse = plot_marginal_density(y_true[:, i], y_pred[:, i], unit_label, f"Target_{i}_{clean_name}")
    
    print(f"{col_name[:35]:<35} | {r2:.4f}   | {mae:.4f}   | {mse:.4f}")
    metrics_summary.append([col_name, r2, mae, mse, rmse])

print("="*60)

csv_path = os.path.join(Config.SAVE_DIR, "metrics_summary.csv")
df_metrics = pd.DataFrame(metrics_summary, columns=['Target', 'R2', 'MAE', 'MSE', 'RMSE'])
df_metrics.to_csv(csv_path, index=False)
print(f"\n详细指标数据已保存至: {csv_path}")
print("分析完成！")
plt.show()

>>> 1. 正在加载数据集与缓存...
    缓存加载成功。包含分子: 496, 沸石: 184
>>> 2. 正在重建数据划分 (Random State = 42)...
>>> 3. 拟合 Scaler 并准备测试集...
>>> 4. 加载模型权重: zeolite_3d_gnn_enriched_cleaned.pth
    模型权重加载成功！
>>> 5. 开始测试集推理...
>>> 6. 生成高标准分析图表...

Target Variable                     | R²       | MAE      | MSE     
-----------------------------------------------------------------
    图表已保存: C:\Users\admin\Energymodel\2-9\Analysis_Target_0_Binding_Energy.png
Binding Energy (kJ/mol Si)          | 0.8361   | 1.0663   | 2.2569
    图表已保存: C:\Users\admin\Energymodel\2-9\Analysis_Target_1_Directivity_Energy.png
Directivity Energy (kJ/mol Si)      | 0.8527   | 1.0781   | 2.3182
    图表已保存: C:\Users\admin\Energymodel\2-9\Analysis_Target_2_Competition_Energy.png
Competition Energy (kJ/mol Si)      | 0.8459   | 1.1122   | 2.3770
    图表已保存: C:\Users\admin\Energymodel\2-9\Analysis_Target_3_Binding_Energy.png
Binding Energy (kJ/mol OSDA)        | 0.8632   | 13.0874   | 396.3253
    图表已保存: C:\Users\admin\Energymodel\2-9\Analysi