In [None]:
import torch
import torch.nn as nn
import shap
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from cgcnn.data import CIFData, collate_pool


##############################################################################
# A. 自动解析模型维度的函数
##############################################################################
def infer_cgcnn_dims_from_sd(state_dict):
    """
    根据 checkpoint 里的 key/shape，自动推断:
     - orig_atom_fea_len, atom_fea_len
     - nbr_fea_len, n_conv
     - h_fea_len, out_dim
    返回一个 (orig_atom_fea_len, atom_fea_len, nbr_fea_len, n_conv, h_fea_len, out_dim) 元组
    """
    dims = {}

    # 1) 解析 embedding.weight => 通常 shape [atom_fea_len, orig_atom_fea_len] (或反过来)
    emb_key = "embedding.weight"
    if emb_key not in state_dict:
        raise ValueError(f"Cannot find {emb_key} in checkpoint to infer dims.")
    emb_shape = state_dict[emb_key].shape  # e.g. (64, 92)
    # 大多数 CGCNN 实现: "out_features, in_features" => out=atom_fea_len, in=orig_atom_fea_len
    # 也有人写反，但这里假设 row=out_features
    atom_fea_len, orig_atom_fea_len = emb_shape
    dims["orig_atom_fea_len"] = orig_atom_fea_len
    dims["atom_fea_len"] = atom_fea_len

    # 2) 解析 convs.0.fc_full.weight => shape [2*atom_fea_len, 2*atom_fea_len + nbr_fea_len]
    #    同时统计 n_conv
    n_conv = 0
    conv0_key = None
    for k, v in state_dict.items():
        if k.startswith("convs.") and k.endswith(".fc_full.weight"):
            n_conv += 1
            # 只要第一层的形状来推断 nbr_fea_len
            if conv0_key is None:
                conv0_key = k

    dims["n_conv"] = n_conv
    if conv0_key is None:
        raise ValueError("Cannot find any convs.X.fc_full.weight in checkpoint.")
    conv0_shape = state_dict[conv0_key].shape  # e.g. (128, 229)
    out_features, in_features = conv0_shape
    # out_features = 2*atom_fea_len => double check
    # in_features = 2*atom_fea_len + nbr_fea_len
    # => nbr_fea_len = in_features - out_features
    # out_features must match 2*dims["atom_fea_len"]
    if out_features != 2 * dims["atom_fea_len"]:
        raise ValueError(f"fc_full out_features({out_features}) != 2*atom_fea_len({2*dims['atom_fea_len']}).")
    nbr_fea_len = in_features - out_features
    dims["nbr_fea_len"] = nbr_fea_len

    # 3) 解析 fc_out.weight => shape (out_dim, h_fea_len)
    fc_out_key = "fc_out.weight"
    if fc_out_key not in state_dict:
        raise ValueError("Cannot find fc_out.weight in checkpoint.")
    fc_out_shape = state_dict[fc_out_key].shape  # e.g. (1, 128)
    out_dim, h_fea_len = fc_out_shape
    dims["out_dim"] = out_dim
    dims["h_fea_len"] = h_fea_len

    return (
        dims["orig_atom_fea_len"],
        dims["atom_fea_len"],
        dims["nbr_fea_len"],
        dims["n_conv"],
        dims["h_fea_len"],
        dims["out_dim"]
    )


##############################################################################
# B. ModifiedConvLayer
##############################################################################
class ModifiedConvLayer(nn.Module):
    """
    示范性的卷积层: (atom_fea_len, nbr_fea_len) -> fc_full -> BN -> split -> sum pooling
    """
    def __init__(self, atom_fea_len, nbr_fea_len):
        super().__init__()
        self.atom_fea_len = atom_fea_len
        self.nbr_fea_len = nbr_fea_len

        in_features = 2 * atom_fea_len + nbr_fea_len
        out_features = 2 * atom_fea_len
        self.fc_full = nn.Linear(in_features, out_features)
        self.bn1 = nn.BatchNorm1d(out_features)
        self.bn2 = nn.BatchNorm1d(atom_fea_len)

        self.sigmoid = nn.Sigmoid()
        self.softplus1 = nn.Softplus()
        self.softplus2 = nn.Softplus()

    def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
        N, M = nbr_fea_idx.shape
        atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]  # (N, M, atom_fea_len)

        total_nbr_fea = torch.cat([
            atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len),
            atom_nbr_fea,
            nbr_fea
        ], dim=2)  # shape => (N, M, 2*atom_fea_len + nbr_fea_len)

        gated_fea = self.fc_full(total_nbr_fea.view(-1, total_nbr_fea.shape[-1]))
        gated_fea = self.bn1(gated_fea).view(N, M, 2*self.atom_fea_len)

        nbr_filter, nbr_core = gated_fea.chunk(2, dim=2)
        nbr_filter = self.sigmoid(nbr_filter)
        nbr_core = self.softplus1(nbr_core)

        nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1)
        nbr_sumed = self.bn2(nbr_sumed)
        out = self.softplus2(atom_in_fea + nbr_sumed)
        return out


##############################################################################
# C. CrystalGraphConvNetWithHooks
##############################################################################
class CrystalGraphConvNetWithHooks(nn.Module):
    """
    - embedding: [orig_atom_fea_len -> atom_fea_len]
    - n_conv 层 ModifiedConvLayer
    - global mean pooling
    - conv_to_fc -> fc_out
    - 可在 forward 里自动 register hook
    """
    def __init__(
        self,
        orig_atom_fea_len,
        atom_fea_len,
        nbr_fea_len,
        n_conv,
        h_fea_len,
        out_dim=1,
        classification=False
    ):
        super().__init__()
        self.intermediate_outputs = {}

        # 1) embedding
        self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len)

        # 2) n_conv layers
        self.convs = nn.ModuleList([
            ModifiedConvLayer(atom_fea_len, nbr_fea_len) for _ in range(n_conv)
        ])

        # 3) pooling
        self.pooling = self._pooling_mean

        # 4) conv_to_fc
        self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len)
        self.conv_to_fc_softplus = nn.Softplus()

        # 5) fc_out
        if classification and out_dim > 1:
            self.fc_out = nn.Linear(h_fea_len, out_dim)
        else:
            self.fc_out = nn.Linear(h_fea_len, out_dim)

    def _pooling_mean(self, atom_fea, crystal_atom_idx_list):
        pooled = []
        for idx in crystal_atom_idx_list:
            chunk = atom_fea[idx]
            pooled.append(chunk.mean(dim=0, keepdim=True))
        return torch.cat(pooled, dim=0)

    def add_hook(self, name):
        def hook(module, input, output):
            self.intermediate_outputs[name] = output.detach()
        return hook

    def forward(self, atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx):
        # 注册 hook
        self.embedding.register_forward_hook(self.add_hook("embedding"))
        for i, conv_layer in enumerate(self.convs):
            conv_layer.register_forward_hook(self.add_hook(f"conv_{i}"))
        self.fc_out.register_forward_hook(self.add_hook("fc_out"))

        # 1) embedding
        atom_fea = self.embedding(atom_fea)
        # 2) conv
        for conv_layer in self.convs:
            atom_fea = conv_layer(atom_fea, nbr_fea, nbr_fea_idx)

        # 3) pooling
        if isinstance(crystal_atom_idx, torch.Tensor):
            crystal_atom_idx = [crystal_atom_idx]
        crys_fea = self.pooling(atom_fea, crystal_atom_idx)

        # 4) conv_to_fc
        crys_fea = self.conv_to_fc_softplus(self.conv_to_fc(crys_fea))

        # 5) fc_out
        out = self.fc_out(crys_fea)
        return out

    def get_graph_embedding(self, atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx):
        """
        只取 fc_out 之前的图 embedding，用于 SHAP 或可视化
        """
        with torch.no_grad():
            atom_fea = self.embedding(atom_fea)
            for conv_layer in self.convs:
                atom_fea = conv_layer(atom_fea, nbr_fea, nbr_fea_idx)
            if isinstance(crystal_atom_idx, torch.Tensor):
                crystal_atom_idx = [crystal_atom_idx]
            crys_fea = self.pooling(atom_fea, crystal_atom_idx)
            crys_fea = self.conv_to_fc_softplus(self.conv_to_fc(crys_fea))
        return crys_fea


##############################################################################
# D. 自动解析维度 + 构造模型的 ModelWrapper
##############################################################################
class ModelWrapper(nn.Module):
    """
    用于 SHAP 阶段：自动解析 checkpoint 得到网络维度,
    在 forward 时只做 get_graph_embedding
    """
    def __init__(self, checkpoint_path):
        super().__init__()
        # 1) 加载 checkpoint
        ckpt = torch.load(checkpoint_path, map_location='cpu')
        state_dict = ckpt['state_dict'] if 'state_dict' in ckpt else ckpt

        # 2) 自动推断网络维度
        (orig_atom_fea_len,
         atom_fea_len,
         nbr_fea_len,
         n_conv,
         h_fea_len,
         out_dim) = infer_cgcnn_dims_from_sd(state_dict)

        print("[ModelWrapper] Inferred dims =>",
              f"orig_atom_fea_len={orig_atom_fea_len},",
              f"atom_fea_len={atom_fea_len},",
              f"nbr_fea_len={nbr_fea_len},",
              f"n_conv={n_conv},",
              f"h_fea_len={h_fea_len},",
              f"out_dim={out_dim}"
        )

        # 3) 构造网络
        classification = (out_dim > 1)
        self.model = CrystalGraphConvNetWithHooks(
            orig_atom_fea_len=orig_atom_fea_len,
            atom_fea_len=atom_fea_len,
            nbr_fea_len=nbr_fea_len,
            n_conv=n_conv,
            h_fea_len=h_fea_len,
            out_dim=out_dim,
            classification=classification
        )

        # 4) 加载权重
        self.model.load_state_dict(state_dict, strict=True)
        self.model.eval()

    def forward(self, graph_batch):
        # 只做 get_graph_embedding
        atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx = graph_batch
        return self.model.get_graph_embedding(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)


##############################################################################
# E. Layer 可视化工具
##############################################################################
def visualize_layer_output(layer_name, output):
    arr = output.squeeze().cpu().numpy()
    plt.figure(figsize=(8, 4))

    if arr.ndim == 0:
        plt.text(0.5, 0.5, f"{layer_name} scalar: {arr:.4f}",
                 ha='center', va='center', fontsize=14)
        plt.axis('off')
    elif arr.ndim == 1:
        plt.plot(arr, marker='o')
        plt.title(f"{layer_name} (1D)")
        plt.xlabel('Index')
        plt.ylabel('Value')
    elif arr.ndim == 2:
        plt.imshow(arr, aspect='auto', cmap='viridis')
        plt.colorbar()
        plt.title(f"{layer_name} (2D)")
        plt.xlabel('Feature dim')
        plt.ylabel('Sample index')
    else:
        print(f"[visualize_layer_output] shape {arr.shape} not easily visualizable.")
        plt.close()
        return

    plt.tight_layout()
    plt.savefig(f"{layer_name}_output.png", dpi=150)
    plt.close()

def analyze_feature_activation(layer_name, output):
    arr = output.squeeze().cpu().numpy()
    if arr.ndim != 2:
        print(f"[analyze_feature_activation] skip, arr.ndim={arr.ndim}")
        return

    mean_act = np.mean(arr, axis=0)
    plt.figure(figsize=(8, 4))
    plt.bar(range(len(mean_act)), mean_act)
    plt.title(f"{layer_name} Mean Activation")
    plt.xlabel('Channel index')
    plt.ylabel('Mean')
    plt.tight_layout()
    plt.savefig(f"{layer_name}_mean_activation.png", dpi=150)
    plt.close()


##############################################################################
# F. dissect_single_sample: 可视化单样本卷积层
##############################################################################
def dissect_single_sample(model_ckpt_path, data_path):
    print("\n=== [dissect_single_sample] ===")
    # 1) 同样先自动解析checkpoint, 构建网络
    #    但在这里, 我们要用 "CrystalGraphConvNetWithHooks" 直接 forward
    #    (因为要看中间层)
    ckpt = torch.load(model_ckpt_path, map_location='cpu')
    state_dict = ckpt['state_dict'] if 'state_dict' in ckpt else ckpt

    (orig_atom_fea_len,
     atom_fea_len,
     nbr_fea_len,
     n_conv,
     h_fea_len,
     out_dim) = infer_cgcnn_dims_from_sd(state_dict)

    classification = (out_dim > 1)
    model = CrystalGraphConvNetWithHooks(
        orig_atom_fea_len=orig_atom_fea_len,
        atom_fea_len=atom_fea_len,
        nbr_fea_len=nbr_fea_len,
        n_conv=n_conv,
        h_fea_len=h_fea_len,
        out_dim=out_dim,
        classification=classification
    )
    model.load_state_dict(state_dict, strict=True)
    model.eval()

    print("[dissect_single_sample] =>",
          f"orig_atom_fea_len={orig_atom_fea_len}, atom_fea_len={atom_fea_len},",
          f"nbr_fea_len={nbr_fea_len}, n_conv={n_conv}, h_fea_len={h_fea_len}, out_dim={out_dim}")

    # 2) 从 dataset 取 1 个样本
    #    注意, 要和训练同样的 radius/其他参数
    dataset = CIFData(data_path, radius=20.0)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_pool)
    sample_input = next(iter(loader))
    graph_inputs, target, cif_id = sample_input

    # 3) 如果只有3个,自己做 idx
    if len(graph_inputs) == 3:
        atom_fea, nbr_fea, crystal_atom_idx = graph_inputs
        nbr_fea_idx = torch.arange(atom_fea.shape[0]).unsqueeze(-1)
        graph_inputs = (atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)

    # 4) forward
    with torch.no_grad():
        output = model(*graph_inputs)
    print("[dissect_single_sample] forward done. Output shape:", output.shape)
    if output.numel() == 1:
        print("Model output value:", output.item())
    else:
        print("Model output (batch):", output)

    # 5) 可视化中间层
    print("\n[dissect_single_sample] Intermediate outputs:")
    for name, out_tensor in model.intermediate_outputs.items():
        print(f"  {name}: shape={tuple(out_tensor.shape)}")
        visualize_layer_output(name, out_tensor)
        analyze_feature_activation(name, out_tensor)


##############################################################################
# G. run_shap_analysis: 做图级嵌入 + SHAP
##############################################################################
def run_shap_analysis(model_ckpt_path, data_path):
    print("\n=== [run_shap_analysis] ===")
    # 直接用自动推断的 ModelWrapper
    emb_model = ModelWrapper(model_ckpt_path)
    emb_model.eval()

    # dataset
    dataset = CIFData(data_path, radius=20.0)
    loader = DataLoader(dataset, batch_size=16, shuffle=False, collate_fn=collate_pool)

    X_emb_list, y_list = [], []
    for batch in loader:
        graph_inputs, targets, cif_ids = batch
        if len(graph_inputs) == 3:
            atom_fea, nbr_fea, crystal_atom_idx = graph_inputs
            nbr_fea_idx = torch.arange(atom_fea.shape[0]).unsqueeze(-1)
            graph_inputs = (atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)
        emb = emb_model(graph_inputs).cpu()
        X_emb_list.append(emb)
        y_list.append(targets)

    X_emb = torch.cat(X_emb_list, dim=0)
    y_test = torch.cat(y_list).numpy()
    print("Graph embedding shape:", X_emb.shape)
    print("y_test shape:", y_test.shape)

    # ---------------------------
    # 下面是你原有的 fc_out & predict
    # ---------------------------
    ckpt = torch.load(model_ckpt_path, map_location='cpu')
    state_dict = ckpt['state_dict'] if 'state_dict' in ckpt else ckpt
    fc_w = state_dict['fc_out.weight'].cpu()
    fc_b = state_dict['fc_out.bias'].cpu()

    def predict_from_emb(X_numpy):
        X_t = torch.tensor(X_numpy, dtype=torch.float32)
        out = X_t @ fc_w.T + fc_b
        return out.squeeze(-1).detach().numpy()

    # shap
    background_size = min(200, X_emb.shape[0])
    background_data = X_emb[:background_size].numpy()

    explainer = shap.KernelExplainer(predict_from_emb, background_data)
    sample_size = min(200, X_emb.shape[0])
    sample_data = X_emb[:sample_size].numpy()

    shap_values = explainer.shap_values(sample_data)
    if isinstance(shap_values, list):
        shap_values = shap_values[0]
    shap_values = np.array(shap_values).reshape(sample_size, -1)
    print("SHAP values shape:", shap_values.shape)

    emb_dim = shap_values.shape[1]
    feature_names = [f"GNN_Emb_{i}" for i in range(emb_dim)]

    # summary plot
    plt.figure(figsize=(10, 6))
    shap.summary_plot(shap_values, sample_data, feature_names=feature_names,
                      plot_type='dot', max_display=10, show=False)
    plt.savefig('shap_summary_swarm.png', dpi=300, bbox_inches='tight')
    plt.close()
    print("[SHAP] summary swarm plot saved -> shap_summary_swarm.png")

    # top K
    mean_importance = np.mean(np.abs(shap_values), axis=0)
    sorted_idx = np.argsort(mean_importance)
    top_k = 10
    print(f"\nTop {top_k} important embedding dims:")
    for i in range(1, top_k + 1):
        idx = sorted_idx[-i]
        print(f"  {feature_names[idx]} = {mean_importance[idx]:.4f}")

    # save shap
    df_shap = pd.DataFrame(shap_values, columns=feature_names)
    df_shap.to_csv('shap_values_emb.csv', index=False)
    print("[SHAP] saved shap_values_emb.csv")

    df_importance = pd.DataFrame({
        'Feature': feature_names,
        'Importance': mean_importance
    }).sort_values('Importance', ascending=False)
    df_importance.to_csv('feature_importance_emb.csv', index=False)
    print("[SHAP] saved feature_importance_emb.csv")

    # ==================================================
    # 在这里，把与 shap_values 行数对应的 embedding 也保存
    # ==================================================
    df_emb_for_shap = pd.DataFrame(sample_data, columns=feature_names)
    # 如果想带上对应的目标值
    df_emb_for_shap["target"] = y_test[:sample_size]
    df_emb_for_shap.to_csv("shap_input_emb.csv", index=False)
    print("[SHAP] saved shap_input_emb.csv (the embedding input to shap).")


##############################################################################
# H. main
##############################################################################
def main():
    model_ckpt_path = r"model_best.pth.tar"
    data_path = r"E:\桌面文件\香港科技大学固态电解质项目\project2_1\文章整理_project2\CGCNN\cgcnn-master_2\data\sample-regression\dielectricity"

    # 1) 对单样本做卷积层可视化
    dissect_single_sample(model_ckpt_path, data_path)

    # 2) 对全数据做 SHAP
    run_shap_analysis(model_ckpt_path, data_path)

if __name__ == "__main__":
    main()


IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html



=== [dissect_single_sample] ===
[dissect_single_sample] => orig_atom_fea_len=92, atom_fea_len=64, nbr_fea_len=101, n_conv=3, h_fea_len=128, out_dim=1
[dissect_single_sample] forward done. Output shape: torch.Size([1, 1])
Model output value: 0.08809866011142731

[dissect_single_sample] Intermediate outputs:
  embedding: shape=(44, 64)
  conv_0: shape=(44, 64)
  conv_1: shape=(44, 64)
  conv_2: shape=(44, 64)
  fc_out: shape=(1, 1)
[analyze_feature_activation] skip, arr.ndim=0

=== [run_shap_analysis] ===
[ModelWrapper] Inferred dims => orig_atom_fea_len=92, atom_fea_len=64, nbr_fea_len=101, n_conv=3, h_fea_len=128, out_dim=1


Issues encountered while parsing CIF: Some fractional coordinates rounded to ideal values to avoid issues with finite precision.
Using 200 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.


Graph embedding shape: torch.Size([6686, 128])
y_test shape: (6686, 1)


100%|██████████| 200/200 [04:39<00:00,  1.40s/it]


SHAP values shape: (200, 128)
[SHAP] summary swarm plot saved -> shap_summary_swarm.png

Top 10 important embedding dims:
  GNN_Emb_73 = 0.0383
  GNN_Emb_23 = 0.0326
  GNN_Emb_118 = 0.0276
  GNN_Emb_124 = 0.0221
  GNN_Emb_106 = 0.0213
  GNN_Emb_39 = 0.0206
  GNN_Emb_56 = 0.0203
  GNN_Emb_90 = 0.0197
  GNN_Emb_61 = 0.0188
  GNN_Emb_71 = 0.0186
[SHAP] saved shap_values_emb.csv
[SHAP] saved feature_importance_emb.csv
