# Cellpose 2D prediction
## Performs Cellpose 2D segmentation on a given set of images 

## 1. Load dependencies

In [None]:
import numpy as np
import glob
import time, os, sys
from cellpose import core, utils, io, models, metrics, plot
import matplotlib.pyplot as plt

## 2. Load images

In [None]:
mark_image_directories = '/path/to/marker/images/' # path to cell type marker maximum projection images (if a second channel should be defined for cellpose segmentation, make sure you provide multi-channel images)
mark_images = sorted(glob.glob(mark_image_directories + "/*.tif"), key = len)
mark_image_list = [io.imread(f) for f in mark_images]

## 3. Load model

In [None]:
model = models.CellposeModel(gpu=False, model_type="cyto")

# if you have an own pre-trained model
#modelPath = '/path/to/model'
#model = models.CellposeModel(gpu=False, pretrained_model=modelPath)

## 3. Inspect channels of first image 

In [None]:
plt.rcParams['figure.dpi'] = 150
if len(mark_image_list[0].shape) > 2:
    nc = mark_image_list[0].shape[0]
    img = mark_image_list[0].transpose(2,1,0)
    ratio = np.amax(img) / 256
    img = (img/ratio).astype("uint8")
    fig = plt.figure(figsize=(50, 30))
    for c in range(nc):
        fig.add_subplot(1, nc+1, c+1)
        plt.imshow(img[:,:,c], cmap="Greys_r")
        plt.axis('off')
        plt.title(f"Channel {c}")
    fig.add_subplot(1, nc+1, nc+1)
    plt.imshow(img)
    plt.tight_layout()
    plt.axis('off')
    plt.title("Merge")
    plt.style.use("dark_background")
else:
    img = mark_image_list[0]
    fig = plt.figure(figsize=(30, 20))
    plt.imshow(img, cmap="Greys_r")
    plt.tight_layout()
    plt.axis('off')
    plt.style.use("dark_background")

## 4. Define prediction parameters

In [None]:
"""
# define CHANNELS to run segementation on
# grayscale=0, R=1, G=2, B=3
# channels = [cytoplasm, nucleus]
# if NUCLEUS channel does not exist, set the second channel to 0
# channels = [0,0]
# IF ALL YOUR IMAGES ARE THE SAME TYPE, you can give a list with 2 elements
# channels = [0,0] # IF YOU HAVE GRAYSCALE
# channels = [2,3] # IF YOU HAVE G=cytoplasm and B=nucleus
# channels = [2,1] # IF YOU HAVE G=cytoplasm and R=nucleus
"""
chan = 2 # primary channel to segment on
chan2 = 1 # optional secondary channel
channels = [chan, chan2]
diameter = 22.2
flow_threshold = 0.4
cellprob_threshold = 0

## 5. Prediction

In [None]:
masks, flows, styles = model.eval(mark_image_list, 
                                  channels=channels,
                                  diameter=diameter,
                                  flow_threshold=flow_threshold,
                                  cellprob_threshold=cellprob_threshold
                                  )

## 6. Inspect prediction results

In [None]:
nimg = len(mark_image_list)
for idx in range(nimg):
    maski = masks[idx]
    flowi = flows[idx][0]

    fig = plt.figure(figsize=(50,30))
    img = mark_image_list[idx]
    ratio = np.amax(img) / 256
    img = (img/ratio).astype("uint8")
    plot.show_segmentation(fig, img, maski, flowi)
    plt.tight_layout()
    plt.style.use("dark_background")
    plt.show()

## 7. Save results

In [None]:
io.save_masks(mark_image_list, 
              masks, 
              flows,
              mark_images, 
              channels=channels,
              png=False, # save masks as PNGs and save example image
              tif=True, # save masks as TIFFs
              save_txt=True, # save txt outlines for ImageJ
              save_flows=True, # save flows as TIFFs
              save_outlines=True, # save outlines as PNGs
              in_folders=True # save everything in separated folders
              )