# Base MedSAM2 Inference on SpineMetsCT Validation Set

This notebook runs inference on the SpineMetsCT validation dataset using the base MedSAM2 model (without fine-tuning) and visualizes the results.

In [11]:
# Import required libraries
import os
import glob
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import sys
from PIL import Image

# Add MedSAM2 to Python path
sys.path.append("MedSAM2")
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# Ensure reproducibility
torch.set_float32_matmul_precision('high')
torch.manual_seed(2024)
torch.cuda.manual_seed(2024)
np.random.seed(2024)

## Define helper functions for loading data and visualizing results

In [None]:
def load_npz_data(npz_path):
    """Load data from an NPZ file."""
    data = np.load(npz_path)
    # Updated keys based on the actual NPZ file structure
    image = data['imgs']  # 3D volume with shape (8, 512, 512)
    gt_mask = data['gts'] if 'gts' in data else None  # Ground truth mask
    return image, gt_mask

def convert_to_rgb(image):
    """Convert a grayscale image to RGB by repeating the channel 3 times.
    
    Parameters:
        image (numpy.ndarray): Grayscale image of shape (H, W).
        
    Returns:
        numpy.ndarray: RGB image of shape (3, H, W) in float32 format.
    """
    # Normalize image to 0-255
    if image.max() > 1.0:
        image_norm = (image - image.min()) / (image.max() - image.min()) * 255.0
    else:
        image_norm = image * 255.0
    
    # Convert to uint8 for PIL
    image_uint8 = image_norm.astype(np.uint8)
    
    # Convert to RGB using PIL
    pil_image = Image.fromarray(image_uint8)
    rgb_image = pil_image.convert('RGB')
    
    # Convert back to numpy array
    rgb_array = np.array(rgb_image)
    
    # Rearrange to channels-first format (3, H, W) and convert to float32
    rgb_array = np.transpose(rgb_array, (2, 0, 1)).astype(np.float32)
    
    # Normalize to 0-1 range for the model
    rgb_array = rgb_array / 255.0
    
    return rgb_array

def show_mask(mask, ax, mask_color=None, alpha=0.5):
    """
    Show mask overlay on the image

    Parameters
    ----------
    mask : numpy.ndarray
        mask of the image
    ax : matplotlib.axes.Axes
        axes to plot the mask
    mask_color : numpy.ndarray
        color of the mask
    alpha : float
        transparency of the mask
    """
    if mask_color is not None:
        color = np.concatenate([mask_color, np.array([alpha])], axis=0)
    else:
        color = np.array([251/255, 252/255, 30/255, alpha])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def calculate_dice(pred_mask, gt_mask):
    """Calculate Dice coefficient between prediction and ground truth"""
    intersection = (pred_mask * gt_mask).sum()
    return (2. * intersection) / (pred_mask.sum() + gt_mask.sum() + 1e-8)

## Set up the model and prediction environment

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Path to validation data
val_dir = "processed_data/SpineMetsCT_npz/val"

# Get list of validation NPZ files
val_files = glob.glob(os.path.join(val_dir, "*.npz"))
print(f"Found {len(val_files)} validation files")

# Load base model (MedSAM2 pretrained)
print("Loading base MedSAM2 model...")
sam_checkpoint = "MedSAM2/checkpoints/MedSAM2_2411.pt"

# Fix: Use absolute path to the config file
config_file = "MedSAM2/sam2/configs/sam2.1_hiera_t512.yaml"

# Build model and load weights
sam = build_sam2(
    config_file=config_file, 
    ckpt_path=sam_checkpoint,
    device=device
)

# Create predictor
predictor = SAM2ImagePredictor(sam)

# Create directory for saving results
output_dir = "inference_results_base_model"
os.makedirs(output_dir, exist_ok=True)

Using device: cuda
Found 1704 validation files
Loading base MedSAM2 model...


## Test Data Loading and Preprocessing

Let's first test our image loading and preprocessing to ensure everything works correctly.

In [None]:
# Load a sample image and check its properties
if len(val_files) > 0:
    sample_file = val_files[0]
    print(f"Sample file: {os.path.basename(sample_file)}")
    
    # Load data
    image_3d, gt_mask = load_npz_data(sample_file)
    print(f"Image shape: {image_3d.shape}, dtype: {image_3d.dtype}")
    print(f"Mask shape: {gt_mask.shape}, dtype: {gt_mask.dtype}")
    print(f"Image value range: [{image_3d.min()}, {image_3d.max()}]")
    
    # Select a slice
    slice_idx = image_3d.shape[0] // 2
    image_slice = image_3d[slice_idx]
    mask_slice = gt_mask[slice_idx]
    
    # Convert to RGB and check
    rgb_image = convert_to_rgb(image_slice)
    print(f"RGB image shape: {rgb_image.shape}, dtype: {rgb_image.dtype}")
    print(f"RGB image value range: [{rgb_image.min()}, {rgb_image.max()}]")
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original grayscale image
    axes[0].imshow(image_slice, cmap='gray')
    axes[0].set_title("Original Grayscale")
    axes[0].axis('off')
    
    # RGB channels
    rgb_display = np.transpose(rgb_image, (1, 2, 0))
    axes[1].imshow(rgb_display)
    axes[1].set_title("Converted RGB")
    axes[1].axis('off')
    
    # Ground truth mask
    axes[2].imshow(image_slice, cmap='gray')
    show_mask(mask_slice, ax=axes[2], mask_color=np.array([1.0, 0, 0]))
    axes[2].set_title("Ground Truth Mask")
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print("No validation files found to test.")

## Run inference on validation dataset

We'll process samples from the validation set and visualize the results.

In [None]:
# Metrics dictionary
metrics = {'dice': []}

# Number of samples to process
num_samples = 5
max_slices_per_sample = 4  # Limit slices per sample to avoid too many visualizations

# Run inference on validation set
for i, npz_path in enumerate(tqdm(val_files[:num_samples])):
    filename = os.path.basename(npz_path)
    print(f"\nProcessing {filename}...")
    
    # Load image and ground truth
    image_3d, gt_mask = load_npz_data(npz_path)
    print(f"Loaded data shapes - Image: {image_3d.shape}, Mask: {gt_mask.shape}")
    
    # Process each slice in the 3D volume (limit to a few slices for demonstration)
    total_slices = image_3d.shape[0]
    slice_indices = [
        0,  # First slice
        total_slices // 3,  # One-third through the volume
        total_slices // 2,  # Middle slice
        total_slices - 1,  # Last slice
    ][:max_slices_per_sample]  # Limit to max_slices_per_sample
    
    for slice_idx in slice_indices:
        image = image_3d[slice_idx]
        
        # Normalize image for visualization if needed
        if image.max() > 1.0:
            image_viz = (image - image.min()) / (image.max() - image.min())
        else:
            image_viz = image
            
        # Get ground truth slice
        if gt_mask is not None:
            gt_slice = gt_mask[slice_idx]
        else:
            gt_slice = None
            
        # Convert grayscale to RGB before passing to the model
        image_rgb = convert_to_rgb(image)
        
        # Convert to PyTorch tensor and move to appropriate device
        input_tensor = torch.tensor(image_rgb, dtype=torch.float32).to(device)
        
        # Set image for predictor
        predictor.set_image(input_tensor)
        
        # Get automatic mask prediction
        masks, scores, _ = predictor.predict()
        
        # Visualize and save results
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        
        # Plot original image
        axes[0].imshow(image_viz, cmap='gray')
        axes[0].set_title("Original Image")
        axes[0].axis('off')
        
        # Plot ground truth if available
        if gt_slice is not None:
            axes[1].imshow(image_viz, cmap='gray')
            show_mask(gt_slice, ax=axes[1], mask_color=np.array([1.0, 0, 0]))
            axes[1].set_title("Ground Truth")
            axes[1].axis('off')
        else:
            axes[1].set_visible(False)
            
        # Plot prediction
        if len(masks) > 0:
            # Use the highest scoring mask
            mask = masks[0]  # Shape: H x W
            axes[2].imshow(image_viz, cmap='gray')
            show_mask(mask, ax=axes[2], mask_color=np.array([0, 0, 1.0]))
            axes[2].set_title(f"Prediction (Score: {scores[0]:.3f})")
            axes[2].axis('off')
            
            # Calculate Dice score if ground truth is available
            if gt_slice is not None:
                dice = calculate_dice(mask, gt_slice)
                metrics['dice'].append(dice)
                print(f"Slice {slice_idx} - Dice score: {dice:.4f}")
        else:
            axes[2].imshow(image_viz, cmap='gray')
            axes[2].set_title("No prediction")
            axes[2].axis('off')
        
        # Save figure
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"{os.path.splitext(filename)[0]}_slice{slice_idx}_results.png"), dpi=150)
        plt.close(fig)
        
    # Limit the number of files processed
    if i >= num_samples - 1:
        break

  0%|          | 0/5 [00:00<?, ?it/s]


Processing 15067_12-03-2011-SpineSPINEBONESBRT Adult-72528_5.000000-SKINTOSKINSIM0.5MM15067a iMAR-28213_300.000000-Spine Segmentation-27808_144.npz...
Loaded data shapes - Image: (8, 512, 512), Mask: (8, 512, 512)


  0%|          | 0/5 [00:00<?, ?it/s]



RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/tmp2/b10902078/miniconda3/envs/MEDSAM/lib/python3.10/site-packages/torch/nn/modules/container.py", line 240, in forward
    def forward(self, input):
        for module in self:
            input = module(input)
                    ~~~~~~ <--- HERE
        return input
  File "/tmp2/b10902078/miniconda3/envs/MEDSAM/lib/python3.10/site-packages/torchvision/transforms/transforms.py", line 277, in forward
            Tensor: Normalized Tensor image.
        """
        return F.normalize(tensor, self.mean, self.std, self.inplace)
               ~~~~~~~~~~~ <--- HERE
  File "/tmp2/b10902078/miniconda3/envs/MEDSAM/lib/python3.10/site-packages/torchvision/transforms/functional.py", line 350, in normalize
        raise TypeError(f"img should be Tensor Image. Got {type(tensor)}")

    return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
           ~~~~~~~~~~~~~ <--- HERE
  File "/tmp2/b10902078/miniconda3/envs/MEDSAM/lib/python3.10/site-packages/torchvision/transforms/_functional_tensor.py", line 928, in normalize
    if std.ndim == 1:
        std = std.view(-1, 1, 1)
    return tensor.sub_(mean).div_(std)
           ~~~~~~~~~~~ <--- HERE
RuntimeError: The size of tensor a (512) must match the size of tensor b (3) at non-singleton dimension 0


## Analyze and display the quantitative results

In [None]:
# Print average Dice score
if metrics['dice']:
    avg_dice = np.mean(metrics['dice'])
    std_dice = np.std(metrics['dice'])
    min_dice = np.min(metrics['dice'])
    max_dice = np.max(metrics['dice'])
    median_dice = np.median(metrics['dice'])
    
    print(f"Dice statistics on {len(metrics['dice'])} validation slices:")
    print(f"Average: {avg_dice:.4f}")
    print(f"Standard deviation: {std_dice:.4f}")
    print(f"Minimum: {min_dice:.4f}")
    print(f"Maximum: {max_dice:.4f}")
    print(f"Median: {median_dice:.4f}")
    
    # Plot histogram of Dice scores
    plt.figure(figsize=(10, 6))
    plt.hist(metrics['dice'], bins=10, edgecolor='black')
    plt.axvline(avg_dice, color='red', linestyle='dashed', linewidth=1, label=f'Mean: {avg_dice:.4f}')
    plt.axvline(median_dice, color='green', linestyle='dashed', linewidth=1, label=f'Median: {median_dice:.4f}')
    plt.xlabel('Dice Score')
    plt.ylabel('Frequency')
    plt.title('Distribution of Dice Scores on Validation Set')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'dice_distribution.png'), dpi=150)
    plt.show()
else:
    print("No valid samples for calculating Dice scores.")
    
print(f"Inference results saved to {output_dir}")

## 3D Volume Visualization

Let's visualize a few representative slices from one 3D volume with their predictions overlaid.

In [None]:
def process_and_visualize_volume(npz_path):
    """Process a whole 3D volume and visualize selected slices."""
    filename = os.path.basename(npz_path)
    print(f"\nProcessing 3D volume: {filename}...")
    
    # Load image and ground truth
    image_3d, gt_mask = load_npz_data(npz_path)
    print(f"Volume shape: {image_3d.shape}")
    
    # Create a 3D mask for predictions
    pred_mask_3d = np.zeros_like(gt_mask)
    
    # Process all slices
    slice_dice_scores = []
    for slice_idx in tqdm(range(image_3d.shape[0]), desc="Processing slices"):
        image = image_3d[slice_idx]
        
        # Convert grayscale to RGB
        image_rgb = convert_to_rgb(image)
        
        # Convert to PyTorch tensor and move to device
        input_tensor = torch.tensor(image_rgb, dtype=torch.float32).to(device)
        
        # Set image for predictor
        predictor.set_image(input_tensor)
        
        # Get automatic mask prediction
        masks, scores, _ = predictor.predict()
        
        # Store prediction in 3D mask
        if len(masks) > 0:
            mask = masks[0]
            pred_mask_3d[slice_idx] = mask
            
            # Calculate Dice score
            if gt_mask is not None:
                dice = calculate_dice(mask, gt_mask[slice_idx])
                slice_dice_scores.append(dice)
    
    # Calculate volume-wide Dice score
    volume_dice = calculate_dice(pred_mask_3d, gt_mask)
    print(f"Volume-wide Dice score: {volume_dice:.4f}")
    
    # Select slices to visualize (start, 25%, 50%, 75%, end)
    num_slices = image_3d.shape[0]
    slice_indices = [
        0,
        num_slices // 4,
        num_slices // 2,
        3 * num_slices // 4,
        num_slices - 1
    ]
    
    # Visualize selected slices
    fig, axes = plt.subplots(len(slice_indices), 3, figsize=(15, 4*len(slice_indices)))
    
    for i, slice_idx in enumerate(slice_indices):
        # Prepare image for visualization
        image = image_3d[slice_idx]
        if image.max() > 1.0:
            image_viz = (image - image.min()) / (image.max() - image.min())
        else:
            image_viz = image
        
        # Get masks
        gt_slice = gt_mask[slice_idx]
        pred_slice = pred_mask_3d[slice_idx]
        
        # Calculate Dice score for this slice
        slice_dice = calculate_dice(pred_slice, gt_slice)
        
        # Original image
        axes[i, 0].imshow(image_viz, cmap='gray')
        axes[i, 0].set_title(f"Slice {slice_idx}")
        axes[i, 0].axis('off')
        
        # Ground truth overlay
        axes[i, 1].imshow(image_viz, cmap='gray')
        show_mask(gt_slice, ax=axes[i, 1], mask_color=np.array([1.0, 0, 0]))
        axes[i, 1].set_title("Ground Truth")
        axes[i, 1].axis('off')
        
        # Prediction overlay
        axes[i, 2].imshow(image_viz, cmap='gray')
        show_mask(pred_slice, ax=axes[i, 2], mask_color=np.array([0, 0, 1.0]))
        axes[i, 2].set_title(f"Prediction (Dice: {slice_dice:.3f})")
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{os.path.splitext(filename)[0]}_3d_visualization.png"), dpi=150)
    plt.show()
    
    return volume_dice

# Process a sample volume
if len(val_files) > 0:
    sample_volume = val_files[0]  # Use the first file
    volume_dice = process_and_visualize_volume(sample_volume)
else:
    print("No validation files available.")