In [4]:
!pip install grad-cam

Collecting grad-cam
  Downloading grad-cam-1.5.2.tar.gz (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m47.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting ttach (from grad-cam)
  Downloading ttach-0.0.3-py3-none-any.whl.metadata (5.2 kB)
Downloading ttach-0.0.3-py3-none-any.whl (9.8 kB)
Building wheels for collected packages: grad-cam
  Building wheel for grad-cam (pyproject.toml) ... [?25ldone
[?25h  Created wheel for grad-cam: filename=grad_cam-1.5.2-py3-none-any.whl size=38335 sha256=76ed2208e543e3895c7bb84329ce2a6f42cae98ee2526ddbcc5ace6550f968c1
  Stored in directory: /root/.cache/pip/wheels/28/25/dd/cf5dc1751e3d5b89ea4d877a61ba969939c78cf4223ace9c59
Successfully built grad-cam
Installing collected packages: ttach, grad-cam
Successfully installed grad-ca

In [11]:
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
from scipy import ndimage

class CNN3D(nn.Module):
    def __init__(self):
        super(CNN3D, self).__init__()
        self.conv1 = nn.Conv3d(1, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool3d(kernel_size=2)
        self.bn1 = nn.BatchNorm3d(64)

        self.conv2 = nn.Conv3d(64, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool3d(kernel_size=2)
        self.bn2 = nn.BatchNorm3d(64)

        self.conv3 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool3d(kernel_size=2)
        self.bn3 = nn.BatchNorm3d(128)

        self.conv4 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool3d(kernel_size=2)
        self.bn4 = nn.BatchNorm3d(256)

        self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc1 = nn.Linear(256, 512)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(512, 1)

    def forward(self, x):
        x = F.relu(self.bn1(self.pool1(self.conv1(x))))
        x = F.relu(self.bn2(self.pool2(self.conv2(x))))
        x = F.relu(self.bn3(self.pool3(self.conv3(x))))
        x = F.relu(self.bn4(self.pool4(self.conv4(x))))
        x = self.global_pool(x)
        x = x.view(-1, 256)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.sigmoid(self.fc2(x))
        return x

# Create the model
model = CNN3D()
model.load_state_dict(torch.load('3d_image_classification.pth'))
model.eval()


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


# Define the target layer for Grad-CAM
target_layers = [model.conv4]

# Create the Grad-CAM object
cam = GradCAM(model=model, target_layers=target_layers)

# Load the input 3D image
# Load the NIfTI file
output_dir = os.path.join(os.getcwd(), "MosMedData")
input_nifti_path = os.path.join(output_dir, "CT-23", "study_0939.nii.gz")
output_nifti_path =os.path.join(output_dir, "output", "study_0939.nii.gz")

def read_nifti_file(filepath):
    scan = nib.load(filepath)
    scan = scan.get_fdata()
    return scan

def normalize(volume):
    min, max = -1000, 400
    volume[volume < min] = min
    volume[volume > max] = max
    volume = (volume - min) / (max - min)
    volume = volume.astype("float32")
    return volume

def resize_volume(img):
    desired_depth, desired_width, desired_height = 64, 128, 128
    current_depth, current_width, current_height = img.shape[-1], img.shape[0], img.shape[1]
    depth_factor, width_factor, height_factor = current_depth / desired_depth, current_width / desired_width, current_height / desired_height
    depth_factor, width_factor, height_factor = 1 / depth_factor, 1 / width_factor, 1 / height_factor
    img = ndimage.rotate(img, 90, reshape=False)
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
    return img

def process_scan(path):
    volume = read_nifti_file(path)
    volume = normalize(volume)
    volume = resize_volume(volume)
    return volume


# Reshape the data
data = process_scan(input_nifti_path)
# data = data.reshape((data.shape[0], data.shape[1], data.shape[2]))
## Convert to Torch Tensor
# input_tensor = torch.from_numpy(data)
input_tensor = torch.tensor(data).unsqueeze(0).unsqueeze(0).to(device, dtype=torch.float32)
print(input_tensor.shape)

# Compute the Grad-CAM
targets = [ClassifierOutputTarget(0)]
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)

# Visualize the Grad-CAM
visualization = show_cam_on_image(input_tensor, grayscale_cam, use_rgb=True)

# Save the visualization
# torch.save(visualization, 'grad_cam_visualization.pth')
plt.imshow(visualization)
plt.show()

torch.Size([1, 1, 128, 128, 64])


error: OpenCV(4.8.0) :-1: error: (-5:Bad argument) in function 'resize'
> Overload resolution failed:
>  - Can't parse 'dsize'. Expected sequence length 2, got 3
>  - Can't parse 'dsize'. Expected sequence length 2, got 3
