In [None]:
import os
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patheffects as pe  
from matplotlib import cm
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch
from torch_geometric.nn import GCNConv, global_mean_pool

class Config:
    BASE_PATH = r"./data"
    DATASET_PATH = os.path.join(BASE_PATH, "Energy_data.xlsx.xlsx"))
    PROCESSED_CACHE_PATH = r"./models/cached_graphs_box64_cleaned.pt"
    MODEL_PATH = r"./models/zeolite_3d_gnn_enriched_cleaned.pth"
    SAVE_DIR = r"./"
    
    TARGET_COLS = [
        'Binding Energy (kJ/mol Si)',
        'Directivity Energy (kJ/mol Si)',
        'Competition Energy (kJ/mol Si)',
        'Binding Energy (kJ/mol OSDA)',
        'Competition Energy (kJ/mol OSDA)'
    ]
    
    ATOM_EMBEDDING_DIM = 64
    HIDDEN_DIM = 128
    EMB_DIM_DEGREE = 8
    EMB_DIM_CHARGE = 8
    EMB_DIM_HYB = 8
    EMB_DIM_AROMATIC = 4
    EMB_DIM_CHIRAL = 4
    VOXEL_SIZE = 64
    VOXEL_RES = 0.5
    SIGMA = 0.5
    MIN_SAMPLES_PER_TOPO = 0

if not os.path.exists(Config.SAVE_DIR):
    os.makedirs(Config.SAVE_DIR)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"运行设备: {device}")

# =========================================
# 1. 基础工具函数
# =========================================
def coords_to_voxel(coords, grid_size=32, res=0.5, sigma=0.5):
    grid = np.zeros((grid_size, grid_size, grid_size), dtype=np.float32)
    limit = (grid_size * res) / 2.0
    mask = (coords[:, 0] > -limit) & (coords[:, 0] < limit) & \
           (coords[:, 1] > -limit) & (coords[:, 1] < limit) & \
           (coords[:, 2] > -limit) & (coords[:, 2] < limit)
    valid_coords = coords[mask]
    if len(valid_coords) == 0: return grid
    indices = ((valid_coords + limit) / res).astype(int)
    indices = np.clip(indices, 0, grid_size - 1)
    for idx in indices:
        x, y, z = idx
        x_min, x_max = max(0, x-1), min(grid_size, x+2)
        y_min, y_max = max(0, y-1), min(grid_size, y+2)
        z_min, z_max = max(0, z-1), min(grid_size, z+2)
        grid[x_min:x_max, y_min:y_max, z_min:z_max] += 1.0
    return np.clip(grid, 0, 1.0)

def get_rotation_matrix_z(angle_deg):
    rad = np.radians(angle_deg)
    cos_a, sin_a = np.cos(rad), np.sin(rad)
    return np.array([[cos_a, -sin_a, 0], [sin_a, cos_a, 0], [0, 0, 1]])

# =========================================
# 2. 模型定义
# =========================================
class Voxel3DCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv3d(2, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(16)
        self.pool1 = nn.MaxPool3d(2)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(32)
        self.pool2 = nn.MaxPool3d(2)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm3d(64)
        self.pool3 = nn.MaxPool3d(2)
        self.fc = nn.Linear(64 * 8 * 8 * 8, 128)

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(x))
        return x

class DualBranchGNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb_atom = nn.Embedding(120, Config.ATOM_EMBEDDING_DIM)
        self.emb_degree = nn.Embedding(12, Config.EMB_DIM_DEGREE)
        self.emb_charge = nn.Embedding(15, Config.EMB_DIM_CHARGE)
        self.emb_hyb = nn.Embedding(8, Config.EMB_DIM_HYB)
        self.emb_aromatic = nn.Embedding(2, Config.EMB_DIM_AROMATIC)
        self.emb_chiral = nn.Embedding(4, Config.EMB_DIM_CHIRAL)
        
        total_emb_dim = (Config.ATOM_EMBEDDING_DIM + Config.EMB_DIM_DEGREE + 
                         Config.EMB_DIM_CHARGE + Config.EMB_DIM_HYB + 
                         Config.EMB_DIM_AROMATIC + Config.EMB_DIM_CHIRAL)
        
        self.mol_conv1 = GCNConv(total_emb_dim + 1, Config.HIDDEN_DIM)
        self.mol_conv2 = GCNConv(Config.HIDDEN_DIM, Config.HIDDEN_DIM)
        self.zeo_conv1 = GCNConv(total_emb_dim, Config.HIDDEN_DIM)
        self.zeo_conv2 = GCNConv(Config.HIDDEN_DIM, Config.HIDDEN_DIM)
        self.voxel_cnn = Voxel3DCNN()
        self.global_encoder = nn.Sequential(
            nn.Linear(17, 64), nn.ReLU(),
            nn.Linear(64, Config.HIDDEN_DIM), nn.BatchNorm1d(Config.HIDDEN_DIM), nn.ReLU()
        )
        self.head = nn.Sequential(
            nn.Linear(Config.HIDDEN_DIM * 4, 512), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(512, 256), nn.ReLU(),
            nn.Linear(256, len(Config.TARGET_COLS))
        )

    def _embed_features(self, x_idx):
        return torch.cat([
            self.emb_atom(x_idx[:, 0]), self.emb_degree(x_idx[:, 1]),
            self.emb_charge(x_idx[:, 2]), self.emb_hyb(x_idx[:, 3]),
            self.emb_aromatic(x_idx[:, 4]), self.emb_chiral(x_idx[:, 5])
        ], dim=1)

    def forward(self, mol_batch, zeo_batch, voxel_batch):
        x_m_emb = self._embed_features(mol_batch.x)
        x_m_in = torch.cat([x_m_emb, mol_batch.x_charge], dim=1)
        x_m_out = F.relu(self.mol_conv2(F.relu(self.mol_conv1(x_m_in, mol_batch.edge_index, edge_weight=mol_batch.edge_weight)), mol_batch.edge_index, edge_weight=mol_batch.edge_weight))
        feat_m = global_mean_pool(x_m_out, mol_batch.batch)
        
        x_z_emb = self._embed_features(zeo_batch.x)
        x_z_out = F.relu(self.zeo_conv2(F.relu(self.zeo_conv1(x_z_emb, zeo_batch.edge_index)), zeo_batch.edge_index))
        feat_z = global_mean_pool(x_z_out, zeo_batch.batch)
        
        feat_v = self.voxel_cnn(voxel_batch)
        global_attr = mol_batch.global_attr.squeeze(1) if mol_batch.global_attr.dim() == 3 else mol_batch.global_attr
        feat_global = self.global_encoder(global_attr)
        
        return self.head(torch.cat([feat_m, feat_z, feat_global, feat_v], dim=1))

# =========================================
# 3. 资源加载
# =========================================
def load_resources():
    print(">>> 正在加载资源...")
    model = DualBranchGNN().to(device)
    model.load_state_dict(torch.load(Config.MODEL_WEIGHT_PATH, map_location=device, weights_only=False))
    model.eval()
    
    cache_data = torch.load(Config.PROCESSED_CACHE_PATH, weights_only=False)
    df = pd.read_excel(Config.DATASET_PATH, engine='openpyxl')
    
    valid_topos = df['Topology Code'].value_counts()[df['Topology Code'].value_counts() >= Config.MIN_SAMPLES_PER_TOPO].index
    df_filtered = df[df['Topology Code'].isin(valid_topos)].reset_index(drop=True)
    
    train_idx, _ = train_test_split(list(range(len(df_filtered))), train_size=0.8, random_state=42)
    
    train_targets = []
    mol_cache, zeo_cache = cache_data['mol_cache'], cache_data['zeo_cache']
    for idx in train_idx:
        row = df_filtered.iloc[idx]
        if row['CID'] in mol_cache and row['Topology Code'] in zeo_cache:
            targets = row[Config.TARGET_COLS].values.astype(float)
            if not np.isnan(targets).any(): train_targets.append(targets)
            
    target_scaler = StandardScaler()
    target_scaler.fit(np.array(train_targets))
    props_scaler = StandardScaler()
    all_props = [mol_cache[cid].global_attr.numpy().flatten() for cid in mol_cache]
    props_scaler.fit(np.array(all_props))
    
    return model, cache_data, df_filtered, target_scaler, props_scaler

def get_most_unstable_molecule(df, topo_code):
    subset = df[df['Topology Code'] == topo_code]
    if len(subset) == 0: return None
    target_col = Config.TARGET_COLS[0]
    most_unstable_row = subset.loc[subset[target_col].idxmax()]
    return most_unstable_row['CID'], most_unstable_row[target_col]

# =========================================
# 4. 旋转计算
# =========================================
def compute_rotation_profile(model, cid, topo, cache_data, t_scaler, p_scaler):
    mol_raw = cache_data['mol_cache'].get(cid)
    zeo_raw = cache_data['zeo_cache'].get(topo)
    if not mol_raw or not zeo_raw: return None, None

    angles = np.arange(0, 360, 10)
    preds_list = []
    
    if hasattr(mol_raw, 'pos_variants'): mol_base = mol_raw.pos_variants[0].numpy()
    else: mol_base = mol_raw.pos.numpy()
    mol_base = mol_base - np.mean(mol_base, axis=0)
    
    zeo_coords = zeo_raw.pos_super.numpy() if hasattr(zeo_raw, 'pos_super') else zeo_raw.pos.numpy()
    grid_zeo = coords_to_voxel(zeo_coords, Config.VOXEL_SIZE, Config.VOXEL_RES, Config.SIGMA)
    
    props_norm = p_scaler.transform(mol_raw.global_attr.numpy())
    global_attr = torch.tensor(props_norm, dtype=torch.float).to(device)
    
    batch_size_rot = 12
    with torch.no_grad():
        for i in range(0, len(angles), batch_size_rot):
            ang_batch = angles[i:i+batch_size_rot]
            mol_list, voxel_list = [], []
            for ang in ang_batch:
                rot_mat = get_rotation_matrix_z(ang)
                mol_rot = np.dot(mol_base, rot_mat)
                grid_mol = coords_to_voxel(mol_rot, Config.VOXEL_SIZE, Config.VOXEL_RES, Config.SIGMA)
                voxel_list.append(torch.tensor(np.stack([grid_mol, grid_zeo], axis=0), dtype=torch.float))
                m = mol_raw.clone()
                m.pos = torch.tensor(mol_rot, dtype=torch.float)
                m.global_attr = global_attr
                if hasattr(m, 'pos_variants'): del m.pos_variants
                mol_list.append(m)
            
            mol_batch_gpu = Batch.from_data_list(mol_list).to(device)
            voxel_gpu = torch.stack(voxel_list).to(device)
            zeo_list_expanded = [zeo_raw.clone() for _ in range(len(mol_list))]
            zeo_batch_gpu = Batch.from_data_list(zeo_list_expanded).to(device)
            
            pred = model(mol_batch_gpu, zeo_batch_gpu, voxel_gpu)
            preds_list.append(pred.cpu().numpy())

    preds_all = np.vstack(preds_list)
    preds_real = t_scaler.inverse_transform(preds_all)
    return angles, preds_real[:, 0]

# =========================================
# 5. 绘图核心 (30个一排)
# =========================================
def plot_all_topologies_in_one(df, model, cache_data, t_scaler, p_scaler):
    
    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['font.sans-serif'] = ['Arial', 'Helvetica', 'DejaVu Sans']
    
    # 1. 获取所有可用的拓扑结构
    available_topos = list(cache_data['zeo_cache'].keys())
    available_topos.sort()
    
    print(f"检测到共 {len(available_topos)} 个分子筛骨架，开始计算所有 CV 值以生成色谱...")
    
    # 2. 预计算步骤
    plot_data = []
    all_cvs = []
    
    for topo in tqdm(available_topos, desc="预计算"):
        cid, worst_energy = get_most_unstable_molecule(df, topo)
        if cid is None: continue
        
        angles, energies = compute_rotation_profile(model, cid, topo, cache_data, t_scaler, p_scaler)
        if energies is None: continue
        
        mean_e = np.mean(energies)
        std_e = np.std(energies)
        cv = (std_e / abs(mean_e)) * 100 if mean_e != 0 else 0
        
        plot_data.append({
            'topo': topo,
            'mean': mean_e,
            'cv': cv,
            'angles': angles,
            'energies': np.abs(energies)
        })
        all_cvs.append(cv)
        
    if not plot_data:
        print("未找到有效数据。")
        return

    # 3. 颜色映射配置 (Magma 30%-100%)
    min_cv, max_cv = min(all_cvs), max(all_cvs)
    print(f"CV Range: {min_cv:.4f}% - {max_cv:.4f}%")
    
    try: magma = matplotlib.colormaps['magma']
    except: magma = plt.get_cmap('magma')
    
    new_colors = magma(np.linspace(0.3, 1.0, 256))
    new_cmap = mcolors.LinearSegmentedColormap.from_list("trunc_magma", new_colors)
    norm = mcolors.Normalize(vmin=min_cv, vmax=max_cv)
    
    # 4. 布局计算 (强制 30 列)
    N = len(plot_data)
    cols = 27  # 核心修改：一排30个
    rows = int(math.ceil(N / cols))
    
    # 画布尺寸
    # 30列 * 3英寸 = 90英寸宽，这能保证在高清保存时不超限且文字清晰
    fig_w, fig_h = cols * 3, rows * 3
    
    fig, axes = plt.subplots(rows, cols, figsize=(fig_w, fig_h), subplot_kw=dict(polar=True))
    
    # 如果只有1行，axes 是一维数组，需要处理
    if rows == 1:
        axes = axes.reshape(1, -1)
    
    axes = axes.flatten()
    
    print("开始绘图...")
    for i, ax in enumerate(axes):
        if i >= N:
            ax.set_visible(False)
            continue
            
        data = plot_data[i]
        
        # 颜色
        color = new_cmap(norm(data['cv']))
        
        theta = np.deg2rad(data['angles'])
        values = data['energies']
        theta = np.concatenate((theta, [theta[0]]))
        values = np.concatenate((values, [values[0]]))
        
        # 绘图
        ax.plot(theta, values, color=color, linewidth=3)
        ax.fill(theta, values, color=color, alpha=0.8)
        
        max_val = np.max(values)
        ax.set_ylim(0, max_val * 1.25)
        
        ax.set_yticklabels([])
        ax.set_xticklabels([])
        ax.grid(True, linestyle='--', alpha=0.3)
        ax.spines['polar'].set_visible(False)
        
        # === 中心文字 ===
        center_text = f"{data['topo']}\n{data['mean']:.1f}"
        
        # 字体调整为 28，配合加粗描边，在3x3的格子中会非常显眼
        txt = ax.text(0, 0, center_text, 
                      ha='center', va='center', 
                      fontsize=51, weight='bold', color='white', zorder=10)
        
        txt.set_path_effects([pe.withStroke(linewidth=3.5, foreground='black')])

    # Colorbar
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.1) # 底部留白
    
    cbar_ax = fig.add_axes([0.15, 0.05, 0.7, 0.02])
    sm = plt.cm.ScalarMappable(cmap=new_cmap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal')
    cbar.set_label('Coefficient of Variation (CV %)', fontsize=120, weight='bold') # 标签也加大
    cbar.ax.tick_params(labelsize=120)
    
    save_path = os.path.join(Config.SAVE_DIR, "All_Topologies_Row30.png")
    plt.savefig(save_path, dpi=600, bbox_inches='tight')
    plt.close(fig)
    print(f"  --> 保存成功: {save_path}")

# =========================================
# 6. 主程序
# =========================================
if __name__ == "__main__":
    model, cache_data, df, t_scaler, p_scaler = load_resources()
    plot_all_topologies_in_one(df, model, cache_data, t_scaler, p_scaler)

运行设备: cuda
>>> 正在加载资源...
检测到共 184 个分子筛骨架，开始计算所有 CV 值以生成色谱...


预计算: 100%|████████████████████████████████████████████████████████████████████████| 184/184 [00:37<00:00,  4.87it/s]


CV Range: 0.0000% - 9.2970%
开始绘图...
  --> 保存成功: C:\Users\admin\Energymodel\2-9\All_Topologies_Row30.png
