In [65]:
import os
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import torch
from torch import nn
from torchvision import models, transforms
import torchvision.transforms.functional as TF
import pydicom as dicom

### This code sets up a ResNet50 model, loads pre-trained weights, splits the model into feature extractor and classifier, and then moves these components to CPU and GPU, respectively.

In [2]:
T = 1

# Initialize network

net = models.resnet50(pretrained=False)
net.fc = nn.Linear(2048, T * 2)  # T * 2 for mean-variance

# Load snapshot
#define the model path 
PATH = '/snapshots/snapshot_10000.pth.tar'
snapshot = torch.load(PATH, map_location={"cuda": "cpu"})
net.load_state_dict(snapshot['state_dict'])

# Define a custom Flatten layer
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

# Split model into two parts: features and classifier
arch = net.__class__.__name__
print(arch)

if arch == 'ResNet':
    features_fn = nn.Sequential(*list(net.children())[:-2]) # -2 is Stage 4/layer 4 of ResNet-50 (for different stages/layers use -5, -4,-3,-2 stage 1,2,3,4)
    classifier_fn = nn.Sequential(*(list(net.children())[-2:-1] + [Flatten()] + list(net.children())[-1:])) # -2 define stage 4

# Move the feature extractor to CPU
features_fn = features_fn.eval().cpu()

# Move the classifier to GPU
classifier_fn = classifier_fn.eval().cuda()

# Move the whole network to GPU
net = net.eval().cuda()

In [289]:
net

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [290]:
features_fn

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [291]:
classifier_fn

Sequential(
  (0): AdaptiveAvgPool2d(output_size=(1, 1))
  (1): Flatten()
  (2): Linear(in_features=2048, out_features=2, bias=True)
)

## Grad-CAM

In [281]:
def GradCAM(img, c, features_fn, classifier_fn):
    # Move image to GPU
    img = img.cuda()

    # Ensure that the image tensor requires gradient computation
    img.requires_grad_(True)

    # Extract features from the specified layer
    feats = features_fn(img)
    
    # Ensure that the features tensor also requires gradient computation
    feats.requires_grad_(True)

    # Get the spatial dimensions of the features
    _, N, H, W = feats.size()

    # Pass the features through the classifier
    out = classifier_fn(feats)
    
    # Get the score for the target class c
    c_score = out[0, c]
    
    # Compute gradients of the class score with respect to features
    grads = torch.autograd.grad(c_score, feats, create_graph=True)[0]
    
    # Global average pooling to obtain weights
    w = grads.mean((2, 3))

    # Perform a weighted sum of the feature maps
    sal = torch.matmul(w, feats.view(N, H * W))
    
    # Reshape and process the saliency map
    sal = sal.view(H, W).cpu().detach().numpy()
    sal = np.maximum(sal, 0)
    
    return sal

## The code uses GradCAM to generate heatmaps highlighting important regions, and it saves the original image alongside the heatmap for visual interpretation. It assumes the existence of a pre-trained neural network (net), which is used to obtain the class probabilities and apply GradCAM(output as Jpg).

In [285]:

#Define Input and Output Directories:
input_dir = "/Input_path/"

output_dir_path ="/Output_path/"

output_dir = f"{output_dir_path}/jpg_{output_dir_path.split('/')[-1]}"

# Ensure the output directory exists.
os.makedirs(output_dir, exist_ok=True)


#Image Preprocessing
img_dim = 512

default_transformation = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((img_dim, img_dim))
])

# Function to apply the transformations and get the tensor image.
def preprocess_image(image_path):
    img = dicom.dcmread(image_path)
    image_pix = img.pixel_array
    image = Image.fromarray(image_pix)
    image = default_transformation(image)
    gray_to_rgb = np.stack([image] * 3, axis=2)
    tensor_img = TF.to_tensor(gray_to_rgb)
    tensor_unsqueeze_img = torch.unsqueeze(tensor_img, 0)
    return tensor_unsqueeze_img.cuda()

# List all image files in the input directory.
image_files = [f for f in os.listdir(input_dir) if f.endswith('.dcm')]

for image_file in image_files:
    image_path = os.path.join(input_dir, image_file)

    # Use PIL to open the images and masks.
    img = dicom.dcmread(image_path)
    image_pix=img.pixel_array

    image=Image.fromarray(image_pix)

    img_tensor = preprocess_image(image_path)
    print(torch.min(img_tensor),torch.max(img_tensor))

    pp, cc = torch.topk(nn.Softmax(dim=1)(net(img_tensor)), 2)

    # Process the GradCAM and resize the image.
    class_index = int(cc[0][0])  # Extract the index of the top class.
   # print(class_index)
    
    sal = GradCAM(img_tensor, class_index, features_fn, classifier_fn)

    #img_resize = Image.open(image_path).resize((256, 256))
    img_resize = image.resize((img_dim, img_dim))
    
    sal = Image.fromarray(sal)
    sal = sal.resize(img_resize.size, resample=Image.LINEAR)

    # Save the processed image with a different filename for each input image.
    output_path = os.path.join(output_dir, f'scapis_{image_file[:-4]}.jpg')
    
    # Save the processed image with a different filename for each input image.
    output_path = os.path.join(output_dir, f'scapis_{image_file[:-4]}.jpg')
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img_resize,cmap='gray')
    plt.axis('off')
    plt.subplot(1, 2, 2)
    plt.imshow(img_resize)
    plt.imshow(np.array(sal), alpha=0.5, cmap='jet')
    plt.axis('off')
    plt.savefig(output_path)
    plt.close()

tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='c

tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='c

tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='c

tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')
tensor(0., device='cuda:0') tensor(1., device='cuda:0')


## This part of code applies GradCAM to DICOM images, blends the GradCAM visualization with the original image, and saves the result as new DICOM files in the specified output directory. The blending and normalization steps are applied to enhance the visualization 

In [44]:
import os
import pydicom
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt


# Define input and output directories.
input_dir = input_dir
output_dir =f"{output_dir}/../8bit_gradcam"

# Ensure the output directory exists.
os.makedirs(output_dir, exist_ok=True)

default_transformation = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((img_dim, img_dim))
])

# Function to apply the transformations and get the tensor image.
def preprocess_image(image_path):
    img = dicom.dcmread(image_path)
    image_pix = img.pixel_array
    image = Image.fromarray(image_pix)
    image = default_transformation(image)
    gray_to_rgb = np.stack([image] * 3, axis=2)
    tensor_img = TF.to_tensor(gray_to_rgb)
    tensor_unsqueeze_img = torch.unsqueeze(tensor_img, 0)
    return tensor_unsqueeze_img.cuda()

# List all image files in the input directory.
image_files = [f for f in os.listdir(input_dir) if f.endswith('.dcm')]

for image_file in image_files:
    image_path = os.path.join(input_dir, image_file)

    # Use PIL to open the images and masks.
    img = dicom.dcmread(image_path)
    image_pix = img.pixel_array
    image = Image.fromarray(image_pix)

    img_tensor = preprocess_image(image_path)
    pp, cc = torch.topk(nn.Softmax(dim=1)(net(img_tensor)), 2)

    # Process the GradCAM and resize the image.
    class_index = int(cc[0][0])  # Extract the index of the top class.
    sal = GradCAM(img_tensor, class_index, features_fn, classifier_fn)

    img_resize = image.resize((img_dim, img_dim))
    sal = Image.fromarray(sal)
    sal = sal.resize(img_resize.size, resample=Image.LINEAR)

    # Convert the PIL images to numpy arrays and normalize pixel values.
    img_np = np.array(img_resize)
    sal_np = np.array(sal)
    #sal_np = (sal_np / np.max(sal_np) * 255).astype(np.uint8)
    
    
    max_sal_np = np.max(sal_np)
    if max_sal_np != 0:
        sal_np = (sal_np / max_sal_np * 255).astype(np.uint8)
    else:
        sal_np = sal_np.astype(np.uint8)  # Or any other appropriate handling for zero max value.
        
        
    # Apply alpha blending (alpha = 0.5) to combine the original image and the GradCAM visualization.
    blended_img = Image.fromarray((0.5 * img_np + 0.5 * sal_np).astype(np.uint8))

    # Create a new DICOM file from the original DICOM header and the blended image.
    new_dcm = pydicom.dcmread(image_path)
    new_dcm.PixelData = blended_img.tobytes()
    new_dcm.Rows, new_dcm.Columns = blended_img.size
    new_dcm.save_as(os.path.join(output_dir, f'{image_file[:-4]}.dcm'))