# SMO for MoNuSeg or TNBC

In [None]:
import cv2
from segment_anything import build_sam, SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from PIL import Image, ImageDraw
import clip
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC

import tifffile as tiff
import os
from sklearn.metrics import jaccard_score

import json
from collections import defaultdict, Counter
import torchvision.datasets as dset
from pycocotools.coco import COCO
import xml.etree.ElementTree as ET

import os
import sys
import math

from yolo11.predict import predict_image
from util.common import show_points2, convert_box_xywh_to_xyxy, segment_image
from util.itm import retriev
from util.loss import get_scores_loss,get_indices_of_values_above_threshold_2
from util.cal_iou import calculate_metrics
from util.cal_instance_iou import calculate_dice, calculate_miou
from util.nms import NMS


In [None]:
time_start=time.time()

sam_checkpoint = "../checkpoints/sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

print(f"Loading time: {time.time()-time_start:.2f} s")

# Load MoNuSeg

In [3]:
im_path="../datasets/MoNuSeg/MoNuSegTestData/TissueImages/"
# liat all images
img_res=os.listdir(im_path)
img_res=[s for s in img_res if "tif" in s]
img_res=sorted(img_res)

In [97]:
# set one image
index=0
image_path=im_path+img_res[index]
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image0 = cv2.imread(image_path)
image0 = cv2.cvtColor(image0, cv2.COLOR_BGR2RGB)

In [None]:
# set SAM predictor
predictor = SamPredictor(sam)
predictor.set_image(image)

# Point layout prompts or Box layout prompts

In [94]:
# layout='point'
layout='box'
# mode='manual'
mode='automatic'

In [None]:
if layout=='point' and mode=='manual':
    points_nuclei = []
    with open('manual_points.txt', 'r') as f:
        for line in f:
            x, y = map(float, line.strip().split())
            points_nuclei.append(np.array([x, y]).reshape(1,2))

if layout=='point' and mode=='automatic':
    points_nuclei = []
    model_path='../checkpoints/best.pt'
    image_path_yolo='../datasets/MoNuSeg/MoNuSegTestData/test/images/'+img_res[index].split('tif')[0]+'png'
    points_nuclei=predict_image(model_path,image_path_yolo,isbox=False)

In [73]:
if layout=='box' and mode=='manual':
    boxes = []
    with open('manual_boxes.txt', 'r') as f:
        for line in f:
            x1, y1, x2, y2 = map(float, line.strip().split())
            boxes.append(np.array([x1, y1, x2, y2]))

In [None]:
if layout=='box' and mode=='automatic':
    boxes=[]
    model_path='../checkpoints/best.pt'
    image_path_yolo='../datasets/MoNuSeg/MoNuSegTestData/test/images/'+img_res[index].split('tif')[0]+'png'
    boxes=predict_image(model_path,image_path_yolo,isbox=True)

In [23]:
if layout=='point':
    # prepare input points and labels
    input_points = np.concatenate(points_nuclei)
    points = torch.Tensor(input_points).to(predictor.device).unsqueeze(1).view(len(points_nuclei),1,2)
    labels = torch.Tensor([int(l) for _, l in input_points]).to(predictor.device).unsqueeze(1).view(len(points_nuclei),1)
    transformed_points = predictor.transform.apply_coords_torch(points, image.shape[:2])

    # split points into batches, save memory
    if transformed_points.shape[0] > 500:
        cir = math.ceil(transformed_points.shape[0] / 500)
    else:
        cir = 1
    # predict masks
    mmm=[]
    sss=[]
    for pc in range(cir):
        mp, sp, _ = predictor.predict_torch(
                point_coords=transformed_points[pc*500:500*(pc+1)],
                point_labels=labels[pc*500:500*(pc+1)],
                boxes=None,
                multimask_output=False,
        )
        mmm.append(mp.cpu().detach().numpy())
        sss.append(sp.cpu().detach().numpy())
    masks_p=np.concatenate(mmm,axis=0)
    scores_p=np.concatenate(sss,axis=0)

    # make masks
    masks = []
    for i,p in enumerate(zip(masks_p,scores_p)):
        m = {}
        m["segmentation"] = p[0][0]
        m["area"] = int(p[0].sum())
        a = np.where(p[0][0]==True)
        m["bbox"] = [a[1].min(), a[0].min(), a[1].max()-a[1].min(), a[0].max()-a[0].min()]
        m["predicted_iou"] = p[1][0].item()
        m["point_coords"] = points[i].cpu().detach().numpy().tolist()
        m["stability_score"] = p[1][0].item()
        m["crop_box"] = [0, 0, image.shape[1], image.shape[0]]
        masks.append(m)

In [98]:
if layout=='box':
    # prepare input boxes
    input_boxes = torch.tensor([
        boxes
    ], device=predictor.device).squeeze()
    transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])

    # split boxes into batches, save memory
    if transformed_boxes.shape[0] > 500:
        cir = math.ceil(transformed_boxes.shape[0] / 500)
    else:
        cir = 1
    # predict masks
    mmm=[]
    sss=[]
    for pc in range(cir):
        mp, sp, _ = predictor.predict_torch(
                point_coords=None,
                point_labels=None,
                boxes=transformed_boxes[pc*500:500*(pc+1)],
                multimask_output=False,
        )
        mmm.append(mp.cpu().detach().numpy())
        sss.append(sp.cpu().detach().numpy())
    masks_p=np.concatenate(mmm,axis=0)
    scores_p=np.concatenate(sss,axis=0)

    # make masks
    masks = []
    for i,p in enumerate(zip(masks_p,scores_p,boxes)):
        m = {}
        m["segmentation"] = p[0][0]
        m["area"] = int(p[0].sum())
        m["bbox"] = [p[2][0],p[2][1],p[2][2]-p[2][0],p[2][3]-p[2][1]]
        m["predicted_iou"] = p[1][0].item()
        m["stability_score"] = p[1][0].item()
        m["crop_box"] = [0, 0, image.shape[1], image.shape[0]]
        masks.append(m)

# Filter

In [99]:
# exclude large, small area items
h0,w0,_=image0.shape
s0=w0*h0
alpha=0.00195
gamma=2.8e-05
h0,w0,_=image0.shape
s0=w0*h0
masks_f = []
for i,p in enumerate(masks):
    _,_,w,h = p["bbox"]
    if p["area"] >= alpha*s0:
        print(str(i)+': too big')
        continue
    elif p["area"] <= gamma*s0:
        print(str(i)+': too small')
        continue
    else:
        masks_f.append(p)

In [100]:
masks=masks_f

In [101]:
# Cut out all masks
image = Image.open(image_path)
cropped_boxes = []

for mask in masks:
    # crop masks from input image
    cropped_boxes.append(segment_image(image, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"])))

# CLIP ITM Scores

In [102]:
scores = retriev(cropped_boxes, "cell")

# Compute Loss and drop bad ones

In [103]:
scores=get_scores_loss(masks,scores)

In [104]:
indices = get_indices_of_values_above_threshold_2(scores, 0.075)

In [105]:
segmentation_masks = []
result_crops = []
result_masks = []
result_scores = []
for seg_idx in indices:
    segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255)
    result_crop = cropped_boxes[seg_idx]
    segmentation_masks.append(segmentation_mask_image)
    result_masks.append(masks[seg_idx])
    result_scores.append(1-scores[seg_idx]) # 1-loss
    result_crops.append(result_crop)

# Use NMS

In [106]:
# Concatenate targets and scores
t_scores=torch.tensor(result_scores)
b = []
for p in result_masks:
    xyxy = convert_box_xywh_to_xyxy(p["bbox"])
    b.append(xyxy)
t_boxes=torch.tensor(b)

In [None]:
n = NMS(t_boxes,t_scores,0.9,GIoU=True,eps=1e-7)
len(n)

In [108]:
# Select final results
n_crops = []
n_masks = []
n_scores = []
for i,v in enumerate(zip(result_crops,result_masks,result_scores)):
    if i in list(np.array(n)):
        n_crops.append(v[0])
        n_masks.append(v[1])
        n_scores.append(v[2])

# Results

In [109]:
original_image = Image.open(image_path)
array0 = np.zeros(original_image.size,dtype=bool).T
for p in n_masks:
    array0 = np.logical_or(array0, p["segmentation"])
mat_array0 = np.uint8(array0)
# Save segmentation results
# tiff.imsave(str("./results/yolo11monuseg_mask_"+img_res[index]), mat_array0)

# Eval

In [110]:
# load gt
label_path="../datasets/MoNuSeg/MoNuSegTestData/labelcol/"
# list all images
ann_img=os.listdir(label_path)
ann_img=[s for s in ann_img if "png" in s]
ann_img=sorted(ann_img)

In [111]:
gt_png=cv2.imread(label_path+ann_img[index])
gray_gt=cv2.cvtColor(gt_png,cv2.COLOR_BGR2GRAY)
mat_array_anno=cv2.threshold(gray_gt,200,1,cv2.THRESH_BINARY)[1]

# DICE and IoU

In [112]:
res=mat_array0.copy()
gt=mat_array_anno.copy()
gt=gt.reshape(1,gt.shape[0],gt.shape[1])
res=res.reshape(1,res.shape[0],res.shape[1])

In [None]:
dice,iou=calculate_metrics(res,gt)
dice,iou

In [None]:
instance_miou_score = calculate_miou(res, gt, num_classes=2)
print(f"mIoU Score: {instance_miou_score:.4f}")