In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import cv2
import matplotlib.pyplot as plt
from model import DeepLabV3Plus  # Import your model

# 1. GradCAM Hook Class
class GradCAMHook:
    def __init__(self):
        self.activations = []
        self.gradients = []
    
    def forward_hook(self, module, input, output):
        self.activations.append(output.detach())
    
    def backward_hook(self, module, grad_input, grad_output):
        self.gradients.append(grad_output[0].detach())

In [5]:
import torch
from model.semseg.deeplabv3plus import DeepLabV3Plus  # Đảm bảo import đúng model của bạn
import torch.serialization
import numpy as np
# Load config và khởi tạo model
cfg = {
    'backbone': 'resnet101',
    'replace_stride_with_dilation': [False, True, True],
    'dilations': [12, 24, 36],
    'nclass': 21  # Số class của bạn
}
model = DeepLabV3Plus(cfg)

# Load checkpoint
checkpoint = torch.load('best.pth', weights_only=False)  # Chỉ dùng nếu file từ nguồn đáng tin
model.load_state_dict(checkpoint['model'])
model.eval()  # Chế độ evaluation

DeepLabV3Plus(
  (backbone): ResNet(
    (conv1): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runnin

In [6]:
hook = GradCAMHook()
target_layer = model.fuse[0]  # Layer đầu tiên trong fuse

# Đăng ký hook
forward_handle = target_layer.register_forward_hook(hook.forward_hook)
backward_handle = target_layer.register_backward_hook(hook.backward_hook)

In [15]:
def compute_gradcam(activations, gradients):
    if not activations or not gradients:
        return None
    
    activation = activations[-1]  # Shape: [B, C, H, W]
    gradient = gradients[-1]     # Shape: [B, C, H, W]
    
    # Global average pooling of gradients
    weights = gradient.mean(dim=(2, 3), keepdim=True)  # [B, C, 1, 1]
    
    # Weighted combination of activations
    heatmap = (weights * activation).sum(dim=1)  # [B, H, W]
    heatmap = F.relu(heatmap).squeeze()  # [H, W] or [B, H, W]
    
    # Convert to numpy and normalize
    heatmap = heatmap.cpu().numpy()
    heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
    
    return heatmap


In [21]:
# 4. Register Hook to Decoder's Layer1 (first conv in fuse)
hook = GradCAMHook()
target_layer = model.fuse[0]  # First conv in fuse
forward_handle = target_layer.register_forward_hook(hook.forward_hook)
backward_handle = target_layer.register_backward_hook(hook.backward_hook)

# 5. Process Image and Get Heatmaps
def process_image(image_path):
    # Load and preprocess image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (321, 321))
    image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float()
    
    # Get heatmaps for both modes
    with torch.set_grad_enabled(True):
        # need_fp=False
        hook.activations.clear()
        hook.gradients.clear()
        output = model(image_tensor, need_fp=False)
        pred_class = output.mean(dim=(2, 3)).argmax().item()
        output[0, pred_class].sum().backward()  # Đã sửa
        heatmap_fp_false = compute_gradcam(hook.activations, hook.gradients)
        
        # need_fp=True
        hook.activations.clear()
        hook.gradients.clear()
        _, output_fp = model(image_tensor, need_fp=True)
        pred_class_fp = output_fp.mean(dim=(2, 3)).argmax().item()
        output_fp[0, pred_class_fp].sum().backward()  # Đã sửa
        heatmap_fp_true = compute_gradcam(hook.activations, hook.gradients)
    
    # Handle multi-channel heatmap (if exists)
    if isinstance(heatmap_fp_true, np.ndarray) and heatmap_fp_true.ndim == 2:
        pass  # Already correct shape
    elif heatmap_fp_true.ndim == 3:  # If [2, H, W]
        heatmap_fp_true = heatmap_fp_true[0]  # Take first channel
    
    return image, heatmap_fp_false, heatmap_fp_true

In [22]:
def visualize_results(image, heatmap_fp_false, heatmap_fp_true):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Original Image
    axes[0].imshow(image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # need_fp=False Heatmap
    axes[1].imshow(image, alpha=0.7)
    axes[1].imshow(heatmap_fp_false, cmap='jet', alpha=0.3)
    axes[1].set_title('GradCAM (need_fp=False)')
    axes[1].axis('off')
    
    # need_fp=True Heatmap
    axes[2].imshow(image, alpha=0.7)
    axes[2].imshow(heatmap_fp_true, cmap='jet', alpha=0.3)
    axes[2].set_title('GradCAM (need_fp=True)')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

In [23]:
if __name__ == "__main__":
    image_path = "pascal_data/JPEGImages/2007_000027.jpg"  # Change to your image path
    image, heatmap_fp_false, heatmap_fp_true = process_image(image_path)
    visualize_results(image, heatmap_fp_false, heatmap_fp_true)
    
    # Remove hooks when done
    forward_handle.remove()
    backward_handle.remove()

: 