In [2]:
import torch  
import numpy as np  
import torch.nn.functional as F  
from pytorch_grad_cam import GradCAM  
import shap  
import matplotlib.pyplot as plt  
from model import *
import os

In [11]:
def load_model(model_path):  
    """  
    加载预训练模型  
    """  
    # 确保模型文件存在  
    if not os.path.exists(model_path):  
        raise FileNotFoundError(f"Model file not found: {model_path}")  
    
    # 加载完整的检查点  
    checkpoint = torch.load(model_path)  
    
    # 创建模型实例  
    model = ImprovedCNNTransformer()  
    
    # 只加载模型状态字典  
    model.load_state_dict(checkpoint['model_state_dict'])  
    model.eval()  # 设置为评估模式  
    
    # 如果有GPU则使用GPU  
    if torch.cuda.is_available():  
        model = model.cuda()  
    
    return model 


model_path = "/root/experiments/20250117_141323/checkpoint_epoch_64.pth"
model = load_model(model_path)

data_path = "/root/autodl-tmp/CNNLSTM/Project/preprocssed_60_64/sub-0010001_ses-1_task-rest_run-1_bold.nii.gz.npy" 
data_npy = np.load(data_path)
print(data_npy)
data_npy.shape
# out_put = model(data_npy)

# output

You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


[[[[[-0.5889067 ]
    [-0.56582094]
    [-0.54403773]
    ...
    [-0.56337926]
    [-0.57372922]
    [-0.58394798]]

   [[-0.56309448]
    [-0.5628318 ]
    [-0.56300838]
    ...
    [-0.56690623]
    [-0.56694305]
    [-0.567388  ]]

   [[-0.56105258]
    [-0.5672651 ]
    [-0.57330655]
    ...
    [-0.56665985]
    [-0.56003193]
    [-0.55331294]]

   ...

   [[-0.57589625]
    [-0.56192421]
    [-0.54889107]
    ...
    [-0.570832  ]
    [-0.55805765]
    [-0.54490154]]

   [[-0.58557934]
    [-0.56732206]
    [-0.54964627]
    ...
    [-0.56640397]
    [-0.54791863]
    [-0.52939389]]

   [[-0.5413599 ]
    [-0.56063661]
    [-0.57900903]
    ...
    [-0.5678143 ]
    [-0.54962892]
    [-0.53127093]]]


  [[[-0.57691006]
    [-0.5761982 ]
    [-0.57515739]
    ...
    [-0.56555154]
    [-0.56277122]
    [-0.56022342]]

   [[-0.56468569]
    [-0.57447308]
    [-0.58390754]
    ...
    [-0.56236351]
    [-0.55721601]
    [-0.55196328]]

   [[-0.54658876]
    [-0.55844092]
    [-0.57

(60, 64, 64, 64, 1)

In [16]:
def prepare_data_for_inference(data_npy):  
    """  
    预处理数据用于模型推理  
    原始数据形状: (60, 64, 64, 64, 1)  
    目标形状: (batch_size, time_steps, 1, 64, 64, 64)  
    """  
    # 转换为torch张量  
    data = torch.from_numpy(data_npy)  
    
    # 确保数据类型为float32  
    if data.dtype != torch.float32:  
        data = data.float()  
    
    # 调整维度顺序  
    # 从 (60, 64, 64, 64, 1) 转换为 (1, 60, 1, 64, 64, 64)  
    data = data.permute(0, 4, 1, 2, 3)  # 现在是 (60, 1, 64, 64, 64)  
    data = data.unsqueeze(0)  # 添加batch维度，变成 (1, 60, 1, 64, 64, 64)  
    
    # 如果模型在GPU上，将数据也移到GPU  
    if torch.cuda.is_available():  
        data = data.cuda()  
    
    return data  

# 使用示例  
data = prepare_data_for_inference(data_npy)  
print(f"Input shape: {data.shape}")  # 应该输出: torch.Size([1, 60, 1, 64, 64, 64])  

# 使用模型进行推理  
with torch.no_grad():  
    output = model(data)  
print(output)
print(f"Output shape: {output.shape}")

Input shape: torch.Size([1, 60, 1, 64, 64, 64])
tensor([[ 0.7346, -1.3534]], device='cuda:0')
Output shape: torch.Size([1, 2])


In [None]:
import torch  
import numpy as np  
import matplotlib.pyplot as plt  
import seaborn as sns  
from captum.attr import IntegratedGradients, LayerGradCam  
import os  
from typing import Optional, Union, List, Tuple  # 添加这行 
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'  

# 在计算密集操作后及时清理显存  
torch.cuda.empty_cache()  
class ModelInterpreter:  
    def __init__(self, model, data):  
        self.model = model  
        self.data = data  
        self.model.eval()  # 确保模型在评估模式  
        
    def compute_layer_gradcam(self, target_class=1):  
        """  
        使用LayerGradCam替代普通GradCam  
        """  
        # 获取选定的卷积层  
        target_layer = self.model.res1  
        
        # 创建LayerGradCam实例  
        layer_gc = LayerGradCam(self.model, target_layer)  
        
        # 计算属性  
        with torch.no_grad():  
            # 克隆输入数据以避免修改原始数据  
            input_data = self.data.clone()  
            
        attributions = layer_gc.attribute(input_data, target=target_class)  
        return attributions.detach().cpu().numpy()  
    
    def compute_attention_weights(self):  
        """  
        安全地获取注意力权重  
        """  
        attention_weights = []  
        
        def hook_fn(module, input, output):  
            # 只获取注意力权重，不计算梯度  
            with torch.no_grad():  
                if hasattr(module, 'self_attn'):  
                    # 克隆输入以避免修改  
                    q, k, v = [x.clone() for x in input[0:3]]  
                    # 计算注意力分数  
                    attn_weights = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(k.size(-1))  
                    attention_weights.append(attn_weights.cpu().numpy())  
        
        # 注册钩子  
        hooks = []  
        for layer in self.model.transformer_encoder.layers:  
            hooks.append(layer.self_attn.register_forward_hook(hook_fn))  
        
        try:  
            # 前向传播  
            with torch.no_grad():  
                _ = self.model(self.data.clone())  
        finally:  
            # 移除钩子  
            for hook in hooks:  
                hook.remove()  
        
        return attention_weights  
    
    def compute_integrated_gradients(self, target_class=1, steps=20):  # 减少步数  
        """  
        计算Integrated Gradients，使用内存优化  
        """  
        ig = IntegratedGradients(self.model)  
        baseline = torch.zeros_like(self.data)  
        
        # 使用较小的batch处理  
        batch_size = 4  # 根据GPU内存调整  
        n_batches = steps // batch_size  
        attributions_list = []  
        
        try:  
            for i in range(n_batches):  
                with torch.no_grad():  
                    batch_attributions = ig.attribute(  
                        self.data.clone(),  
                        baseline,  
                        target=target_class,  
                        n_steps=batch_size  
                    )  
                    attributions_list.append(batch_attributions.cpu().numpy())  
                
                # 清理GPU内存  
                torch.cuda.empty_cache()  
            
            # 合并结果  
            attributions = np.concatenate(attributions_list, axis=0)  
            attributions = np.mean(attributions, axis=0, keepdims=True)  
            
            return attributions  
        
        except Exception as e:  
            print(f"Error in compute_integrated_gradients: {str(e)}")  
            return None    
    
    def visualize_results(self, save_dir='interpretability_results', original_data=None):  
        """  
        可视化所有结果
        """  
        os.makedirs(save_dir, exist_ok=True)  
        
        # 预处理原始数据  
        if original_data is not None:  
            if isinstance(original_data, torch.Tensor):  
                original_data = original_data.cpu().numpy()  
            if len(original_data.shape) == 6:  # (batch, time, channel, depth, height, width)  
                original_data = original_data[0, 0, 0]  # 取第一个样本，第一个时间点，第一个通道  
            elif len(original_data.shape) == 5:  # (batch, channel, depth, height, width)  
                original_data = original_data[0, 0]  # 取第一个样本，第一个通道  
        
        # Layer GradCAM可视化  
        try:  
            print("Computing Layer GradCAM...")  
            grad_cam_maps = self.compute_layer_gradcam()  
            if grad_cam_maps is not None:  
                self._plot_3d_heatmap(  
                    grad_cam_maps[0],  
                    os.path.join(save_dir, 'layer_gradcam.png'),  
                    title='Layer GradCAM Visualization',  
                    original_data=original_data  
                )  
                print("Layer GradCAM computation completed")  
        except Exception as e:  
            print(f"Layer GradCAM computation failed: {str(e)}")  
        
        # 注意力权重可视化  
        try:  
            print("Computing attention weights...")  
            attention_weights = self.compute_attention_weights()  
            if attention_weights:  
                for i, attn in enumerate(attention_weights):  
                    plt.figure(figsize=(10, 8))  
                    sns.heatmap(attn[0], cmap='viridis')  
                    plt.title(f'Attention Weights Layer {i+1}')  
                    plt.savefig(os.path.join(save_dir, f'attention_layer_{i+1}.png'))  
                    plt.close()  
            print("Attention weights visualization completed")  
        except Exception as e:  
            print(f"Attention visualization failed: {str(e)}")  
        
        # Integrated Gradients可视化  
        try:  
            print("Computing Integrated Gradients...")  
            ig_attributions = self.compute_integrated_gradients()  
            if ig_attributions is not None:  
                self._plot_3d_heatmap(  
                    np.mean(ig_attributions[0], axis=0),  
                    os.path.join(save_dir, 'integrated_gradients.png'),  
                    title='Integrated Gradients',  
                    original_data=original_data  
                )  
                print("Integrated Gradients computation completed")  
        except Exception as e:  
            print(f"Integrated Gradients computation failed: {str(e)}")  
    
        # SHAP值可视化
        # try:  
        #     print("Computing and visualizing SHAP values...")  
            
        #     # 重塑数据以适应SHAP分析  
        #     reshaped_data = self._reshape_data_for_shap(self.data)  
            
        #     # 计算SHAP值  
        #     shap_values = self.compute_shap_values_batch(  
        #         batch_size=1,  
        #         n_samples=20  # 可以调整采样数量  
        #     )  
            
        #     if shap_values is not None:  
        #         try:  
        #             # 处理多类别情况  
        #             if isinstance(shap_values, list):  
        #                 shap_values_viz = shap_values[0]  # 使用第一个类别  
        #             else:  
        #                 shap_values_viz = shap_values  
                    
        #             # 确保数据类型正确  
        #             if isinstance(shap_values_viz, torch.Tensor):  
        #                 shap_values_viz = shap_values_viz.cpu().numpy()  
                    
        #             # 重塑SHAP值回原始维度  
        #             if original_data is not None:  
        #                 shap_values_viz = shap_values_viz.reshape(original_data.shape)  
                    
        #             # 绘制SHAP值总结图  
        #             self._plot_shap_summary(  
        #                 shap_values_viz,  
        #                 os.path.join(save_dir, 'shap_summary.png')  
        #             )  
                    
        #             # 绘制SHAP值3D热力图  
        #             self._plot_shap_3d_heatmap(  
        #                 shap_values_viz,  
        #                 os.path.join(save_dir, 'shap_3d_heatmap.png'),  
        #                 original_data  
        #             )  
                    
        #             print("SHAP visualization completed")  
                    
        #         except Exception as e:  
        #             print(f"Error in SHAP visualization: {str(e)}")  
                    
        # except Exception as e:  
        #     print(f"SHAP visualization failed: {str(e)}")  
        #     import traceback  
        #     traceback.print_exc()  


    def _overlay_heatmap(self, original, heatmap, alpha=0.4):  
        """  
        将热力图叠加到原始图像上  
        Args:  
            original: 原始图像数据  
            heatmap: 热力图数据  
            alpha: 透明度  
        """  
        # 归一化热力图  
        heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))  
        
        # 使用jet颜色映射转换热力图为RGB  
        cmap = plt.cm.jet  
        heatmap_colored = cmap(heatmap)[:, :, :3]  # 去掉alpha通道  
        
        # 归一化原始图像  
        original = (original - np.min(original)) / (np.max(original) - np.min(original))  
        
        # 创建叠加图像  
        overlay = original[..., np.newaxis] * (1 - alpha) + heatmap_colored * alpha  
        
        return overlay  
    

    def _plot_3d_heatmap(self, data, save_path, title, original_data=None):  
        # 确保数据是3D的  
        if len(data.shape) > 3:  
            data = np.mean(data, axis=0)  
        
        # 确保数据形状正确  
        if original_data is not None:  
            # 如果原始数据是5D的 (batch, channel, depth, height, width)，取第一个样本的第一个通道  
            if len(original_data.shape) == 5:  
                original_data = original_data[0, 0]  
            # 如果是6D的 (batch, time, channel, depth, height, width)，取第一个样本，第一个时间点的第一个通道  
            elif len(original_data.shape) == 6:  
                original_data = original_data[0, 0, 0]  
        
        # 确保热力图数据大小与原始数据匹配  
        if data.shape != original_data.shape:  
            # 使用插值调整热力图大小  
            from scipy.ndimage import zoom  
            zoom_factors = tuple(o/d for o, d in zip(original_data.shape, data.shape))  
            data = zoom(data, zoom_factors)  
        
        # 获取中间切片  
        mid_z = data.shape[-1] // 2  
        mid_y = data.shape[-2] // 2  
        mid_x = data.shape[-3] // 2  
          
        fig = plt.figure(figsize=(20, 15))  
        gs = plt.GridSpec(2, 3)  
        
        # 1. 原始热力图视图  
        ax1 = fig.add_subplot(gs[0, 0])  
        ax2 = fig.add_subplot(gs[0, 1])  
        ax3 = fig.add_subplot(gs[0, 2])  
        
        # 绘制热力图  
        im1 = ax1.imshow(data[mid_x, :, :], cmap='hot')  
        im2 = ax2.imshow(data[:, mid_y, :], cmap='hot')  
        im3 = ax3.imshow(data[:, :, mid_z], cmap='hot')  
        
        ax1.set_title('Sagittal View - Heatmap')  
        ax2.set_title('Coronal View - Heatmap')  
        ax3.set_title('Axial View - Heatmap')  
        
        plt.colorbar(im1, ax=ax1)  
        plt.colorbar(im2, ax=ax2)  
        plt.colorbar(im3, ax=ax3)  
        
        # 绘制叠加图  
        if original_data is not None:  
            ax4 = fig.add_subplot(gs[1, 0])  
            ax5 = fig.add_subplot(gs[1, 1])  
            ax6 = fig.add_subplot(gs[1, 2])  
            
            # 创建叠加图  
            overlay_sagittal = self._overlay_heatmap(  
                original_data[mid_x, :, :],  
                data[mid_x, :, :]  
            )  
            overlay_coronal = self._overlay_heatmap(  
                original_data[:, mid_y, :],  
                data[:, mid_y, :]  
            )  
            overlay_axial = self._overlay_heatmap(  
                original_data[:, :, mid_z],  
                data[:, :, mid_z]  
            )  
            
            # 显示叠加图  
            ax4.imshow(overlay_sagittal)  
            ax5.imshow(overlay_coronal)  
            ax6.imshow(overlay_axial)  
            
            ax4.set_title('Sagittal View - Overlay')  
            ax5.set_title('Coronal View - Overlay')  
            ax6.set_title('Axial View - Overlay')  
        
        plt.suptitle(title)  
        plt.tight_layout()  
        plt.savefig(save_path)  
        plt.close()

def run_interpretation(model_path, data_path):  
    """  
    运行解释性分析的主函数  
    """  
    # 加载模型和数据  
    model = load_model(model_path)  
    data_npy = np.load(data_path)  
    data = prepare_data_for_inference(data_npy)  
    
    # 保存原始数据用于可视化  
    original_data = data_npy.copy()  
    
    print("Model loaded and data prepared")  
    print(f"Input data shape: {data.shape}")  
     
    interpreter = ModelInterpreter(model, data)  
     
    print("Starting visualization process...")  
    interpreter.visualize_results(original_data=original_data)  
     
    with torch.no_grad():  
        output = model(data)  
        predictions = torch.softmax(output, dim=1)  
        print(f"Model predictions: {predictions}")  
    
    print("Interpretation completed")  
    return interpreter

# 使用示例  
if __name__ == "__main__":  
    model_path = "/root/experiments/20250117_141323/checkpoint_epoch_64.pth"  
    data_path = "/root/autodl-tmp/CNNLSTM/Project/preprocssed_60_64/sub-0010001_ses-1_task-rest_run-1_bold.nii.gz.npy"  
    
    try:  
        interpreter = run_interpretation(model_path, data_path)  
        print("Successfully completed all interpretability analyses")  
    except Exception as e:  
        print(f"An error occurred during interpretation: {str(e)}")

You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


Model loaded and data prepared
Input data shape: torch.Size([1, 60, 1, 64, 64, 64])
Starting visualization process...
Computing Layer GradCAM...
Layer GradCAM computation completed
Computing attention weights...
Attention weights visualization completed
Computing Integrated Gradients...
Integrated Gradients computation completed
Computing and visualizing SHAP values...
Computing SHAP values using KernelExplainer...
Flattened data shape: (1, 15728640)
KernelExplainer created successfully


Computing SHAP values:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Computing SHAP values:   0%|          | 0/1 [01:59<?, ?it/s]


KeyboardInterrupt: 

In [78]:
    def compute_shap_values(self, input_data):  
        try:  
            # 将数据移到CPU并转换为numpy数组  
            input_data_cpu = input_data.cpu().numpy()  
            
            # 获取原始形状以便后续恢复  
            original_shape = input_data_cpu.shape  
            
            # 展平数据为2维数组: (样本数, 特征数)  
            # 将所有空间维度展平为一个向量  
            flattened_data = input_data_cpu.reshape(original_shape[0], -1)  
            
            print(f"Flattened data shape: {flattened_data.shape}")  # 调试信息  
            
            # 创建模型包装器  
            def model_wrapper(x):  
                # 重塑输入回原始维度  
                x_reshaped = torch.tensor(  
                    x.reshape(-1, *original_shape[1:]),  # 使用原始形状  
                    dtype=torch.float32,  
                    device='cuda' if torch.cuda.is_available() else 'cpu'  
                )  
                
                with torch.no_grad():  
                    output = self.model(x_reshaped)  
                    return output.cpu().numpy()  
            
            # 创建背景数据（可以使用输入数据的均值或零值）  
            background = np.zeros((1, flattened_data.shape[1]))  
            
            # 创建解释器  
            explainer = shap.KernelExplainer(  
                model_wrapper,  
                background,  
                link="identity"  
            )  
            
            # 计算SHAP值  
            shap_values = explainer.shap_values(  
                flattened_data,  
                nsamples=100,  # 减少样本数以避免内存问题  
                batch_size=1    # 使用较小的批量大小  
            )  
            
            # 如果shap_values是列表（多类别情况），处理每个类别的SHAP值  
            if isinstance(shap_values, list):  
                shap_values = [  
                    sv.reshape(original_shape) for sv in shap_values  
                ]  
            else:  
                # 单类别情况  
                shap_values = shap_values.reshape(original_shape)  
            
            print("SHAP values computed successfully")  
            return shap_values  
            
        except Exception as e:  
            print(f"Detailed error in SHAP computation: {str(e)}")  
            print(f"Error type: {type(e)}")  
            import traceback  
            print(traceback.format_exc())  
            return None  
        
    def compute_shap_values_batch(self, background_data=None, n_samples=5, batch_size=1):  
        try:  
            print("Computing SHAP values using KernelExplainer...")  
            self.model.eval()  
            
            # 包装函数修改  
            def model_wrapper(x):  
                if isinstance(x, np.ndarray):  
                    # 如果x是展平的，需要重塑回原始维度  
                    if len(x.shape) == 2:  
                        x = x.reshape(-1, *self.data.shape[1:])  
                    x = torch.FloatTensor(x)  
                    if self.data.is_cuda:  
                        x = x.cuda()  
                
                with torch.no_grad():  
                    output = self.model(x)  
                    probs = torch.softmax(output, dim=1)  
                    return probs.cpu().numpy()  
            
            # 准备背景数据  
            if background_data is None:  
                background_data = torch.randn(  
                    (min(n_samples, 3),) + tuple(self.data.shape[1:]),  # 减少背景样本数量  
                    device='cpu'  # 先放在 CPU 上  
                )   
            
            # 转换并展平数据  
            background_numpy = background_data.cpu().numpy()  
            data_numpy = self.data.cpu().numpy()  
            
            # 展平数据为2维  
            background_flat = background_numpy.reshape(background_numpy.shape[0], -1)  
            data_flat = data_numpy.reshape(data_numpy.shape[0], -1)  
            
            print(f"Flattened data shape: {data_flat.shape}")  
            
            try:  
                explainer = shap.KernelExplainer(  
                    model_wrapper,  
                    background_flat,  # 使用展平的背景数据  
                    link="identity"  
                )  
                print("KernelExplainer created successfully")  
            except Exception as e:  
                print(f"Error creating KernelExplainer: {str(e)}")  
                return None  
            
            # 分批计算SHAP值  
            all_shap_values = []  
            total_batches = (len(data_flat) + batch_size - 1) // batch_size  
            
            for i in tqdm(range(0, len(data_flat), batch_size), total=total_batches, desc="Computing SHAP values"):  
                # print(f"Processing batch {i//batch_size + 1}/{total_batches}")  
                
                try:  
                    batch = data_flat[i:i + batch_size]  # 使用展平的数据  
                    batch_shap_values = explainer.shap_values(  
                        batch,  
                        nsamples=50  
                    )  
                    
                    # 重塑SHAP值回原始维度  
                    if isinstance(batch_shap_values, list):  
                        # 多类别情况  
                        batch_shap_values = [  
                            sv.reshape(-1, *self.data.shape[1:])  
                            for sv in batch_shap_values  
                        ]  
                    else:  
                        # 单类别情况  
                        batch_shap_values = batch_shap_values.reshape(-1, *self.data.shape[1:])  
                    
                    all_shap_values.append(batch_shap_values)  
                    print(f"Batch {i//batch_size + 1} completed")  
                    
                except Exception as e:  
                    print(f"Error processing batch {i//batch_size + 1}: {str(e)}")  
                    continue  
                
                if self.data.is_cuda:  
                    torch.cuda.empty_cache()  
            
            if not all_shap_values:  
                print("No SHAP values were successfully computed")  
                return None  
            
            # 合并结果  
            try:  
                if isinstance(all_shap_values[0], list):  
                    final_shap_values = [  
                        np.concatenate([batch[i] for batch in all_shap_values])  
                        for i in range(len(all_shap_values[0]))  
                    ]  
                else:  
                    final_shap_values = np.concatenate(all_shap_values)  
                
                print("SHAP values computation completed")  
                return final_shap_values  
                
            except Exception as e:  
                print(f"Error merging results: {str(e)}")  
                return None  
            
        except Exception as e:  
            print(f"Error in compute_shap_values_batch: {str(e)}")  
            import traceback  
            traceback.print_exc()  
            return None

    def _reshape_data_for_shap(self, data):  
        """  
        重塑数据以适应SHAP分析  
        """  
        # 假设输入形状为 (batch, time, channel, depth, height, width)  
        # 或 (batch, channel, depth, height, width)  
        shape = data.shape  
        if len(shape) == 6:  # 6D数据  
            # 将时间维度展平到批次维度  
            return data.reshape(-1, *shape[2:])  
        return data  
    
    def _plot_shap_summary(self,   
                          shap_values: np.ndarray,  
                          save_path: str,  
                          feature_names: Optional[List[str]] = None):  
        """  
        绘制SHAP值的总结图  
        """  
        plt.figure(figsize=(12, 8))  
        
        # 如果是多分类，取第一个类别的SHAP值  
        if isinstance(shap_values, list):  
            shap_values = shap_values[0]  
        
        # 计算每个特征的平均绝对SHAP值  
        feature_importance = np.mean(np.abs(shap_values), axis=0)  
        
        # 创建特征名称  
        if feature_names is None:  
            feature_names = [f'Feature {i}' for i in range(len(feature_importance))]  
        
        # 绘制条形图  
        plt.barh(range(len(feature_importance)), feature_importance)  
        plt.yticks(range(len(feature_importance)), feature_names)  
        plt.xlabel('mean(|SHAP value|)')  
        plt.title('Feature Importance Based on SHAP Values')  
        
        plt.tight_layout()  
        plt.savefig(save_path)  
        plt.close()  
    
    def _plot_shap_3d_heatmap(self,   
                             shap_values: np.ndarray,  
                             save_path: str,  
                             original_data: Optional[np.ndarray] = None):  
        """  
        绘制3D SHAP值热力图  
        """  
        # 确保SHAP值是3D的  
        if len(shap_values.shape) > 3:  
            shap_values = np.mean(shap_values, axis=0)  
        
        # 调用原有的_plot_3d_heatmap方法  
        self._plot_3d_heatmap(  
            data=shap_values,  
            save_path=save_path,  
            title='SHAP Values 3D Visualization',  
            original_data=original_data  
        )  