<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/SAM/Run_inference_with_MedSAM_using_HuggingFace_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Segment medical images with MedSAM

In this notebook, we're going to perform inference with [MedSAM](https://twitter.com/BoWang87/status/1650690625317007361), a fine-tuned version of the SAM (segment-anything model) by Meta AI on the medical domain (thereby greatly improving its performance).

* [Original repo](https://github.com/bowang-lab/medsam)
* [Hugging Face docs](https://huggingface.co/docs/transformers/main/en/model_doc/sam).

## Set-up environment

We'll start by installing Transformers from source (as SAM is brand new) and the Datasets library, both by 🤗.

## Load model and processor

Next, let's load the SAM model and its corresponding processor. The model is available on the 🤗 hub here: https://huggingface.co/wanglab/medsam-vit-base.

In [1]:
import torch
from transformers import SamModel, SamProcessor

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)

ModuleNotFoundError: No module named 'torch'

## Load image + ground truth map

Next, let's load an image from the dataset that the MedSAM authors collected, with its ground truth segmentation map. I've uploaded a small subset (130 images + segmentation maps) to the 🤗 hub (see last section of this notebook to see how I did that).

In [None]:
from datasets import load_dataset

dataset = load_dataset("imagefolder", data_dir="demo_set/images", split="train")

In [None]:
import numpy as np
from PIL import Image

idx = 27

# load image
image = dataset[idx]["image"]
IMAGE_ORIGINAL_W, IMAGE_ORIGINAL_H = image.size
image

In [None]:
image = image.resize((256,256), Image.BICUBIC)
SCALE_X, SCALE_Y = 256 / IMAGE_ORIGINAL_W, 256 / IMAGE_ORIGINAL_H
input_boxes = dataset[idx]["objects"]["bbox"]

for i in range(len(input_boxes)):
    input_boxes[i][0] *= SCALE_X
    input_boxes[i][1] *= SCALE_Y
    input_boxes[i][2] *= SCALE_X
    input_boxes[i][3] *= SCALE_Y

In [None]:
from PIL import ImageDraw
test = image.copy()
imgd = ImageDraw.Draw(test) 

for i in range(len(input_boxes)):
    x1, y1, x2, y2 = input_boxes[i]
    print(x1, y1, x2, y2)
    imgd.rectangle([x1, y1, x2, y2])

In [None]:
test

The segmentation is a 2D numpy array, indicating with 1's where the region of interest is.

We can visualize this as follows:

## Load box prompt and predict

The MedSAM authors prompt the model with a bounding box based on the ground truth segmentation. Let's do that here:

The model is prompted to generate a segmentation mask inside the bounding box.

We can prepare the inputs for the model and perform a forward pass. We move the inputs and model to the GPU if it's available.

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(input_boxes)

# prepare image + box prompt for the model
inputs = processor(image, input_boxes=[input_boxes], return_tensors="pt").to(device)
for k,v in inputs.items():
  print(k,v.shape)

In [None]:
model.to(device)

# forward pass
# note that the authors use `multimask_output=False` when performing inference
with torch.no_grad():
  outputs = model(**inputs, multimask_output=False)

In [None]:
outputs.pred_masks.shape

In [None]:
import matplotlib.pyplot as plt

pred_maskses = outputs.pred_masks
pred_maskses = torch.reshape(pred_maskses, (-1, 256, 256))
pred_maskses = torch.sigmoid(pred_maskses)
pred_maskses = (pred_maskses > 0.5)
pred_maskses = pred_maskses.cpu().numpy()

for i in range(len(pred_maskses)):
    plt.figure()
    plt.imshow(pred_maskses[i])

In [None]:
image

Note that MedSAM was fine-tuned using a custom [DiceWithSigmoid loss](https://github.com/bowang-lab/MedSAM/blob/66cf4799a9ab9a8e08428a5087e73fc21b2b61cd/train.py#L70), so we need to apply the appropriate [postprocessing](https://github.com/bowang-lab/MedSAM/blob/66cf4799a9ab9a8e08428a5087e73fc21b2b61cd/MedSAM_Inference.py#L67) here:

## Visualize

Let's visualize the predicted mask:

In [None]:
import cv2 
from PIL import ImageDraw

fig, axes = plt.subplots()

image_arr = np.array(image)
image_PIL_square = Image.fromarray(image_arr)
image_PIL_square = image_PIL_square.convert("RGBA")
image_PIL_square = image_PIL_square.resize((256,256), Image.BICUBIC)

bbox_centers = []

center_to_cropimg = dict()

for i in range(len(pred_maskses)):
    mask_image = Image.fromarray(pred_maskses[i])
    mask_image = mask_image.convert("RGBA")
    mask_image = mask_image.resize((256, 256), Image.BICUBIC)
    
    # print(pred_maskses[i])
    min_x, min_y, max_x, max_y = float('inf'), float('inf'), -1, -1
    for ycoord in range(len(pred_maskses[i])):
        for xcoord in range(len(pred_maskses[i][ycoord])):
            if pred_maskses[i][ycoord][xcoord]:
                min_x = min(min_x, xcoord)
                max_x = max(max_x, xcoord)
                min_y = min(min_y, ycoord)
                max_y = max(max_y, ycoord)
    
    bbox_centers.append([(min_x + max_x) / 2, (min_y + max_y) / 2])

    print(min_x * (1/SCALE_X), min_y * (1/SCALE_Y), max_x * (1/SCALE_X), max_y * (1/SCALE_Y))
    # imgd = ImageDraw.Draw(image_PIL_square)
    # imgd.rectangle([min_x, min_y, max_x, max_y])
    # image_PIL_square = Image.blend(image_PIL_square, mask_image, alpha=0.44)
    
    center_to_cropimg[((min_x + max_x) / 2, (min_y + max_y) / 2)] = np.array(image_PIL_square)[min_x:max_x, min_y:max_y, :]

print(bbox_centers)
imgd = ImageDraw.Draw(image_PIL_square.copy())
for i in range(len(bbox_centers)):
    imgd.ellipse([bbox_centers[i][0]-2, bbox_centers[i][1]-2, bbox_centers[i][0]+2, bbox_centers[i][1]+2], fill="#ffff33")

# image_PIL_square = image_PIL_square.resize((IMAGE_ORIGINAL_W, IMAGE_ORIGINAL_H), Image.BICUBIC)
axes.imshow(image_PIL_square)
axes.axis("off")

In [None]:
# TODO: generate robot trajectories, assuming that top left corner is (0,0)

ORIGIN_X = 0
ORIGIN_Y = 0

for bbox_center in bbox_centers:
    print(bbox_center[0] / IMAGE_ORIGINAL_W, bbox_center[1] / IMAGE_ORIGINAL_H)

In [1]:
from transformers import pipeline
generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=0)

for center_pt in center_to_cropimg:
    print(f"generating mask for center: {center_pt}")
    outputs = generator(Image.fromarray(center_to_cropimg[center_pt]), points_per_batch=64)
    print(outputs)

  from .autonotebook import tqdm as notebook_tqdm


RuntimeError: Failed to import transformers.models.sam.modeling_tf_sam because of the following error (look up to see its traceback):
Your currently installed version of Keras is Keras 3, but this is not yet supported in Transformers. Please install the backwards-compatible tf-keras package with `pip install tf-keras`.