<a href="https://colab.research.google.com/github/nye0/SAM-Med2D/blob/main/predictor_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Environment Set-up
edit from [sam colab](https://colab.research.google.com/github/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb#scrollTo=MTeAdX_mHwAR)



If you're running this notebook locally using Jupyter, please clone `SAM-Med2D` into a directory named `SAM_Med2D`. Note that you do **not** need to install `segment_anything` in your local environment, as `SAM-Med2D` and `SAM` share function names that could lead to conflicts.

For Google Colab users: Set `using_colab=True` in the cell below before executing it. Although you can select 'GPU' under 'Edit' -> 'Notebook Settings' -> 'Hardware Accelerator', this notebook is designed to run efficiently in a CPU environment as well.



# SAM-Med2D generates predicted object masks based on prompts.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys

In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.array([0, 1, 0, 0.5])
    else:
        color = np.array([1, 0, 0, 0.5])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=100):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='.', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='.', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))


In [None]:
import SimpleITK as sitk
from skimage.measure import label, regionprops

In [None]:
def visualize(slce, predict_mask, gt, points, labels, fn):
    plt.figure(figsize=(5,5))
    plt.imshow(slce, cmap='gray')
    show_mask(gt, plt.gca(), random_color=True)
    show_mask(predict_mask, plt.gca())
    show_points(points, labels, plt.gca())
    plt.axis('off')
    plt.show()
    # plt.savefig(f"output/{fn}")

In [None]:
def visualize(slce, predict_mask, gt, box, fn):
    plt.figure(figsize=(5,5))
    plt.imshow(slce, cmap='gray')
    # show_mask(gt, plt.gca(), random_color=True)
    show_mask(predict_mask, plt.gca())
    # show_box(box, plt.gca())
    plt.axis('off')
    plt.show()
    # plt.savefig(f"output/{fn}")

## Example image

## Load SAM-Med2D model

In [None]:
from segment_anything import sam_model_registry
from segment_anything.predictor_sammed import SammedPredictor
from argparse import Namespace
args = Namespace()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.image_size = 256
args.encoder_adapter = True
args.sam_checkpoint = "sam-med2d_b.pth"
model = sam_model_registry["vit_b"](args).to(device)
predictor = SammedPredictor(model)

In [None]:
args = Namespace()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.image_size = 1024
args.encoder_adapter = False
args.sam_checkpoint = "/volume/willy-dev/sota/SAM-Med2D/sam_vit_b_01ec64.pth"
model2 = sam_model_registry["vit_b"](args).to(device)
predictor2 = SammedPredictor(model2)

Process the image to produce an image embedding by calling `SammedPredictor.set_image`. `SammedPredictor` remembers this embedding and will use it for subsequent mask prediction.

In [None]:
import glob
from tqdm import tqdm
import os
from collections import defaultdict

In [None]:
glob.glob('/volume/open-dataset-ssd/ai99/gen_data/*')

In [None]:
def norm_slce(slce):
    if torch.is_tensor(slce):
        slce = slce.cpu().numpy()
    slce -= slce.min()
    slce /= slce.max()
    slce *= 255
    slce = slce.astype(np.uint8)
    slce = np.stack([slce, slce, slce], axis=2)
    return slce

def get_side_pred(predictor, pred, img, gt, rot_point, centroid, offset):

    slce = norm_slce(rot_img[:, int(rot_point[1])])
    
    predictor.set_image(slce)
    input_point = [[int(rot_point[0]), centroid[0]]]
    input_label = [1]

    proj = torch.nonzero(gt[centroid[0], int(rot_point[1])])
    if proj.shape[0] > 0:

        proj_min, proj_max = proj.min().cpu().numpy(), proj.max().cpu().numpy()
        input_point += [[proj_min+5, centroid[0]], [proj_max-5, centroid[0]]]
        input_label += [1, 1]
        
    input_point = np.array(input_point)
    input_label = np.array(input_label)

    masks, scores, logits = predictor.predict(
                    point_coords=input_point,
                    point_labels=input_label,
                    multimask_output=True,
                )
    
    # visualize(slce, masks[0], gt[:, int(rot_point[1])].cpu().numpy(), input_point, input_label, f"{offset}.png")

    z = int(rot_point[1])

    pred[:, [z-1, z, z+1]] = torch.tensor(masks[0]).unsqueeze(1).repeat(1, 3, 1).cuda()

    return pred

import math
from torchvision.transforms.functional import rotate, InterpolationMode
import torch.nn.functional as F

def rotate_(origin, point, angle):
    """
    Rotate a point counterclockwise by a given angle around a given origin.

    The angle should be given in radians.
    """
    ox, oy = origin
    px, py = point

    qx = ox + math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy)
    qy = oy + math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy)
    return qx, qy

In [None]:
dices = []
dicesb = []
dicesc = []
mms = []
mmms = []

for di in tqdm(glob.glob('/volume/open-dataset-ssd/ai99/gen_data/meningioma/*')):

    try:
        img = sitk.ReadImage(f'{di}/axc.nii.gz')
        img = sitk.GetArrayFromImage(img)

        mask = sitk.ReadImage(f'{di}/seg.nii.gz')
        mask = sitk.GetArrayFromImage(mask)
    except:
        continue

    if img.shape[0] < 100:
        continue

    mask = label(mask)

    for i, prop in enumerate(regionprops(mask)):
        
        if prop.area == 0:
            continue
            
        centroid = np.array(prop.centroid).astype(int)
        
        rot_img = torch.tensor(img.astype(float)).cuda()
        rot_mask = torch.tensor(mask.astype(float)).cuda()
        rot_pred = torch.zeros_like(rot_img).float().cuda()

        rot_pred[centroid[0]] = rot_mask[centroid[0]]
        
        zx_img = rot_img.permute(1, 0, 2).contiguous()
        zx_mask = rot_mask.permute(1, 0, 2).contiguous()
        zx_pred = rot_pred.permute(1, 0, 2).contiguous()

        long_length = max(zx_img.shape[1:])

        pad_h1 = (long_length - zx_img.shape[1])//2
        pad_h2 = long_length - zx_img.shape[1] - pad_h1

        pad_w1 = (long_length - zx_img.shape[2])//2
        pad_w2 = long_length - zx_img.shape[2] - pad_w1

        zx_img = F.pad(zx_img, (pad_w2, pad_w1, pad_h2, pad_h1))
        zx_pred = F.pad(zx_pred, (pad_w2, pad_w1, pad_h2, pad_h1))
        zx_mask = F.pad(zx_mask, (pad_w2, pad_w1, pad_h2, pad_h1))

        degree = math.atan(img.shape[0]/img.shape[1]) / 0.0174533
        zx_img = rotate(zx_img, degree, interpolation=InterpolationMode.BILINEAR)
        zx_pred = rotate(zx_pred, degree)
        zx_mask = rotate(zx_mask, degree)

        zx_img = zx_img.permute(1, 0, 2).contiguous()
        zx_mask = zx_mask.permute(1, 0, 2).contiguous()
        zx_pred = zx_pred.permute(1, 0, 2).contiguous()

        dices_tp, dices_pred = 0, 0

        for z in zx_pred.nonzero()[:,0].unique():

            tar_img = zx_img[z].clone()
            tar_mask = zx_mask[z].clone()
            tar_pred = zx_pred[z].clone()

            tar_img[tar_img < 1] = 0
            crop_min_x, crop_min_y = tar_img.nonzero().min(0)[0]
            crop_max_x, crop_max_y = tar_img.nonzero().max(0)[0]
    
            tar_img = tar_img[crop_min_x:crop_max_x, crop_min_y:crop_max_y]
            tar_mask = tar_mask[crop_min_x:crop_max_x, crop_min_y:crop_max_y]
            tar_pred = tar_pred[crop_min_x:crop_max_x, crop_min_y:crop_max_y]

            zx_pred_nonzero = tar_pred.nonzero()

            if zx_pred_nonzero.shape[0] < 5: continue

            input_point, input_label = [], []

            # print(zx_pred_nonzero)

            point_a = zx_pred_nonzero.shape[0]//2
            point_b = int(zx_pred_nonzero.shape[0] * 0.25)
            point_c = int(zx_pred_nonzero.shape[0] * 0.75)
            input_point = zx_pred_nonzero[[point_a, point_b, point_c]][:, [1, 0]].cpu().numpy()
            input_label = [1, 1, 1]

            input_point = np.array(input_point)
            input_label = np.array(input_label)

            slce = norm_slce(tar_img)
            predictor2.set_image(slce)
            masks, scores, logits = predictor2.predict(
                            point_coords=input_point,
                            point_labels=input_label,
                            multimask_output=False,
                        )

            visualize(slce, masks[0], tar_mask.cpu().numpy(), input_point, input_label, "")

            dices_tp += (tar_mask.cpu().numpy()*masks[0]).sum()
            dices_pred += masks[0].sum()

            dice = 2 * ((tar_mask.cpu().numpy()*masks[0]).sum() + 1e-5)/(tar_mask.cpu().numpy().sum() + masks[0].sum() + 1e-5)
            dicesc.append(dice)

        dicesb_ = 2 * dices_tp / (dices_pred + prop.area)
        dicesb.append(dicesb_)
        
    break


In [None]:
math.atan(img.shape[0]/img.shape[1])/0.0174533 * 2

In [None]:
img.shape[1]

In [None]:
zx_img.shape

In [None]:
import plotly.express as px

In [None]:
np.save('tmp_dicesb.npy', dicesb)
np.save('tmp_dicesc.npy', dicesc)
np.save('tmp_mms.npy', mms)
np.save('tmp_mmms.npy', mmms)

In [None]:
dices = np.array(dices)
dicesb = np.array(dicesb)
dicesc = np.array(dicesc)

In [None]:
print(np.mean(dicesb))
plt.hist(dicesb,bins=20)
plt.show()

In [None]:
print(np.mean(dicesc))
plt.hist(dicesc, bins=20)
plt.show()

In [None]:
dicesc.shape

In [None]:
np.array(dices)[np.array(mms) > 0].shape

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(slce)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(slce)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()


In [None]:
plt.figure(figsize=(10,10))
plt.imshow(slce)
show_mask(gt, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()


## Optimizing Segmentation Results by Point Interaction

In [None]:
input_point1 = np.array([[169, 140]])
input_label1 = np.array([0])
input_points = np.concatenate((input_point, input_point1))
input_labels = np.concatenate((input_label, input_label1))
mask_inputs = torch.sigmoid(torch.as_tensor(logits, dtype=torch.float, device=device))

In [None]:
masks, scores, logits = predictor.predict(
    point_coords=input_points,
    point_labels=input_labels,
    mask_input = mask_inputs,
    multimask_output=True,
)
masks.shape  # (number_of_masks) x H x W

In [None]:

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_points, input_labels, plt.gca())
plt.axis('off')
plt.show()

## Specifying a specific object with a bounding box

The model can also take a box as input, provided in xyxy format.

In [None]:
image = cv2.imread('data_demo/images/s0114_111.png')
predictor.set_image(image)
input_box = np.array([89,43,113,64]) #

In [None]:
masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box,
    multimask_output=True,
)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()

In [None]:
torch.cuda.empty_cache()

## Multiple bounding box prediction results

In [None]:
input_boxes = torch.tensor([[72,110,136,143],[124,92,160,132]], device=predictor.device)

In [None]:
transformed_boxes = predictor.apply_boxes_torch(input_boxes, image.shape[:2], (args.image_size, args.image_size))
masks, _, _ = predictor.predict_torch(
    point_coords=None,
    point_labels=None,
    boxes=transformed_boxes,
    multimask_output=True,
)
print(transformed_boxes.shape)
print(masks.shape)  # (batch_size) x (num_predicted_masks_per_input) x H x W

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
    show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()

