In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from PIL import Image as PILImage
import os

import torchvision.models
from torchvision.models.vgg import VGG16_Weights

vgg = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)
vgg.eval()  # 设置为评估模式

# # 加载预训练的 VGG 模型
# vgg = torchvision.models.vgg16(pretrained=True)
# vgg.eval()  # 设置为评估模式

# 初始化一个字典来存储每层的输出
vgg_hooks = {}

# 注册钩子函数到每一层
def register_hooks(model, hooks_dict):
    def hook(module, input, output):
        # 使用模块的名字作为键
        module_name = str(module)
        if module_name not in hooks_dict:
            hooks_dict[module_name] = []
        hooks_dict[module_name].append(output)
    
    # 注册钩子到所有的卷积层
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            module.register_forward_hook(hook)

# 注册钩子
register_hooks(vgg.features, vgg_hooks)


# 图片预处理
img_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  # 使用ImageNet的均值和标准差
])

# 加载图片
img = Image.open("lena.png").convert("RGB")
img_tensor = img_transform(img).unsqueeze(0)  # 添加批次维度

# 获取模型的预测结果
scores = vgg(img_tensor)
proba = torch.softmax(scores, dim=1)
top5 = torch.topk(proba, k=5, dim=1)

# 打印预测结果
print("Image Name:", "lena.png")
print("Top 5 Predictions:", top5)

# 可视化中间层的输出
output_dir = "./vgg_feature_maps"  # 定义输出目录
os.makedirs(output_dir, exist_ok=True)  # 创建输出目录

# 收集所有channel的feature maps
all_features = []
print (f" len vgg_hooks: {len(vgg_hooks)}")
for module_name, features_list in vgg_hooks.items():
    print ("module_name: ", module_name)
    # print ("features_list: ", features_list)
    # break
    for features in features_list:
        n, c, h, w = features.shape
        print (f"n, c, h, w: {n, c, h, w}")
        channel_images = []
        
        for i in range(c):  # 遍历每个channel
            feature_map = features[0, i].cpu().detach()  # 获取单个channel的feature map
            feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min())  # 归一化
            feature_map = (feature_map * 255).byte()  # 转换为0-255的灰度值
            feature_map_pil = PILImage.fromarray(feature_map.numpy()).convert('L')
            
            # 保存单个channel的图片
            feature_map_pil.save(os.path.join(output_dir, f"layer_{module_name.replace('.', '_')}_channel_{i}.png"))
            
            # 将图片添加到列表中用于后续拼接
            channel_images.append(feature_map_pil)
        
        # 拼接所有channel的图片
        combined_image = PILImage.new('RGB', (w * c, h))  # 水平拼接
        for i, img in enumerate(channel_images):
            combined_image.paste(img.convert('RGB'), (w * i, 0))
        
        # 保存拼接后的图片
        all_features.append(combined_image)
        combined_image.save(os.path.join(output_dir, f"layer_{module_name.replace('.', '_')}_combined.png"))

# 如果有多个层，可以进一步将它们拼接起来
if len(all_features) > 1:
    final_combined = PILImage.new('RGB', (all_features[0].width, sum([img.height for img in all_features])))  # 垂直拼接
    y_offset = 0
    for img in all_features:
        final_combined.paste(img, (0, y_offset))
        y_offset += img.height
    final_combined.save(os.path.join(output_dir, "final_combined.png"))

Image Name: lena.png
Top 5 Predictions: torch.return_types.topk(
values=tensor([[0.1427, 0.0897, 0.0712, 0.0469, 0.0379]], grad_fn=<TopkBackward0>),
indices=tensor([[552, 452, 808, 515, 903]]))
 len vgg_hooks: 8
module_name:  Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
n, c, h, w: (1, 64, 224, 224)
module_name:  Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
n, c, h, w: (1, 64, 224, 224)
module_name:  Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
n, c, h, w: (1, 128, 112, 112)
module_name:  Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
n, c, h, w: (1, 128, 112, 112)
module_name:  Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
n, c, h, w: (1, 256, 56, 56)
module_name:  Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
n, c, h, w: (1, 256, 56, 56)
n, c, h, w: (1, 256, 56, 56)
module_name:  Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
n, c, h, w: