In [None]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
from model.semseg.deeplabv3plus import DeepLabV3Plus
import torch
import torch.functional as F
import numpy as np
from pytorch_grad_cam import GradCAM
import matplotlib.pyplot as plt
from explainer.semantic_segmentation_target import SemanticSegmentationTarget, GroundTruthSegmentationTarget
from explainer.segmentation_output_wrapper import SegmentationModelOutputWrapper

In [4]:
num_classes = 21

In [5]:
def compute_gradcam_heatmap(model, target_layer, input_tensor, masks):
    # input_tensor.requires_grad = True  # Already set in the forward method
    cam = GradCAM(model=model, target_layers=[target_layer])

    targets = []
    for i in range(input_tensor.size(0)):
        target = SemanticSegmentationTarget(category=None, mask=masks[i])
        targets.append(target)

    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
    return grayscale_cam

In [6]:
import segmentation_models_pytorch as smp
model = smp.DeepLabV3Plus(
    encoder_name="resnet101",        # lựa chọn backbone, ví dụ resnet101
    encoder_weights="imagenet",      # sử dụng trọng số đã huấn luyện trên ImageNet
    in_channels=3,                   # số kênh đầu vào (thường là 3 cho ảnh RGB)
    classes=21                       # số lớp cần dự đoán, ví dụ 21 lớp cho Pascal VOC
)

Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /home/loc11/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth
100%|██████████| 170M/170M [00:04<00:00, 40.6MB/s] 


In [7]:
model

DeepLabV3Plus(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequentia