In [3]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# 假设这些导入是可用的，如果不是，您可能需要调整导入语句
from cgcnn.model import CrystalGraphConvNet
from cgcnn.data import CIFData, collate_pool

class ModifiedConvLayer(torch.nn.Module):
    def __init__(self, atom_fea_len, nbr_fea_len):
        super(ModifiedConvLayer, self).__init__()
        self.atom_fea_len = atom_fea_len
        self.nbr_fea_len = nbr_fea_len
        self.fc_full = None  # We'll initialize this in the forward pass
        self.sigmoid = torch.nn.Sigmoid()
        self.softplus1 = torch.nn.Softplus()
        self.bn1 = None  # We'll initialize this in the forward pass
        self.bn2 = torch.nn.BatchNorm1d(self.atom_fea_len)
        self.softplus2 = torch.nn.Softplus()

    def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
        N, M = nbr_fea_idx.shape
        print(f"ModifiedConvLayer - Input shapes:")
        print(f"  atom_in_fea: {atom_in_fea.shape}")
        print(f"  nbr_fea: {nbr_fea.shape}")
        print(f"  nbr_fea_idx: {nbr_fea_idx.shape}")

        atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]
        total_nbr_fea = torch.cat(
            [atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len),
             atom_nbr_fea, nbr_fea], dim=2)
        
        print(f"  total_nbr_fea shape: {total_nbr_fea.shape}")
        
        # Dynamically create fc_full and bn1 layers based on actual input size
        if self.fc_full is None:
            in_features = total_nbr_fea.shape[-1]
            out_features = 2 * self.atom_fea_len
            self.fc_full = torch.nn.Linear(in_features, out_features)
            self.bn1 = torch.nn.BatchNorm1d(out_features)
        
        print(f"  fc_full input shape: {total_nbr_fea.view(-1, total_nbr_fea.shape[-1]).shape}")
        print(f"  fc_full weight shape: {self.fc_full.weight.shape}")

        total_gated_fea = self.fc_full(total_nbr_fea.view(-1, total_nbr_fea.shape[-1]))
        total_gated_fea = self.bn1(total_gated_fea).view(N, M, 2*self.atom_fea_len)
        nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2)
        nbr_filter = self.sigmoid(nbr_filter)
        nbr_core = self.softplus1(nbr_core)
        nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1)
        nbr_sumed = self.bn2(nbr_sumed)
        out = self.softplus2(atom_in_fea + nbr_sumed)
        return out

class CrystalGraphConvNetWithHooks(CrystalGraphConvNet):
    def __init__(self, orig_atom_fea_len, nbr_fea_len,
                 atom_fea_len=64, n_conv=3, h_fea_len=128, n_h=1,
                 classification=False):
        super().__init__(orig_atom_fea_len, nbr_fea_len,
                         atom_fea_len=atom_fea_len, n_conv=n_conv,
                         h_fea_len=h_fea_len, n_h=n_h,
                         classification=classification)
        self.intermediate_outputs = {}
        self.atom_fea_len = atom_fea_len
        self.nbr_fea_len = nbr_fea_len
        
        # Replace the original convs with our modified version
        self.convs = torch.nn.ModuleList([ModifiedConvLayer(
            atom_fea_len=self.atom_fea_len,
            nbr_fea_len=self.nbr_fea_len)
            for _ in range(n_conv)])

    def add_hook(self, name):
        def hook(module, input, output):
            self.intermediate_outputs[name] = output.detach()
        return hook

    def forward(self, atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx):
        print(f"CrystalGraphConvNetWithHooks - Input shapes:")
        print(f"  atom_fea: {atom_fea.shape}")
        print(f"  nbr_fea: {nbr_fea.shape}")
        print(f"  nbr_fea_idx: {nbr_fea_idx.shape}")
        print(f"  crystal_atom_idx: {crystal_atom_idx.shape if isinstance(crystal_atom_idx, torch.Tensor) else type(crystal_atom_idx)}")

        for name, module in self.named_modules():
            if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.Linear):
                module.register_forward_hook(self.add_hook(name))
        
        atom_fea = self.embedding(atom_fea)
        print(f"  After embedding, atom_fea shape: {atom_fea.shape}")
        
        for i, conv_func in enumerate(self.convs):
            print(f"  Before conv {i+1}, atom_fea shape: {atom_fea.shape}")
            atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx)
            print(f"  After conv {i+1}, atom_fea shape: {atom_fea.shape}")
        
        # Handle crystal_atom_idx
        if isinstance(crystal_atom_idx, torch.Tensor) and crystal_atom_idx.dim() == 0:
            print("Warning: crystal_atom_idx is a 0-d tensor. Converting to a list.")
            crystal_atom_idx = [torch.arange(atom_fea.shape[0], dtype=torch.long)]
        elif not isinstance(crystal_atom_idx, list):
            print(f"Warning: Unexpected type for crystal_atom_idx: {type(crystal_atom_idx)}. Converting to a list.")
            crystal_atom_idx = [torch.arange(atom_fea.shape[0], dtype=torch.long)]

        print(f"  Processed crystal_atom_idx: {[idx.shape if isinstance(idx, torch.Tensor) else len(idx) for idx in crystal_atom_idx]}")
        
        crys_fea = self.pooling(atom_fea, crystal_atom_idx)
        print(f"  After pooling, crys_fea shape: {crys_fea.shape}")
        
        crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea))
        print(f"  After conv_to_fc, crys_fea shape: {crys_fea.shape}")
        
        crys_fea = self.conv_to_fc_softplus(crys_fea)
        print(f"  After conv_to_fc_softplus, crys_fea shape: {crys_fea.shape}")
        
        out = self.fc_out(crys_fea)
        print(f"  Final output shape: {out.shape}")

        return out

# 加载模型
checkpoint = torch.load('model_best.pth.tar', map_location=torch.device('cpu'))

print("Keys in checkpoint:", checkpoint.keys())

# 提取模型参数
state_dict = checkpoint['state_dict']

# 从checkpoint的args中提取模型参数
if 'args' in checkpoint:
    args = checkpoint['args']
    print("Model arguments found in checkpoint:")
    print(args)
    
    # 从第一个卷积层的权重形状精确计算参数
    conv_weight_shape = state_dict['convs.0.fc_full.weight'].shape
    total_fea_len = conv_weight_shape[1]
    atom_fea_len = args.get('atom_fea_len', 64)
    nbr_fea_len = total_fea_len - atom_fea_len
    
    print(f"Calculated nbr_fea_len: {nbr_fea_len}")
    print(f"atom_fea_len: {atom_fea_len}")
    
    model_args = {
        'atom_fea_len': atom_fea_len,
        'nbr_fea_len': nbr_fea_len,
        'n_conv': args.get('n_conv', 3),
        'h_fea_len': args.get('h_fea_len', 128),
        'n_h': args.get('n_h', 1),
        'classification': False  # 假设这是回归任务
    }
    
    # 检查 orig_atom_fea_len
    if 'orig_atom_fea_len' in args:
        model_args['orig_atom_fea_len'] = args['orig_atom_fea_len']
    else:
        print("orig_atom_fea_len not found in args, using value from embedding layer")
        model_args['orig_atom_fea_len'] = state_dict['embedding.weight'].shape[1]
else:
    raise ValueError("No 'args' found in checkpoint. Cannot determine model parameters.")

print("Model arguments to be used:")
print(model_args)

# 创建模型实例
model = CrystalGraphConvNetWithHooks(**model_args)

# 手动过滤掉不匹配的层
model_state_dict = model.state_dict()
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict and v.size() == model_state_dict[k].size()}

# 加载过滤后的权重
try:
    model.load_state_dict(filtered_state_dict, strict=False)
    print("Model loaded successfully with filtered weights!")
except RuntimeError as e:
    print(f"Error loading state dict: {e}")
    print("Model structure:")
    print(model)
    print("\nState dict keys:")
    print(state_dict.keys())
    raise

model.eval()

# 打印模型结构
print("Model structure:")
print(model)

# 加载数据集
dataset = CIFData('data/sample-regression/dielectricity')
loader = DataLoader(dataset, batch_size=1, collate_fn=collate_pool)

# 获取一个样本的数据
sample_input = next(iter(loader))

# 解包样本数据
atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx = sample_input[0]
target = sample_input[1]
chemical_formula = sample_input[2][0]

print(f"Input data shapes:")
print(f"atom_fea shape: {atom_fea.shape}")
print(f"nbr_fea shape: {nbr_fea.shape}")
print(f"nbr_fea_idx shape: {nbr_fea_idx.shape}")
print(f"crystal_atom_idx type: {type(crystal_atom_idx)}")
print(f"crystal_atom_idx: {crystal_atom_idx}")
print(f"target shape: {target.shape}")
print(f"chemical_formula: {chemical_formula}")

# 确保所有输入都是正确的张量类型
atom_fea = atom_fea.float()
nbr_fea = nbr_fea.float()
nbr_fea_idx = nbr_fea_idx.long()

# 处理 crystal_atom_idx
if isinstance(crystal_atom_idx, list) and len(crystal_atom_idx) == 1 and isinstance(crystal_atom_idx[0], torch.Tensor):
    crystal_atom_idx = crystal_atom_idx[0].long()
elif not isinstance(crystal_atom_idx, torch.Tensor):
    print(f"Warning: Unexpected type for crystal_atom_idx: {type(crystal_atom_idx)}. Converting to tensor.")
    crystal_atom_idx = torch.arange(atom_fea.shape[0], dtype=torch.long)

print(f"Processed crystal_atom_idx shape: {crystal_atom_idx.shape}")


def visualize_layer_output(layer_name, output):
    output = output.squeeze().cpu().numpy()
    plt.figure(figsize=(10, 6))
    
    if output.ndim == 0:  # Scalar output
        plt.text(0.5, 0.5, f"Scalar Output: {output.item():.4f}", 
                 horizontalalignment='center', verticalalignment='center', fontsize=20)
        plt.axis('off')
    elif output.ndim == 1:
        plt.plot(output)
        plt.title(f'Output of {layer_name}')
        plt.xlabel('Index')
        plt.ylabel('Value')
    elif output.ndim == 2:
        plt.imshow(output, aspect='auto', cmap='viridis')
        plt.title(f'Output of {layer_name}')
        plt.colorbar()
        plt.xlabel('Feature dimension')
        plt.ylabel('Sample index')
    else:
        print(f"Cannot visualize output of shape {output.shape} for layer {layer_name}")
        return
    
    plt.savefig(f'{layer_name}_output.png')
    plt.close()

def analyze_feature_activation(layer_name, output):
    output = output.squeeze().cpu().numpy()
    if output.ndim == 0:  # Scalar output
        print(f"Cannot analyze feature activation for scalar output in layer {layer_name}")
        return
    elif output.ndim == 1:
        mean_activation = output
    else:
        mean_activation = np.mean(output, axis=0)
    
    plt.figure(figsize=(10, 6))
    plt.bar(range(len(mean_activation)), mean_activation)
    plt.title(f'Mean Feature Activation in {layer_name}')
    plt.xlabel('Feature index')
    plt.ylabel('Mean activation')
    plt.savefig(f'{layer_name}_mean_activation.png')
    plt.close()

# 运行模型
try:
    with torch.no_grad():
        output = model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)
    print("Model output shape:", output.shape)
    print("Model output value:", output.item())
except Exception as e:
    print(f"Error during model execution: {str(e)}")
    import traceback
    traceback.print_exc()

# 打印中间层输出
print("Intermediate outputs:")
for name, output in model.intermediate_outputs.items():
    print(f"{name}: {output.shape}")

# 执行可视化
for name, output in model.intermediate_outputs.items():
    try:
        visualize_layer_output(name, output)
        analyze_feature_activation(name, output)
    except Exception as e:
        print(f"Error visualizing {name}: {str(e)}")

print("Model dissection completed. Check the output images for visualization results.")

# 额外的分析
print("\nAdditional Analysis:")
print("1. Number of parameters in each layer:")
for name, param in model.named_parameters():
    print(f"{name}: {param.numel()}")

print("\n2. Layer output statistics:")
for name, output in model.intermediate_outputs.items():
    output_np = output.cpu().numpy()
    print(f"{name}:")
    print(f"  Mean: {np.mean(output_np):.4f}")
    print(f"  Std: {np.std(output_np):.4f}")
    print(f"  Min: {np.min(output_np):.4f}")
    print(f"  Max: {np.max(output_np):.4f}")

print("\nModel dissection and analysis completed.")

Keys in checkpoint: dict_keys(['epoch', 'state_dict', 'best_mae_error', 'optimizer', 'normalizer', 'args'])
Model arguments found in checkpoint:
{'data_options': ['data/sample-regression/dielectricity'], 'task': 'regression', 'disable_cuda': False, 'workers': 0, 'epochs': 100, 'start_epoch': 0, 'batch_size': 256, 'lr': 0.01, 'lr_milestones': [100], 'momentum': 0.9, 'weight_decay': 0, 'print_freq': 10, 'resume': '', 'radius': 20.0, 'train_ratio': 0.6, 'train_size': None, 'val_ratio': 0.2, 'val_size': None, 'test_ratio': 0.2, 'test_size': None, 'optim': 'SGD', 'atom_fea_len': 64, 'h_fea_len': 128, 'n_conv': 3, 'n_h': 1, 'cuda': True}
Calculated nbr_fea_len: 165
atom_fea_len: 64
orig_atom_fea_len not found in args, using value from embedding layer
Model arguments to be used:
{'atom_fea_len': 64, 'nbr_fea_len': 165, 'n_conv': 3, 'h_fea_len': 128, 'n_h': 1, 'classification': False, 'orig_atom_fea_len': 92}
Model loaded successfully with filtered weights!
Model structure:
CrystalGraphConvNet

In [59]:
#每层的权重和偏置统计：
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name}:")
        print(f"  Shape: {param.shape}")
        print(f"  Mean: {param.data.mean().item():.4f}")
        print(f"  Std: {param.data.std().item():.4f}")
        print(f"  Min: {param.data.min().item():.4f}")
        print(f"  Max: {param.data.max().item():.4f}")

embedding.weight:
  Shape: torch.Size([64, 92])
  Mean: 0.0008
  Std: 0.0605
  Min: -0.1301
  Max: 0.1380
embedding.bias:
  Shape: torch.Size([64])
  Mean: -0.0088
  Std: 0.0610
  Min: -0.1051
  Max: 0.1058
convs.0.bn2.weight:
  Shape: torch.Size([64])
  Mean: 1.0001
  Std: 0.0082
  Min: 0.9857
  Max: 1.0309
convs.0.bn2.bias:
  Shape: torch.Size([64])
  Mean: -0.0014
  Std: 0.0035
  Min: -0.0091
  Max: 0.0112
convs.0.fc_full.weight:
  Shape: torch.Size([128, 169])
  Mean: 0.0000
  Std: 0.0444
  Min: -0.0769
  Max: 0.0769
convs.0.fc_full.bias:
  Shape: torch.Size([128])
  Mean: -0.0038
  Std: 0.0468
  Min: -0.0764
  Max: 0.0768
convs.0.bn1.weight:
  Shape: torch.Size([128])
  Mean: 1.0000
  Std: 0.0000
  Min: 1.0000
  Max: 1.0000
convs.0.bn1.bias:
  Shape: torch.Size([128])
  Mean: 0.0000
  Std: 0.0000
  Min: 0.0000
  Max: 0.0000
convs.1.bn2.weight:
  Shape: torch.Size([64])
  Mean: 1.0022
  Std: 0.0067
  Min: 0.9866
  Max: 1.0220
convs.1.bn2.bias:
  Shape: torch.Size([64])
  Mean: -0.0

  print(f"  Std: {param.data.std().item():.4f}")


In [62]:
#每个卷积层的输出统计：
def conv_hook(module, input, output):
    print(f"Conv layer output stats:")
    print(f"  Shape: {output.shape}")
    print(f"  Mean: {output.mean().item():.4f}")
    print(f"  Std: {output.std().item():.4f}")
    print(f"  Min: {output.min().item():.4f}")
    print(f"  Max: {output.max().item():.4f}")

for conv in model.convs:
    conv.register_forward_hook(conv_hook)

In [4]:
from torchinfo import summary

summary(model, input_size=[(32, 92), (32, 12, 41), (32, 12), (32,)])

CrystalGraphConvNetWithHooks - Input shapes:
  atom_fea: torch.Size([32, 92])
  nbr_fea: torch.Size([32, 12, 41])
  nbr_fea_idx: torch.Size([32, 12])
  crystal_atom_idx: torch.Size([32])
  After embedding, atom_fea shape: torch.Size([32, 64])
  Before conv 1, atom_fea shape: torch.Size([32, 64])
ModifiedConvLayer - Input shapes:
  atom_in_fea: torch.Size([32, 64])
  nbr_fea: torch.Size([32, 12, 41])
  nbr_fea_idx: torch.Size([32, 12])


RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Linear: 1]

In [70]:
import matplotlib.pyplot as plt
import seaborn as sns

def plot_weight_distribution(model):
    for name, param in model.named_parameters():
        if 'weight' in name:
            plt.figure(figsize=(10, 6))
            sns.histplot(param.data.cpu().numpy().flatten(), kde=True)
            plt.title(f'Weight Distribution of {name}')
            plt.xlabel('Weight Value')
            plt.ylabel('Frequency')
            plt.savefig(f'weight_dist_{name.replace(".", "_")}.png')
            plt.close()

plot_weight_distribution(model)
plt.show()

In [79]:
#特征相关性热图
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def plot_feature_correlation(features):
    # 移除常量特征和包含 NaN 的特征
    features = features.T  # 转置以便每列是一个特征
    non_constant = np.any(features != features[0, :], axis=1)
    non_nan = ~np.isnan(features).any(axis=1)
    valid_features = non_constant & non_nan
    features = features[valid_features]
    
    if features.shape[0] == 0:
        print("No valid features to compute correlation.")
        return
    
    # 计算相关性，忽略 NaN 值
    corr = np.ma.corrcoef(features)
    if isinstance(corr, np.ma.MaskedArray):
        corr = corr.filled(np.nan)  # 将掩码值替换为 NaN
    
    plt.figure(figsize=(12, 10))
    sns.heatmap(corr, cmap='coolwarm', center=0, mask=np.isnan(corr))
    plt.title('Feature Correlation Heatmap')
    plt.savefig('feature_correlation.png')
    plt.close()
    
    print(f"Correlation matrix shape: {corr.shape}")
    print(f"Number of features plotted: {features.shape[0]}")

# 使用修改后的函数
plot_feature_correlation(atom_fea.detach().cpu().numpy())

Correlation matrix shape: (34, 34)
Number of features plotted: 34


In [77]:
#CGCNN的卷积层输出特征图
def plot_conv_outputs(conv_outputs, layer_name):
    # 获取输出的形状
    if len(conv_outputs.shape) == 2:
        num_atoms, num_features = conv_outputs.shape
    elif len(conv_outputs.shape) == 3:
        batch_size, num_atoms, num_features = conv_outputs.shape
        conv_outputs = conv_outputs[0]  # 只取第一个batch
    else:
        raise ValueError(f"Unexpected shape: {conv_outputs.shape}")
    
    # 创建一个网格来显示所有原子的特征
    num_cols = min(8, num_atoms)
    num_rows = (num_atoms - 1) // num_cols + 1
    
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, 20))
    fig.suptitle(f'Feature maps of {layer_name}')
    
    for i, ax in enumerate(axes.flat):
        if i < num_atoms:
            # 为每个原子创建一个热力图
            im = ax.imshow(conv_outputs[i].detach().cpu().numpy().reshape(1, -1), 
                           aspect='auto', cmap='viridis')
            ax.set_title(f'Atom {i}')
        ax.axis('off')
    
    # 添加颜色条
    fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.5)
    
    plt.tight_layout()
    plt.savefig(f'conv_outputs_{layer_name}.png')
    plt.close()

# 在forward方法中收集卷积层输出
conv_outputs = []
def conv_hook(module, input, output):
    conv_outputs.append(output)

for conv in model.convs:
    conv.register_forward_hook(conv_hook)

# 运行模型
with torch.no_grad():
    model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)

# 绘制每个卷积层的输出
for i, output in enumerate(conv_outputs):
    plot_conv_outputs(output, f'conv_layer_{i}')

CrystalGraphConvNetWithHooks - Input shapes:
  atom_fea: torch.Size([32, 92])
  nbr_fea: torch.Size([32, 12, 41])
  nbr_fea_idx: torch.Size([32, 12])
  crystal_atom_idx: torch.Size([32])
  After embedding, atom_fea shape: torch.Size([32, 64])
  Before conv 1, atom_fea shape: torch.Size([32, 64])
ModifiedConvLayer - Input shapes:
  atom_in_fea: torch.Size([32, 64])
  nbr_fea: torch.Size([32, 12, 41])
  nbr_fea_idx: torch.Size([32, 12])
  total_nbr_fea shape: torch.Size([32, 12, 169])
  fc_full input shape: torch.Size([384, 169])
  fc_full weight shape: torch.Size([128, 169])
  After conv 1, atom_fea shape: torch.Size([32, 64])
  Before conv 2, atom_fea shape: torch.Size([32, 64])
ModifiedConvLayer - Input shapes:
  atom_in_fea: torch.Size([32, 64])
  nbr_fea: torch.Size([32, 12, 41])
  nbr_fea_idx: torch.Size([32, 12])
  total_nbr_fea shape: torch.Size([32, 12, 169])
  fc_full input shape: torch.Size([384, 169])
  fc_full weight shape: torch.Size([128, 169])
  After conv 2, atom_fea sha

  plt.tight_layout()
