# SMO with feature point prompts

In [1]:
import cv2
from segment_anything import build_sam, SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from PIL import Image, ImageDraw
import clip
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC

import tifffile as tiff
import os
from sklearn.metrics import jaccard_score

import json
from collections import defaultdict, Counter
import torchvision.datasets as dset
from pycocotools.coco import COCO
import xml.etree.ElementTree as ET

import os
import sys
import math

from yolo11.predict import predict_image
from util.common import show_points2, convert_box_xywh_to_xyxy, segment_image, show_anns, show_box_anns
from util.itm import retriev
from util.loss import get_scores_loss,get_indices_of_values_above_threshold_2
from util.reward import get_scores_reward,get_indices_of_values_above_threshold
from util.cal_iou import calculate_metrics
from util.cal_instance_iou import calculate_dice, calculate_miou
from util.nms import NMS
from util.featurepoint import edge, get_bz


<class 'clip.model.CLIP'>


In [2]:
time_start=time.time()

sam_checkpoint = "../checkpoints/sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)
print(f"Loading time: {time.time()-time_start:.2f} s")

Loading time: 4.85 s


# Load data, e.g. CTC DIC-C2DH-HeLa

In [3]:
image_path = "../datasets/CTC/DIC-C2DH-HeLa/01/t015.tif"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image0 = cv2.imread(image_path)
image0 = cv2.cvtColor(image0, cv2.COLOR_BGR2RGB)

In [4]:
# init predictor
predictor = SamPredictor(sam)
predictor.set_image(image)

# Feature Points Extraction

In [5]:
# get edge contours
cnts = edge(image)

In [6]:
# automatically take the number of the most significant edge
mini = min(len(p) for p in cnts)
# merge the most significant edge
mini_batch = []
for p in cnts:
    p = np.squeeze(p, 1)
    mini_batch.append(np.random.permutation(p)[:mini])

In [7]:
# prepare the coordinates and labels
input_points = np.concatenate(mini_batch)
points = torch.Tensor(input_points).to(predictor.device).unsqueeze(1).view(len(cnts),mini,2)
labels = torch.Tensor([int(l) for _, l in input_points]).to(predictor.device).unsqueeze(1).view(len(cnts),mini)
transformed_points = predictor.transform.apply_coords_torch(points, image.shape[:2])

# predict masks
masks_p0, scores_p, logits = predictor.predict_torch(
        point_coords=transformed_points,
        point_labels=labels,
        boxes=None,
        multimask_output=False,
)
masks_p = masks_p0.cpu().detach().numpy()

# merge masks
masks = []
for i,p in enumerate(zip(masks_p,scores_p)):
    m = {}
    m["segmentation"] = p[0][0]
    m["area"] = int(p[0].sum())
    a = np.where(p[0][0]==True)
    m["bbox"] = [a[1].min(), a[0].min(), a[1].max()-a[1].min(), a[0].max()-a[0].min()]
    m["predicted_iou"] = p[1][0].item()
    m["point_coords"] = points[i].cpu().detach().numpy().tolist()
    m["stability_score"] = p[1][0].item()
    m["crop_box"] = [0, 0, image.shape[1], image.shape[0]]
    masks.append(m)

# Calculate points_count

In [8]:
bz = get_bz(cnts,image0)

In [9]:
# calculate the number of points in each mask
points_count = []
for p in masks:
    points_count.append(np.count_nonzero(p["segmentation"] & bz))
points_count = np.array(points_count)

# Filter

In [10]:
# exclude large, small area items
h0,w0,_=image0.shape
s0=w0*h0
alpha = 0.45
gamma = 3e-3
h0,w0,_=image0.shape
s0=w0*h0
masks_f = []
for i,p in enumerate(masks):
    _,_,w,h = p["bbox"]
    if w*h >= alpha*s0:
        print(str(i)+': too big')
        continue
    elif p["area"] >= alpha*s0:
        print(str(i)+': too big')
        continue
    elif p["area"] <= gamma*s0:
        print(str(i)+': too small')
        continue
    else:
        masks_f.append(p)

7: too small
9: too small
18: too small
20: too small
21: too small
29: too small
33: too small
35: too small
43: too small
47: too small
54: too small
66: too big
69: too small
78: too small
80: too small
81: too small
86: too small
90: too small
94: too small
96: too small
97: too small
98: too small
99: too small
105: too small


In [11]:
masks=masks_f

In [12]:
# Cut out all masks
image = Image.open(image_path)
cropped_boxes = []

for mask in masks:
    # crop masks from input image
    cropped_boxes.append(segment_image(image, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"])))

# CLIP ITM Scores

In [13]:
scores = retriev(cropped_boxes, "HeLa cell")

# Compute Reward and drop bad ones

In [14]:
scores=get_scores_reward(masks,scores,points_count,device)

In [15]:
indices = get_indices_of_values_above_threshold(scores, 0.25)

In [16]:
image = Image.open(image_path)
segmentation_masks = []
result_crops = []
result_masks = []
result_scores = []
for seg_idx in indices:
    segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255)
    result_crop = cropped_boxes[seg_idx]
    segmentation_masks.append(segmentation_mask_image)
    result_masks.append(masks[seg_idx])
    result_scores.append(scores[seg_idx])
    result_crops.append(result_crop)


# Use NMS

In [17]:
# Concatenate targets and scores
t_scores=torch.tensor(result_scores)
b = []
for p in result_masks:
    xyxy = convert_box_xywh_to_xyxy(p["bbox"])
    b.append(xyxy)
t_boxes=torch.tensor(b)

In [18]:
n = NMS(t_boxes,t_scores,0.25,GIoU=True,eps=1e-7)
len(n)

14

In [19]:
# Select final results
n_crops = []
n_masks = []
n_scores = []
for i,v in enumerate(zip(result_crops,result_masks,result_scores)):
    if i in list(np.array(n)):
        n_crops.append(v[0])
        n_masks.append(v[1])
        n_scores.append(v[2])

# Results

In [20]:
original_image = Image.open(image_path)
array0 = np.zeros(original_image.size,dtype=bool).T
for p in n_masks:
    array0 = np.logical_or(array0, p["segmentation"])
mat_array0 = np.uint8(array0)

# Eval

In [21]:
# load gt
gt = tiff.imread(str("../datasets/CTC/DIC-C2DH-HeLa/01_ST/SEG/man_seg015.tif"))

In [22]:
gt=np.where(gt>0,np.ones_like(gt),np.zeros_like(gt))
img_gt=np.reshape(gt, (gt.shape[0], gt.shape[1])) * 255

# DICE and IoU

In [24]:
res=mat_array0.copy()
gt=gt.reshape(1,gt.shape[0],gt.shape[1])
res=res.reshape(1,res.shape[0],res.shape[1])

In [25]:
dice,iou=calculate_metrics(res,gt)
dice,iou

(0.8240552981709168, 0.7007602456894173)

In [26]:
instance_miou_score = calculate_miou(res, gt, num_classes=2)
print(f"mIoU Score: {instance_miou_score:.4f}")

mIoU Score: 0.7506
