# MedSAM Inference

This notebook provides a clean implementation for MedSAM inference on medical images.

## Requirements
- PyTorch
- segment-anything
- numpy
- matplotlib
- opencv-python

In [None]:
# %% Environment and functions
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
import torch
from segment_anything import sam_model_registry
import torch.nn.functional as F

## Visualization Functions

In [None]:
# Visualization functions
# Source: https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb

def show_mask(mask, ax, random_color=False):
    """Display segmentation mask on matplotlib axis.
    
    Args:
        mask: Binary mask array (H, W)
        ax: Matplotlib axis object
        random_color: If True, use random color; otherwise use yellow
    """
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([251/255, 252/255, 30/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    """Display bounding box on matplotlib axis.
    
    Args:
        box: Bounding box coordinates [x0, y0, x1, y1]
        ax: Matplotlib axis object
    """
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2))

## MedSAM Inference Function

In [None]:
@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
    """Perform MedSAM inference given image embeddings and bounding box.
    
    Args:
        medsam_model: MedSAM model instance
        img_embed: Image embeddings from encoder (B, 256, 64, 64)
        box_1024: Bounding box in 1024x1024 scale (B, 4)
        H: Target height for output mask
        W: Target width for output mask
        
    Returns:
        Binary segmentation mask (H, W)
    """
    box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
    if len(box_torch.shape) == 2:
        box_torch = box_torch[:, None, :]  # (B, 1, 4)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=None,
        boxes=box_torch,
        masks=None,
    )
    
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed,  # (B, 256, 64, 64)
        image_pe=medsam_model.prompt_encoder.get_dense_pe(),  # (1, 256, 64, 64)
        sparse_prompt_embeddings=sparse_embeddings,  # (B, 2, 256)
        dense_prompt_embeddings=dense_embeddings,  # (B, 256, 64, 64)
        multimask_output=False,
    )

    low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)

    low_res_pred = F.interpolate(
        low_res_pred,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )  # (1, 1, H, W)
    
    low_res_pred = low_res_pred.squeeze().cpu().numpy()  # (H, W)
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
    return medsam_seg

## Utility Functions

In [None]:
def get_bboxes_from_mask(mask):
    """Extract the largest bounding box from a binary mask.
    
    Args:
        mask: Binary mask (can be numpy array or torch tensor)
        
    Returns:
        Bounding box coordinates [x0, y0, x1, y1] or None if no contours found
    """
    # Convert the mask to a numpy array of type uint8
    if torch.is_tensor(mask):
        mask_np = (mask.cpu().numpy() == 1).astype(np.uint8)
    else:
        mask_np = (mask == 1).astype(np.uint8)

    contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Initialize variables to keep track of the largest bounding box and its area
    largest_area = 0
    largest_bbox = None

    for contour in contours:
        # Calculate the area of the current bounding box
        x, y, w, h = cv2.boundingRect(contour)
        area = w * h

        # If the current bounding box has a larger area, update
        if area > largest_area:
            largest_area = area
            largest_bbox = [x, y, x + w, y + h]

    if largest_bbox is not None:
        # Convert the coordinates to floating-point numbers
        largest_bbox = [float(coord) for coord in largest_bbox]

    return largest_bbox

## Model Loading

Choose one of the following methods to load your MedSAM model:

In [None]:
# Method 1: Load model with checkpoint directly (for older checkpoints)
MedSAM_CKPT_PATH = "path/to/medsam_model_best.pth"
device = "cuda:0" if torch.cuda.is_available() else "cpu"

medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
medsam_model = medsam_model.to(device)
medsam_model.eval()

print(f"Model loaded successfully on {device}")

In [None]:
# Method 2: Load model with separate pretrained weights and trained checkpoint
# (Uncomment if you need this method)

# MedSAM_CKPT_PATH = "path/to/medsam_model_best.pth"
# sam_pretrain_path = "path/to/medsam_vit_b.pth"
# device = "cuda:0" if torch.cuda.is_available() else "cpu"

# medsam_model = sam_model_registry['vit_b'](checkpoint=sam_pretrain_path)

# checkpoint = torch.load(MedSAM_CKPT_PATH, map_location=device)
# start_epoch = checkpoint["epoch"] + 1
# medsam_model.load_state_dict(checkpoint["model"])

# medsam_model = medsam_model.to(device)
# medsam_model.eval()

# print(f"Model loaded successfully on {device}, trained until epoch {start_epoch}")

## Example: Single Image Inference

This demonstrates how to run inference on a single image with a bounding box prompt.

In [None]:
# Load and preprocess image
image_path = 'path/to/your/image.png'
img_cv = cv2.imread(image_path, cv2.IMREAD_COLOR)

# Ensure image is 3-channel
if len(img_cv.shape) == 2:
    img_3c = cv2.cvtColor(img_cv, cv2.COLOR_GRAY2BGR)
else:
    img_3c = img_cv

# Convert from BGR to RGB
img_3c = cv2.cvtColor(img_3c, cv2.COLOR_BGR2RGB)
H, W, _ = img_3c.shape

print(f"Original image size: {H}x{W}")

In [None]:
# Resize image to 1024x1024
img_1024 = cv2.resize(img_3c, (1024, 1024), interpolation=cv2.INTER_CUBIC)

# Normalize the image to [0, 1] range
img_1024 = (img_1024 - img_1024.min()) / np.clip(
    img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None
)

# Convert to tensor and prepare for model
img_1024_tensor = torch.tensor(img_1024, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)

print(f"Preprocessed image tensor shape: {img_1024_tensor.shape}")

In [None]:
# Define bounding box in original image coordinates [x0, y0, x1, y1]
# Example: box around center of image
box_np = np.array([[W//4, H//4, 3*W//4, 3*H//4]])

# Scale box to 1024x1024
box_1024 = box_np / np.array([W, H, W, H]) * 1024

print(f"Bounding box (original scale): {box_np[0]}")
print(f"Bounding box (1024 scale): {box_1024[0]}")

In [None]:
# Generate image embeddings
with torch.no_grad():
    image_embedding = medsam_model.image_encoder(img_1024_tensor)  # (1, 256, 64, 64)

# Perform segmentation inference
medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, 1024, 1024)

print(f"Segmentation mask shape: {medsam_seg.shape}")
print(f"Number of positive pixels: {medsam_seg.sum()}")

In [None]:
# Visualize results
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Original image
axes[0].imshow(img_3c)
axes[0].set_title("Original Image")
axes[0].axis("off")

# Predicted mask
axes[1].imshow(img_1024)
show_mask(medsam_seg, axes[1])
show_box(box_1024[0], axes[1])
axes[1].set_title("Predicted Mask")
axes[1].axis("off")

plt.tight_layout()
plt.show()

In [None]:
# Optional: Save the segmentation mask
output_path = 'output_mask.png'
cv2.imwrite(output_path, medsam_seg * 255)  # Scale to 0-255 for saving
print(f"Mask saved to {output_path}")