In [None]:
#from segment_anything import SamPredictor, sam_model_registry
from models.sam import SamPredictor, sam_model_registry
from models.sam.utils.transforms import ResizeLongestSide
from skimage.measure import label
from models.sam_LoRa import LoRA_Sam
#Scientific computing 
import numpy as np
import os
#Pytorch packages
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import datasets
#Visulization
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
#Others
from torch.utils.data import DataLoader, Subset
from torch.autograd import Variable
import matplotlib.pyplot as plt
import copy
from utils.dataset import Public_dataset
import torch.nn.functional as F
from torch.nn.functional import one_hot
from pathlib import Path
from tqdm import tqdm
from utils.losses import DiceLoss
from utils.dsc import dice_coeff
import cv2
import monai
from utils.utils import vis_image,inverse_normalize,torch_percentile
from argparse import Namespace
import cfg
import PIL
import torchio as tio
import json
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

arch="vit_t"  # Change this value as needed
finetune_type="adapter"
dataset_name="MRI-Prostate"  # Assuming you set this if it's dynamic

# Construct the checkpoint directory argument
checkpoint_dir= f"2D-SAM_{arch}_encoderdecoder_{finetune_type}_{dataset_name}_noprompt"

args_path = f"{checkpoint_dir}/args.json"

# Reading the args from the json file
with open(args_path, 'r') as f:
    args_dict = json.load(f)


# Converting dictionary to Namespace
args = Namespace(**args_dict)

sam_fine_tune = sam_model_registry[args.arch](args,checkpoint=os.path.join(args.dir_checkpoint,'checkpoint_best.pth'),num_classes=args.num_cls)
sam_fine_tune = sam_fine_tune.to('cuda').eval()

## evaluate a image volume and save predictions

In [None]:
def evaluate_1_volume(image_vol, model, slice_id=None, target_spacing=None, orientation=['L', 'P', 'S']):
    """
    Evaluates a single volume using the provided model.

    Parameters:
    - image_vol: The volume of images to be evaluated.
    - model: The model used for evaluation.
    - slice_id: The specific slice to evaluate. If None, the middle slice is chosen.
    - target_spacing: Desired spacing for resampling the volume. If None, no resampling is performed.
    - orientation: Orientation for flipping the image, default is ['L', 'P', 'S'].

    Returns:
    - ori_img: The original image after normalization.
    - pred: The prediction from the model.
    - voxel_spacing: The spacing of the voxels in the image volume.
    - Pil_img: The PIL image of the selected slice.
    - slice_id: The ID of the evaluated slice.
    """
    
    # Normalize the volume data to [0, 1] range
    image_vol.data = image_vol.data / (image_vol.data.max() * 1.0)
    voxel_spacing = image_vol.spacing
    
    # Resample the volume to target spacing if specified
    if target_spacing and (voxel_spacing != target_spacing):
        resample = tio.Resample(target_spacing, image_interpolation='nearest')
        image_vol = resample(image_vol)
    
    # Extract the first channel of the volume data
    image_vol = image_vol.data[0]
    slice_num = image_vol.shape[2]
    
    # Determine the slice to be evaluated
    if slice_id is not None:
        if slice_id > slice_num:
            slice_id = -1  # Use the last slice if specified slice if out of range
    else:
        slice_id = slice_num // 2  # Default to the middle slice if not specified
    
    # Get the 2D slice image
    img_arr = image_vol[:, :, slice_id]
    
    # Normalize the slice image to [0, 255] range and convert to uint8
    img_arr = np.array((img_arr - img_arr.min()) / (img_arr.max() - img_arr.min() + 1e-8) * 255, dtype=np.uint8)
    
    # Convert the single channel image to 3 channels by duplicating the single channel
    img_3c = np.tile(img_arr[:, :, None], [1, 1, 3])
    img = Image.fromarray(img_3c, 'RGB')
    Pil_img = img.copy()
    
    # Resize the image to 1024x1024
    img = transforms.Resize((1024, 1024))(img)
    
    # Transform the image to a tensor and normalize
    transform_img = transforms.Compose([
        transforms.ToTensor(),
    ])
    img = transform_img(img)
    imgs = torch.unsqueeze(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img), 0).cuda()
    
    # Flip the image along the vertical axis if orientation is 'R'
    if orientation[0] == 'R':
        imgs = torch.flip(imgs, dims=[3])
    
    # Perform model inference without gradient calculation
    with torch.no_grad():
        # Get image embeddings from the image encoder
        img_emb = model.image_encoder(imgs)
        
        # Get sparse and dense embeddings from the prompt encoder
        sparse_emb, dense_emb = model.prompt_encoder(
            points=None,
            boxes=None,
            masks=None,
        )
        
        # Get the prediction from the mask decoder
        pred, _ = model.mask_decoder(
            image_embeddings=img_emb,
            image_pe=model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_emb,
            dense_prompt_embeddings=dense_emb,
            multimask_output=True,
        )
        
        # Get the most likely prediction
        pred = pred.argmax(dim=1)
    
    # Get the original image after normalization
    ori_img = inverse_normalize(imgs.cpu()[0])
    
    return ori_img, pred, voxel_spacing, Pil_img, slice_id

In [None]:
# Example target mapping (you need to define by yourself under the format of: label:value; ...
target_mapping = {
    'tumor': 1
}

# Example usage
test_dir = 'Volumes_to_test/'
predict_dir = 'Predict_results/'

# Retrieve case IDs from the directory names in test volume dir
# Assume you have your image voumes saved in the following architecture
# test_dir
# |_ case_id1
#    |_image.nii.gz
#    |_gt.nii.gz (not necessary)
# |_ case_id2
#    |_image.nii.gz
#    |_gt.nii.gz (not necessary)

case_ids = [d for d in os.listdir(test_dir) if os.path.isdir(os.path.join(test_dir, d))]


for case_id in case_ids:
    case_dir = os.path.join(test_dir, case_id)
    nii_files = glob.glob(os.path.join(case_dir, '*.nii.gz'))
    if len(nii_files) == 0:
        print(f"No .nii.gz files found in {case_dir}")
        continue
    
    image_file = nii_files[0]  # Assuming there is only one .nii.gz file per case
    image1_vol = tio.ScalarImage(image_file)
    print(f'Processing {image_file}')
    print('Volume shape: %s Volume spacing: %s' % (image1_vol.shape, image1_vol.spacing))


    # Define the percentiles
    lower_percentile = 0
    upper_percentile = 100
    image_tensor = image1_vol.data
    lower_bound = torch_percentile(image_tensor, lower_percentile)
    upper_bound = torch_percentile(image_tensor, upper_percentile)

    # Clip and normalize the data
    image_tensor = torch.clamp(image_tensor, lower_bound, upper_bound)
    image_tensor = (image_tensor - lower_bound) / (upper_bound - lower_bound)
    image1_vol.set_data(image_tensor)

    mask_vol_numpy = np.zeros(image1_vol.shape)
    id_list = list(range(image1_vol.shape[3]))
    for id in id_list:
        ori_img, pred_1, voxel_spacing1, Pil_img1, slice_id1 = evaluate_1_volume(image1_vol, sam_fine_tune, slice_id=id, orientation=image1_vol.orientation)
        img1_size = Pil_img1.size
        mask_pred_1 = ((pred_1).cpu()).float()   
        pil_mask1 = Image.fromarray(np.array(mask_pred_1[0], dtype=np.uint8), 'L').resize(img1_size, resample=PIL.Image.NEAREST)
        mask_vol_numpy[0, :, :, id] = np.asarray(pil_mask1)
    
    # Convert to tensor and handle orientation
    mask_tensor = torch.tensor(mask_vol_numpy, dtype=torch.int)
    if image1_vol.orientation[0] == 'R':
        mask_tensor = torch.flip(mask_tensor, dims=[2])

    # Save each organ mask
    mask_save_folder = os.path.join(predict_dir, case_id,'predictions')
    Path(mask_save_folder).mkdir(parents=True, exist_ok=True)

    # Save the combined mask volume
    combined_mask_vol = tio.LabelMap(tensor=mask_tensor, affine=image1_vol.affine)
    combined_mask_filename = os.path.join(predict_dir, case_id, 'pred_mask.nii.gz')
    Path(os.path.dirname(combined_mask_filename)).mkdir(parents=True, exist_ok=True)
    combined_mask_vol.save(combined_mask_filename)

    print(f"Combined mask saved for case: {case_id}")

## visualize results

In [None]:
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os

# Set the paths to the directories containing the image volumes and masks
test_path = '../AbdomenAtlasTest'
predict_dir = '../AbdomenAtlasPredict'

def read_image(image_path):
    """Read an image using SimpleITK."""
    return sitk.ReadImage(image_path)

def read_mask(mask_path):
    """Read a mask using SimpleITK."""
    return sitk.ReadImage(mask_path)

def visualize_slice_with_overlay(image_path, mask_path, slice_index):
    """Visualize a specific slice from a 3D image with its corresponding mask overlay."""
    image = read_image(image_path)
    mask = read_mask(mask_path)

    # Convert SimpleITK images to numpy arrays
    image_array = sitk.GetArrayFromImage(image)  # z, y, x
    mask_array = sitk.GetArrayFromImage(mask)  # z, y, x

    # Select the slice to visualize
    image_slice = image_array[slice_index, :, :]
    mask_slice = mask_array[slice_index, :, :]

    # Create a figure to show the images
    plt.figure(figsize=(10, 5))

    # Show the image
    plt.subplot(1, 2, 1)
    plt.imshow(np.flipud(np.fliplr(image_slice)), cmap='gray')
    plt.title('Original Image')
    plt.axis('off')

    # Show the image with the mask overlay
    plt.subplot(1, 2, 2)
    plt.imshow(np.flipud(np.fliplr(image_slice)), cmap='gray')
    plt.imshow(np.flipud(np.fliplr(mask_slice)), alpha=0.4, cmap='jet')  # alpha controls the transparency
    plt.title('Image with Mask Overlay')
    plt.axis('off')

    plt.show()

# List all volumes and their corresponding masks
# Retrieve case IDs from the directory names in test volume dir
case_ids = [d for d in os.listdir(test_dir) if os.path.isdir(os.path.join(test_dir, d))]

i = 0 # the first volume
case_dir = os.path.join(test_dir, case_ids[i])
nii_files = glob.glob(os.path.join(case_dir, '*.nii.gz'))
if len(nii_files) == 0:
    print(f"No .nii.gz files found in {case_dir}")
    continue

image_file = nii_files[0]  # Assuming there is only one .nii.gz file per case
mask_path = os.path.join(predict_dir, case_id, 'pred_mask.nii.gz')
slice_index = 200  # Example slice index

visualize_slice_with_overlay(image_path, mask_path, slice_index)