# Toy Example: Pancreatic Medical Image Segmentation with Segment Anything (SAM)

This notebook demonstrates how to apply Meta AI's Segment Anything Model (SAM) to a medical image (CT scan) for semantic segmentation. We use a single input point to predict multiple mask hypotheses over a CT slice.

In [None]:
!pip install opencv-python matplotlib torch torchvision
!pip install git+https://github.com/facebookresearch/segment-anything.git

In [None]:
import sys
sys.path.append('/Users/olihiidikwu/Desktop/segment-anything')  # Path to SAM repo

In [None]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor

# Load SAM model
sam_checkpoint = "/Users/olihiidikwu/Desktop/segment-anything/sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cuda" if torch.cuda.is_available() else "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

We load the ViT-B variant of the Segment Anything Model (SAM), which is optimized for general-purpose segmentation. The model is initialized with a pretrained checkpoint and moved to GPU (if available) or CPU.

In [None]:
# Load and display the sample CT scan
image = cv2.imread("/Users/olihiidikwu/Desktop/segment-anything/sample.png")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.imshow(image)
plt.title("Input Image (CT Slice)")
plt.axis('off')
plt.show()

We load a grayscale CT image (`sample.png`) representing an axial abdominal slice. This image will be used as input for the segmentation task.

In [None]:
# Initialize the predictor
predictor = SamPredictor(sam)
predictor.set_image(image)

# Define a single point prompt
input_point = np.array([[300, 300]])  # Modify if needed
input_label = np.array([1])  # 1 = foreground

# Predict masks
masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

# Display each predicted mask
for i, mask in enumerate(masks):
    plt.figure()
    plt.imshow(image)
    plt.imshow(mask, alpha=0.5)
    plt.title(f"Mask {i+1} — Score: {scores[i]:.3f}")
    plt.axis('off')
    plt.show()

We use a single coordinate as the segmentation prompt. SAM returns multiple possible segmentation masks, each with a confidence score. The masks are overlaid onto the original image for visualization.

### 📌 Next Steps
- Experiment with different prompt points to refine results
- Test SAM on a batch of medical images
- Integrate medical-specific pre-processing (e.g., windowing)
- Compare SAM masks with ground truth labels (if available)

### 🔬 Remarks
While SAM was not specifically trained on medical images, it shows strong generalization. Future work may involve fine-tuning SAM on medical segmentation datasets or combining it with domain-specific techniques.