In [4]:
import torch
import shap
import numpy as np
from cgcnn.data import CIFData
from cgcnn.data import collate_pool
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from lime import lime_tabular

# 加载数据
dataset = CIFData('data/sample-regression/dielectricity')
test_loader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=0,
                         collate_fn=collate_pool)

# 加载模型
best_model = torch.load('model_best.pth.tar', map_location=torch.device('cpu'))

# 创建一个包装器类
class ModelWrapper(torch.nn.Module):
    def __init__(self, state_dict):
        super().__init__()
        self.state_dict = {k: v.cpu() for k, v in state_dict.items()}
        print("ModelWrapper initialized with keys:", self.state_dict.keys())
    
    def forward(self, *args):
        input_data = args[0]  # 获取输入数据
        
        # 调试输出
        print(f"Input data shape: {input_data.shape}")
        print(f"fc_out.weight shape: {self.state_dict['fc_out.weight'].shape}")
        print(f"fc_out.bias shape: {self.state_dict['fc_out.bias'].shape}")
        
        # 确保输入是2D的
        if input_data.dim() > 2:
            input_data = input_data.view(-1, input_data.size(-1))
        
        # 如果输入特征数不匹配，我们需要调整输入
        if input_data.shape[-1] < self.state_dict['fc_out.weight'].shape[1]:
            print(f"Padding input feature size from {input_data.shape[-1]} to {self.state_dict['fc_out.weight'].shape[1]}")
            padding = torch.zeros(input_data.size(0), self.state_dict['fc_out.weight'].shape[1] - input_data.size(1))
            input_data = torch.cat([input_data, padding], dim=1)
        elif input_data.shape[-1] > self.state_dict['fc_out.weight'].shape[1]:
            print(f"Truncating input feature size from {input_data.shape[-1]} to {self.state_dict['fc_out.weight'].shape[1]}")
            input_data = input_data[:, :self.state_dict['fc_out.weight'].shape[1]]
        
        # 矩阵乘法
        output = torch.matmul(input_data, self.state_dict['fc_out.weight'].t()) + self.state_dict['fc_out.bias']
        print(f"Output shape: {output.shape}")
        return output

# 使用包装器类创建模型
model = ModelWrapper(best_model['state_dict'])

# 准备数据
X_test = []
y_test = []
for i, (input, target, _) in enumerate(test_loader):
    X_test.append(input[0].cpu())
    y_test.append(target.cpu())

X_test = torch.cat(X_test, dim=0)
y_test = torch.cat(y_test).numpy()

print(f"X_test shape: {X_test.shape}")
print(f"y_test shape: {y_test.shape}")
print("X_test first few rows:")
print(X_test[:5])

# 定义一个包装函数来适配LIME
def f(X):
    with torch.no_grad():
        X = torch.tensor(X, dtype=torch.float32)
        if X.dim() == 2:
            X = X.unsqueeze(0)  # 添加批次维度
        output = model(X)
    return output.numpy()

# 创建LIME解释器
explainer = lime_tabular.LimeTabularExplainer(
    X_test.numpy(),
    mode="regression",
    feature_names=[f"feature_{i}" for i in range(X_test.shape[1])],
    verbose=True,
    random_state=42
)

# 选择一个样本进行解释
sample_idx = 0
sample = X_test[sample_idx].numpy()

# 获取LIME解释
exp = explainer.explain_instance(sample, f, num_features=20)

# 可视化LIME结果
plt.figure(figsize=(12, 8))
exp.as_pyplot_figure()
plt.tight_layout()
plt.savefig('lime_explanation.png', dpi=300, bbox_inches='tight')
plt.close()

print("LIME analysis completed. Results saved in 'lime_explanation.png'.")

# 打印LIME解释的详细信息
print("\nLIME Explanation Details:")
for feature, importance in exp.as_list():
    print(f"{feature}: {importance:.4f}")

# 对比LIME和模型预测
lime_prediction = exp.predicted_value
actual_prediction = f(sample.reshape(1, -1))[0][0]
actual_value = y_test[sample_idx]

print(f"\nLIME Prediction: {lime_prediction:.4f}")
print(f"Model Prediction: {actual_prediction:.4f}")
print(f"Actual Value: {actual_value[0]:.4f}")  # 修改这里

# 添加一些额外的分析
print("\nFeature Importance Summary:")
feature_importance = sorted(exp.as_list(), key=lambda x: abs(x[1]), reverse=True)
for feature, importance in feature_importance[:5]:  # 显示前5个最重要的特征
    print(f"{feature}: {importance:.4f}")

print("\nModel Performance:")
prediction_error = abs(actual_prediction - actual_value[0])
print(f"Prediction Error: {prediction_error:.4f}")
print(f"Relative Error: {(prediction_error / actual_value[0]) * 100:.2f}%")

ModelWrapper initialized with keys: dict_keys(['embedding.weight', 'embedding.bias', 'convs.0.fc_full.weight', 'convs.0.fc_full.bias', 'convs.0.bn1.weight', 'convs.0.bn1.bias', 'convs.0.bn1.running_mean', 'convs.0.bn1.running_var', 'convs.0.bn1.num_batches_tracked', 'convs.0.bn2.weight', 'convs.0.bn2.bias', 'convs.0.bn2.running_mean', 'convs.0.bn2.running_var', 'convs.0.bn2.num_batches_tracked', 'convs.1.fc_full.weight', 'convs.1.fc_full.bias', 'convs.1.bn1.weight', 'convs.1.bn1.bias', 'convs.1.bn1.running_mean', 'convs.1.bn1.running_var', 'convs.1.bn1.num_batches_tracked', 'convs.1.bn2.weight', 'convs.1.bn2.bias', 'convs.1.bn2.running_mean', 'convs.1.bn2.running_var', 'convs.1.bn2.num_batches_tracked', 'convs.2.fc_full.weight', 'convs.2.fc_full.bias', 'convs.2.bn1.weight', 'convs.2.bn1.bias', 'convs.2.bn1.running_mean', 'convs.2.bn1.running_var', 'convs.2.bn1.num_batches_tracked', 'convs.2.bn2.weight', 'convs.2.bn2.bias', 'convs.2.bn2.running_mean', 'convs.2.bn2.running_var', 'convs.2



X_test shape: torch.Size([26087, 92])
y_test shape: (711, 1)
X_test first few rows:
tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0

<Figure size 1200x800 with 0 Axes>