In [None]:
import sys
sys.path.insert(1, '../SyMBac/')
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from joblib import Parallel, delayed
from glob import glob
from skimage.filters.thresholding import threshold_isodata, threshold_li, threshold_mean, threshold_otsu, \
    threshold_local, threshold_niblack, threshold_sauvola
import pandas as pd
import seaborn as sns
from skimage.measure import label
from natsort import natsorted

In [None]:
from cellpose import models, core
use_GPU = core.use_gpu()
from cellpose import transforms
from omnipose.utils import normalize99

In [None]:
def flatten(l):
    return [item for sublist in l for item in sublist]

In [None]:
model_dirs = natsorted(glob("synthetic_training_data/models/*2301*"))

In [None]:
model_dirs

In [None]:
def global_threshold(function, image):
    thresh = function(image)
    binary = image > thresh
    return binary

def default_threshold_local(image):
    return threshold_local(image, block_size=41)

def default_threshold_sauvola(image):
    return threshold_sauvola(image, window_size=103)

def default_threshold_niblack(image):
    return threshold_niblack(image, window_size=103)

In [None]:
def try_all_global_threshold(image):
    functions = [threshold_isodata, threshold_li, threshold_mean, threshold_otsu] 
    function_names = ["Isodata", "Li", "Mean", "Otsu"]
    binary_imgs = {}
    for function, name in zip(functions, function_names):
        try:
            binary_imgs[name] = global_threshold(function, image)
        except:
            binary_imgs[name] = np.nan
    #binary_imgs = {name : global_threshold(function, image) for (function, name) in zip(functions, function_names)}
    return binary_imgs

def try_all_local_threshold(image):
    functions = [default_threshold_sauvola, default_threshold_niblack] 
    function_names = ["Sauvola", "Niblack"]
    binary_imgs = {}
    for function, name in zip(functions, function_names):
        try:
            binary_imgs[name] = global_threshold(function, image)
        except:
            binary_imgs[name] = np.nan
    #binary_imgs = {name : global_threshold(function, image) for (function, name) in zip(functions, function_names)}
    return binary_imgs

In [None]:
images = sorted(glob("synthetic_training_data/*.png"))

In [None]:
use_gpu = use_GPU# = False
def segment_with_all_methods(image_dir):
    all_data = []
    image = np.array(Image.open(image_dir))
    thresholds = segment_with_omnipose(image, train_type="retrained") | try_all_global_threshold(image) | try_all_local_threshold(image) | segment_with_omnipose(image, train_type="bact_fluor_omni")
    for function, image in thresholds.items():
        image = np.array_split(image, 15, axis=1)
        for i, cell in enumerate(image):
            labeled_cell = label(cell)
            if len(np.unique(labeled_cell)) > 2:
                division_index = i
                all_data.append([function, division_index])
                break
    return all_data
for model_dir in model_dirs:
    def segment_with_omnipose(image, train_type = "retrained"):
        if train_type == "retrained":
            use_gpu = use_GPU# = False
            model = models.CellposeModel(gpu=use_gpu, pretrained_model=model_dir, omni=True, concatenation=True)
            key = "Omnipose retrained"
        elif train_type == "bact_fluor_omni":
            key = "Omnipose pretrained"
            use_gpu = use_GPU# = False
            model = models.CellposeModel(gpu=use_gpu, model_type="bact_fluor_omni", omni=True, concatenation=True)
        imgs = [image]


        nimg = len(imgs)

        for k in range(len(imgs)):
            img = transforms.move_min_dim(imgs[k]) # move the channel dimension last
            if len(img.shape)>2:
                imgs[k] = np.mean(img,axis=-1) # or just turn into grayscale 

            imgs[k] = normalize99(imgs[k])
            chans = [0,0] #this means segment based on first channel, no second channel

        n = [0] # make a list of integers to select which images you want to segment
        n = range(nimg) # or just segment them all

        # define parameters
        mask_threshold = -1
        verbose = 0 # turn on if you want to see more output
        use_gpu = use_GPU #defined above
        transparency = True # transparency in flow output
        rescale=None # give this a number if you need to upscale or downscale your images
        omni = True # we can turn off Omnipose mask reconstruction, not advised
        flow_threshold = 0. # default is .4, but only needed if there are spurious masks to clean up; slows down output
        resample = True #whether or not to run dynamics on rescaled grid or original grid
        masks, flows, styles = model.eval([imgs[i] for i in n],channels=chans,rescale=rescale,mask_threshold=mask_threshold,transparency=transparency,
                                          flow_threshold=flow_threshold,omni=omni,resample=resample,verbose=verbose)

        masks = masks[0]



        return {key : masks}

    all_data = Parallel(n_jobs=4)(delayed(segment_with_all_methods)(image_dir) for image_dir in (images))
    all_data = (flatten(all_data))

    all_data = pd.DataFrame(all_data)
    all_data.columns = ["Thresholding method", "Division Index"]
    all_data["Division Index"] -= 9
    sns.violinplot(data=all_data, x="Division Index", y="Thresholding method")
    plt.plot([0,0],[0,8], c = "k")
    plt.title(model_dir)
    plt.show()

In [None]:
sns.violinplot(data=all_data.query("`Thresholding method` == 'Omnipose retrained' or `Thresholding method` == 'Omnipose pretrained'"), x="Division Index", y="Thresholding method")
plt.plot([0,0],[0,1], c = "k")
plt.show()

In [None]:
all_data.groupby("Thresholding method").var() 

In [None]:
#all_data.to_pickle("division_data.pickle")