In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os
from pathlib import Path


# 模型结构
class CrescentModel(nn.Module):
    def __init__(self, num_classes=2):
        super(CrescentModel, self).__init__()
        self.backbone = models.resnet18(pretrained=False)
        num_ftrs = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Linear(num_ftrs, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

# Grad-CAM
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_layers()

    def hook_layers(self):
        def forward_hook(module, input, output):
            self.activations = output.detach()

        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0].detach()

        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)

    def generate(self, input_image, target_class=None):
        self.model.eval()
        output = self.model(input_image)

        if target_class is None:
            target_class = output.argmax(dim=1).item()

        self.model.zero_grad()
        loss = output[0, target_class]
        loss.backward()

        weights = self.gradients.mean(dim=[2, 3], keepdim=True)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = torch.relu(cam)
        cam = cam.squeeze().cpu().numpy()
        cam = cv2.resize(cam, (input_image.size(3), input_image.size(2)))
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        return cam

# 数据预处理
data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CrescentModel(num_classes=2).to(device)
model.load_state_dict(torch.load("C.pth", map_location=device)) # 你的模型路径

target_layer = model.backbone.layer4[1].conv2
grad_cam = GradCAM(model, target_layer)

# 指定图片文件夹路径和输出文件夹路径
image_folder = os.path.join(os.getcwd(), "Crescent")
output_folder = "crescentheatmap"    # 输出结果文件夹

# 创建输出文件夹
os.makedirs(output_folder, exist_ok=True)

# 支持的图片格式
supported_formats = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}

# 处理文件夹中的所有图片
for image_file in os.listdir(image_folder):
    file_path = os.path.join(image_folder, image_file)
    file_ext = Path(image_file).suffix.lower()
    
    # 检查是否为支持的图片格式
    if file_ext in supported_formats and os.path.isfile(file_path):
        try:
            print(f"Processing: {image_file}")
            
            # 读取图片
            img_pil = Image.open(file_path).convert('RGB')
            input_tensor = data_transform(img_pil).unsqueeze(0).to(device)

            # 生成 Grad-CAM
            cam = grad_cam.generate(input_tensor)

            # 准备可视化结果
            img_cv = cv2.cvtColor(np.array(img_pil.resize((224, 224))), cv2.COLOR_RGB2BGR)
            heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
            heatmap[:, :, 0] = np.clip(heatmap[:, :, 0] + 80, 0, 255)
            overlay = cv2.addWeighted(img_cv, 0.5, heatmap, 0.5, 0)

            # 生成输出文件名（不含扩展名）
            file_name = Path(image_file).stem
            
            # 保存叠加图像
            overlay_path = os.path.join(output_folder, f"{file_name}_overlay.jpg")
            cv2.imwrite(overlay_path, overlay)

            # 创建并保存可视化图表
            fig, axs = plt.subplots(1, 3, figsize=(12, 4))

            axs[0].imshow(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
            axs[0].set_title("Original")
            axs[0].axis("off")

            axs[1].imshow(cam, cmap='jet')
            axs[1].set_title("Grad-CAM")
            axs[1].axis("off")

            axs[2].imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
            axs[2].set_title("Overlay")
            axs[2].axis("off")

            plt.tight_layout()
            result_path = os.path.join(output_folder, f"{file_name}_result.jpg")
            plt.savefig(result_path, dpi=300)
            plt.close(fig)  # 关闭图形以释放内存
            
            print(f"Saved results for {image_file}")
            
        except Exception as e:
            print(f"Error processing {image_file}: {str(e)}")

print("All images processed successfully!")