In [None]:
import os, gc
import numpy as np
from scellseg.io import imsave, imread, save_masks_noflows
from scellseg.models import sCellSeg


finetune_dir = r''  # Input the path of the model file
query_img_dir = r'' # Input the path of your query imgs
query_img_paths = os.listdir(query_img_dir)  #
query_imgs = [imread(os.path.join(query_img_dir, i))  for i in query_img_paths]  # Read your images
query_imgs = np.array(query_imgs)

# you should check the shape of query_imgs in (n, h, w, c), ndarray
# if it is (n, h, w),
# use:
# query_imgs = [query_imgi[..., np.newaxis] for query_imgi in query_imgs]
# query_imgs = np.array(query_imgs)

eval_batch_size = 16 # Set batch size
diameter = 30.0 # Set Mean diameter of cells in your image
channel = [2, 1]  #  Set the channel you want to segment, you can also provide a chan2 like nuclei channel for better

In [None]:
model_type = 'scellseg'
finetune_model = os.path.join(finetune_dir, model_type)
model = sCellSeg(gpu=True, model_type=model_type, diam_mean=30.0, task_mode='cellpose', attn_on=True, dense_on=True, style_scale_on=True)
masks, _, _ = model.inference(finetune_model=finetune_model, net_avg=False,
                                  query_images=query_imgs, channel=channel,
                                  diameter=diameter,
                                  resample=False, flow_threshold=0.4,
                                  cellprob_threshold=0.5,
                                  eval_batch_size=eval_batch_size,
                                  postproc_mode='cellpose')

In [None]:
save_mask_dir = r''  # Path for output masks
if not os.path.isdir(save_mask_dir): os.makedirs(save_mask_dir)
query_img_names = [i.split('.')[0] for i in query_img_paths]
save_img_paths = [os.path.join(save_mask_dir, query_img_name) for query_img_name in query_img_names]
for i, save_img_path in enumerate(save_img_paths):
    maski = masks[i]
    maski = maski.astype(np.uint16) if maski.max() < 2 ** 16 - 1 else maski.astype(np.uint32)
    imsave(save_img_path + '_masks.png', maski)  # 保存masks

In [None]:
save_mask_show_dir = r''  # Path for visualization of output masks, this can help you have a better a look at the segmentation results
if not os.path.isdir(save_mask_show_dir): os.makedirs(save_mask_show_dir)
save_img_paths_pre = [os.path.join(save_mask_show_dir, query_img_name + '.png') for query_img_name in query_img_names]
save_masks_noflows(query_imgs, masks, save_img_paths_pre)