# SMO

In [5]:
import cv2
from segment_anything import build_sam, SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from PIL import Image, ImageDraw
import clip
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC

import tifffile as tiff
import os
from sklearn.metrics import jaccard_score

import json
from collections import defaultdict, Counter
import torchvision.datasets as dset
from pycocotools.coco import COCO
import xml.etree.ElementTree as ET

import os
import sys
import math

from yolo11.predict import predict_image
from util.common import show_points2, convert_box_xywh_to_xyxy, segment_image
from util.itm import retriev
from util.loss import get_scores_loss,get_indices_of_values_above_threshold_2
from util.cal_iou import calculate_metrics
from util.cal_instance_iou import calculate_dice, calculate_miou
from util.nms import NMS


In [6]:
time_start=time.time()

sam_checkpoint = "../checkpoints/sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)
print(f"Loading time: {time.time()-time_start:.2f} s")

Loading time: 4.97 s


# Load CTC

In [7]:
image_path = "../datasets/CTC/Fluo-N2DH-GOWT1/01/t002.tif"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image0 = cv2.imread(image_path)
image0 = cv2.cvtColor(image0, cv2.COLOR_BGR2RGB)

In [8]:
# generate all masks way 1
# masks = mask_generator.generate(image)

In [9]:
# generate all masks way 2 with parameters
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=20,
    points_per_batch=128,
    pred_iou_thresh=0.96,
    stability_score_thresh=0.96,
    stability_score_offset=1.0,
    box_nms_thresh=0.7,
    crop_n_layers=0,
    crop_nms_thresh=0.7,
    crop_overlap_ratio=512/1500,
    crop_n_points_downscale_factor=2,
    point_grids=None,
    min_mask_region_area=1000,  # Requires open-cv to run post-processing
    max_mask_region_area=0,
)
masks = mask_generator.generate(image)

In [10]:
# check the number of masks
print(len(masks))

26


# Filter

In [15]:
# exclude large, small area items
h0,w0,_=image0.shape
s0=w0*h0
alpha=4e-2
gamma=1e-4
h0,w0,_=image0.shape
s0=w0*h0
masks_f = []
for i,p in enumerate(masks):
    _,_,w,h = p["bbox"]
    if p["area"] >= alpha*s0:
        print(str(i)+': too big')
        continue
    elif p["area"] <= gamma*s0:
        print(str(i)+': too small')
        continue
    else:
        masks_f.append(p)

25: too big


In [16]:
masks=masks_f

In [17]:
# Cut out all masks
image = Image.open(image_path)
cropped_boxes = []

for mask in masks:
    # 把所有mask从输入图片中抠出来
    cropped_boxes.append(segment_image(image, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"])))

# CLIP ITM Scores

In [18]:
scores = retriev(cropped_boxes, "mouse stem cell")

# Compute Loss and drop bad ones

In [19]:
scores=get_scores_loss(masks,scores)

In [None]:
indices = get_indices_of_values_above_threshold_2(scores, 0.016)

In [21]:
image = Image.open(image_path)
segmentation_masks = []
result_crops = []
result_masks = []
result_scores = []
for seg_idx in indices:
    segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255)
    result_crop = cropped_boxes[seg_idx]
    segmentation_masks.append(segmentation_mask_image)
    result_masks.append(masks[seg_idx])
    result_scores.append(1-scores[seg_idx]) # 1-loss
    result_crops.append(result_crop)


# Use NMS

In [22]:
# Concatenate targets and scores
t_scores=torch.tensor(result_scores)
b = []
for p in result_masks:
    xyxy = convert_box_xywh_to_xyxy(p["bbox"])
    b.append(xyxy)
t_boxes=torch.tensor(b)

In [23]:
n = NMS(t_boxes,t_scores,0.8,GIoU=True,eps=1e-7)
len(n)

25

In [24]:
# Select final results
n_crops = []
n_masks = []
n_scores = []
for i,v in enumerate(zip(result_crops,result_masks,result_scores)):
    if i in list(np.array(n)):
        n_crops.append(v[0])
        n_masks.append(v[1])
        n_scores.append(v[2])

# Results

In [25]:
original_image = Image.open(image_path)
array0 = np.zeros(original_image.size,dtype=bool).T
for p in n_masks:
    array0 = np.logical_or(array0, p["segmentation"])
mat_array0 = np.uint8(array0)

# Eval

In [26]:
# load gt
gt = tiff.imread(str("../datasets/CTC/Fluo-N2DH-GOWT1/01_ST/SEG/man_seg002.tif"))

In [None]:
gt=np.where(gt>0,np.ones_like(gt),np.zeros_like(gt))

# DICE and IoU

In [28]:
res=mat_array0.copy()
gt=gt.reshape(1,gt.shape[0],gt.shape[1])
res=res.reshape(1,res.shape[0],res.shape[1])

In [30]:
dice,iou=calculate_metrics(res,gt)
dice,iou

(0.978380519822726, 0.9576760612013341)

In [31]:
instance_miou_score = calculate_miou(res, gt, num_classes=2)
print(f"mIoU Score: {instance_miou_score:.4f}")

mIoU Score: 0.9773
