In [None]:
from platform import python_version

print(python_version())

3.10.14


In [None]:
import os
import json
import numpy as np
import torch
from sklearn.metrics import precision_score, recall_score, f1_score
from scipy.spatial.distance import directed_hausdorff
import matplotlib.pyplot as plt
import random
from skimage import transform
import warnings
import nibabel as nib
from tqdm import tqdm
from segment_anything import sam_model_registry
import torch.nn.functional as F

In [None]:
dataset_path="/home/manuri/data/Task01_BrainTumour_org/"

In [None]:
json_path="/home/manuri/data/Task01_BrainTumour/"

In [None]:
# Load the image files from the original source folder
json_filename = os.path.join(json_path, "dataset.json")

try:
    with open(json_filename, "r") as fp:
        experiment_data = json.load(fp)
except IOError as e:
    print("File {} doesn't exist. It should be part of the "
          "Decathlon directory".format(json_filename))

output_channels = experiment_data["labels"]
input_channels = experiment_data["modality"]
description = experiment_data["description"]
name = experiment_data["name"]
release = experiment_data["release"]
license = experiment_data["licence"]
reference = experiment_data["reference"]
tensorImageSize = experiment_data["tensorImageSize"]
numFiles = experiment_data["numTraining"]
numFiles_Test = experiment_data["numTest"]

filenames = {}
img = []
label = []
for idx in range(numFiles_Test):
    img.append(os.path.join(dataset_path,experiment_data["test"][idx]["image"]))
    label.append(os.path.join(dataset_path,experiment_data["test"][idx]["label"]))

filenames['images'] = img
filenames['label'] = label

In [None]:
#filenames

In [None]:
len(filenames['images'])

73

In [None]:
import torch
torch.cuda.is_available()

True

In [None]:
# Load, preprocess data and generate the segementation masks

def load_nifti(file_path):
    return nib.load(file_path).get_fdata()

def preprocess_slice(slice_data, target_size=(1024, 1024)):
    """
    Preprocess a single slice of MRI data.

    Args:
    slice_data (np.ndarray): Input 2D slice of MRI data
    target_size (tuple): Desired output size (height, width)

    Returns:
    np.ndarray: Preprocessed slice with shape (*target_size, 3)
    """
    # Ensure the input is 2D
    if slice_data.ndim != 2:
        raise ValueError(f"Expected 2D input, got shape {slice_data.shape}")

    # Handle NaN values
    if np.isnan(slice_data).any():
        #warnings.warn("NaN values found in input slice. Replacing with zeros.")
        slice_data = np.nan_to_num(slice_data, nan=0.0)

    # Clip extreme values (e.g., outliers)
    p1, p99 = np.percentile(slice_data, (1, 99))
    slice_data = np.clip(slice_data, p1, p99)

    # Normalize to [0, 1] range
    slice_min, slice_max = slice_data.min(), slice_data.max()
    if slice_min == slice_max:
        #warnings.warn("Constant intensity slice encountered. Returning zeros.")
        return np.zeros((*target_size, 3), dtype=np.float32)

    slice_data = (slice_data - slice_min) / (slice_max - slice_min)

    # Resize the slice
    if slice_data.shape != target_size:
        slice_data = transform.resize(
            slice_data,
            target_size,
            order=3,  # cubic spline interpolation
            mode='constant',
            anti_aliasing=True,
            preserve_range=True
        )

    # Ensure the output is in [0, 1] range after resize
    slice_data = np.clip(slice_data, 0, 1)

    # Convert to RGB-like format
    slice_data_rgb = np.stack([slice_data] * 3, axis=-1)

    return slice_data_rgb.astype(np.float32)

def evaluate_slice(model, image_slice, mask_slice, device):
    H, W = image_slice.shape
    image_slice = preprocess_slice(image_slice)  # Now uses the improved function
    image_tensor = torch.tensor(image_slice).permute(2, 0, 1).unsqueeze(0).to(device)

    with torch.no_grad():
        image_embedding = model.image_encoder(image_tensor)

    box = get_bounding_box(mask_slice)
    box_1024 = np.array(box) / np.array([W, H, W, H]) * 1024
    box_1024 = box_1024[None, :]

    pred_mask = medsam_inference(model, image_embedding, box_1024, H, W)
    pred_mask = transform.resize(pred_mask, (H, W), order=0, preserve_range=True, anti_aliasing=False)

    return pred_mask.astype(np.uint8)

def get_bounding_box(mask):
    # Get bounding box for the tumor region (any non-background label)
    rows = np.any(mask > 0, axis=1)
    cols = np.any(mask > 0, axis=0)

    if np.sum(rows) == 0 or np.sum(cols) == 0:
        # If there's no tumor in this slice, return a small central box
        h, w = mask.shape
        center_h, center_w = h // 2, w // 2
        box_size = 10  # Small box size
        return [center_w - box_size // 2, center_h - box_size // 2,
                center_w + box_size // 2, center_h + box_size // 2]

    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    # Add a small padding to the bounding box
    padding = 10
    rmin = max(0, rmin - padding)
    rmax = min(mask.shape[0] - 1, rmax + padding)
    cmin = max(0, cmin - padding)
    cmax = min(mask.shape[1] - 1, cmax + padding)

    return [cmin, rmin, cmax, rmax]

def evaluate_volume(model, image_volume, mask_volume, device):
    predictions = []
    for i in range(image_volume.shape[2]):  # Iterate through slices
        image_slice = image_volume[:,:,i,3] if image_volume.ndim == 4 else image_volume[:,:,i]
        mask_slice = mask_volume[:,:,i]

        pred_slice = evaluate_slice(model, image_slice, mask_slice, device)
        predictions.append(pred_slice)

    return np.stack(predictions, axis=-1)

@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
    box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
    if len(box_torch.shape) == 1:  # (4,)
        box_torch = box_torch.unsqueeze(0).unsqueeze(0)  # (1, 1, 4)
    elif len(box_torch.shape) == 2:  # (B, 4) or (1, 4)
        box_torch = box_torch.unsqueeze(1)  # (B, 1, 4) or (1, 1, 4)

    #print("box_torch shape:", box_torch.shape)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=None,
        boxes=box_torch,
        masks=None,
    )
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed,
        image_pe=medsam_model.prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=False,
    )

    low_res_pred = torch.sigmoid(low_res_logits)

    low_res_pred = F.interpolate(
        low_res_pred,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )
    medsam_seg = (low_res_pred > 0.5).squeeze().cpu().numpy().astype(np.uint8)

    if medsam_seg.ndim > 2:
        medsam_seg = medsam_seg[0]

    return medsam_seg

def dice_coefficient(y_true, y_pred, smooth=1e-7):
    intersection = np.sum(y_true * y_pred)
    return (2. * intersection + smooth) / (np.sum(y_true) + np.sum(y_pred) + smooth)

def hausdorff_distance(y_true, y_pred):
    if np.sum(y_true) == 0 and np.sum(y_pred) == 0:
        return 0.0
    elif np.sum(y_true) == 0 or np.sum(y_pred) == 0:
        return np.inf
    return max(directed_hausdorff(y_true, y_pred)[0], directed_hausdorff(y_pred, y_true)[0])

def calculate_metrics(true_mask, pred_mask):
    metrics = {}
    class_names = {0: "background", 1: "edema", 2: "non-enhancing tumor", 3: "enhancing tumour"}

    # Overall metrics
    true_foreground = (true_mask > 0).astype(int)
    pred_foreground = (pred_mask > 0).astype(int)

    metrics['overall_dice'] = dice_coefficient(true_foreground, pred_foreground)
    #metrics['overall_hausdorff'] = hausdorff_distance(true_foreground, pred_foreground)

    # Per-class metrics
    for class_id in range(1, 4):  # Skip background class
        true_class = (true_mask == class_id).astype(int)
        pred_class = (pred_mask == class_id).astype(int)

        if np.sum(true_class) == 0 and np.sum(pred_class) == 0:
            # Both true and predicted masks are empty for this class
            metrics[f'dice_{class_names[class_id]}'] = 1.0
            metrics[f'precision_{class_names[class_id]}'] = 1.0
            metrics[f'recall_{class_names[class_id]}'] = 1.0
            metrics[f'f1_{class_names[class_id]}'] = 1.0
            #metrics[f'hausdorff_{class_names[class_id]}'] = 0.0
        elif np.sum(true_class) == 0 or np.sum(pred_class) == 0:
            # One of the masks is empty, the other is not
            metrics[f'dice_{class_names[class_id]}'] = 0.0
            metrics[f'precision_{class_names[class_id]}'] = 0.0
            metrics[f'recall_{class_names[class_id]}'] = 0.0
            metrics[f'f1_{class_names[class_id]}'] = 0.0
            #metrics[f'hausdorff_{class_names[class_id]}'] = np.inf
        else:
            # Both masks have some positive pixels
            metrics[f'dice_{class_names[class_id]}'] = dice_coefficient(true_class, pred_class)
            metrics[f'precision_{class_names[class_id]}'] = precision_score(true_class.flatten(), pred_class.flatten())
            metrics[f'recall_{class_names[class_id]}'] = recall_score(true_class.flatten(), pred_class.flatten())
            metrics[f'f1_{class_names[class_id]}'] = f1_score(true_class.flatten(), pred_class.flatten())
            #metrics[f'hausdorff_{class_names[class_id]}'] = hausdorff_distance(true_class, pred_class)

    return metrics

def plot_results(image, ground_truth, prediction, slice_index):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Plot original image
    axes[0].imshow(image, cmap='gray')
    axes[0].set_title(f"Original Image (Slice {slice_index})")
    axes[0].axis('off')

    # Plot ground truth mask
    axes[1].imshow(ground_truth, cmap='nipy_spectral')
    axes[1].set_title(f"Ground Truth Mask (Slice {slice_index})")
    axes[1].axis('off')

    # Plot predicted mask
    axes[2].imshow(prediction, cmap='nipy_spectral')
    axes[2].set_title(f"Predicted Mask (Slice {slice_index})")
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()


# Load the pretrained model
MedSAM_CKPT_PATH = "medsam_vit_b.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH).to(device)
model.eval()

image_files = filenames['images']
mask_files = filenames['label']

all_metrics = []

for img_file, mask_file in tqdm(zip(image_files, mask_files), total=len(image_files)):
    image_volume = load_nifti(img_file)
    mask_volume = load_nifti(mask_file)

    pred_volume = evaluate_volume(model, image_volume, mask_volume, device)
    metrics = calculate_metrics(mask_volume, pred_volume)
    all_metrics.append(metrics)

    # random_plot_done = False

    # #Randomly select one volume for plotting
    # if not random_plot_done and random.random() < 0.2:  # 20% chance to plot for each volume
    #     random_slice = random.randint(0, image_volume.shape[2] - 1)
    #     plot_results(
    #         image_volume[:, :, random_slice, 1] if image_volume.ndim == 4 else image_volume[:, :, random_slice],
    #         mask_volume[:, :, random_slice],
    #         pred_volume[:, :, random_slice],
    #         random_slice
    #     )
    #     random_plot_done = True

# Aggregate and print results
average_metrics = {metric: np.mean([m[metric] for m in all_metrics]) for metric in all_metrics[0]}
print("Average Metrics:")
for metric, value in average_metrics.items():
    print(f"{metric}: {value:.4f}")

100%|██████████| 73/73 [1:24:46<00:00, 69.68s/it]

Average Metrics:
overall_dice: 0.7407
dice_edema: 0.5306
precision_edema: 0.3790
recall_edema: 0.9508
f1_edema: 0.5306
dice_non-enhancing tumor: 0.0000
precision_non-enhancing tumor: 0.0000
recall_non-enhancing tumor: 0.0000
f1_non-enhancing tumor: 0.0000
dice_enhancing tumour: 0.0000
precision_enhancing tumour: 0.0000
recall_enhancing tumour: 0.0000
f1_enhancing tumour: 0.0000



