In [None]:
from gradcam import GradCAM, GradCAMpp
from gradcam.utils import visualize_cam
from torchvision.utils import make_grid, save_image
from PIL import Image, ImageDraw, ImageFont
import glob
import os
import torch.nn.functional as F
import torch

# example
save_path = 'best_model.pth'
best_loss_model = models.resnet152(pretrained=True)
num_ftrs = best_loss_model.fc.in_features
best_loss_model.fc = nn.Linear(num_ftrs, 2)
best_loss_model.to(device)
best_loss_model.eval()
best_loss_model.load_state_dict(torch.load(save_path))

target_layer = best_loss_model.layer4[-1]
gradcam = GradCAM(best_loss_model, target_layer)
gradcam_pp = GradCAMpp(best_loss_model, target_layer)

# example
image_path = 'test_folder/*'
images = []
save_images = []
# example
save_dir = 'test_gradcam'
os.makedirs(save_dir, exist_ok=True)

actual_label = 0

classifications = []

for i, path in enumerate(glob.glob(image_path)):
    img = Image.open(path)
    torch_img = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor()
    ])(img).to(device)

    normed_torch_img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(torch_img)[None]

    with torch.no_grad():
        output = best_loss_model(normed_torch_img)
        pred_probabilities = F.softmax(output, dim=1)
        pred_label = torch.argmax(pred_probabilities, dim=1).item()
        confidence = pred_probabilities[0, pred_label].item()

    mask, _ = gradcam(normed_torch_img)
    heatmap, result = visualize_cam(mask, torch_img)

    mask_pp, _ = gradcam_pp(normed_torch_img)
    heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)

    black_image = Image.new("RGB", (224, 224), "gray")
    draw = ImageDraw.Draw(black_image)

    font_size = 15
    font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size)

    classification_text = f"Pred: {pred_label}, Conf: {confidence:.2f}"

    draw.text((10, 100), classification_text, font=font, fill=(255, 255, 255))

    black_image = transforms.ToTensor()(black_image)

    images.extend([torch_img.cpu(), heatmap_pp, result_pp, black_image])
    save_images.extend([torch_img.cpu(), heatmap_pp, result_pp, black_image])

    classifications.append({
        'index': i,
        'pred_label': pred_label,
        'actual_label': actual_label,
        'confidence': confidence
    })

    if len(save_images) == 16:
        from torchvision.utils import save_image
        grid_image = make_grid(save_images, nrow=4)
        output_path = f"gradcam_images_({i}).png"
        save_image(grid_image, output_path)

        print(f"グリッド画像を {output_path} に保存しました。")
        save_images = []

grid_image = make_grid(images, nrow=4)
transforms.ToPILImage()(grid_image).show()
