In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt 
from scipy.ndimage import gaussian_filter, gaussian_filter1d
from cellpose import io, transforms, utils, models, dynamics
from tqdm import trange
import gc
from glob import glob
import cv2
from natsort import natsorted 
import shutil
from pathlib import Path
import torch 

device = torch.device("cuda")

root = Path("Path/To/DAPIfolder")
img = io.imread(root / "DAPI.tiff")

In [None]:
#masks = io.imread(root / "DAPIhippocampus_crop_masks.tif")

In [None]:
#plt.imshow(img[125,:,:], aspect="auto", vmin=0, vmax=1)
print("\r shape: {0}".format(img.shape))
#print("\r shape: {0}".format(masks.shape))

#masks_full = np.zeros(img.shape, dtype="uint16")

In [None]:
# optional - crop the stack to reduce runtime / memory usage
#y0, y1 = 280, 2300
#x0, x1 = 200, 4368
#z0, z1 = 15, 220
#masks_full[z0:z1, y0:y1, x0:x1] = masks;
#img = img[z0:z1, y0:y1, x0:x1]

In [None]:
# normalize the z-stack
img = transforms.normalize99(img)

In [None]:
# initialize the models
io.logger_setup()
modelXY = models.CellposeModel(pretrained_model= root / "models/Modelxy", gpu=True)
modelYZ = models.CellposeModel(pretrained_model= root / "models/Modelyz", gpu=True)

diameterXY = modelXY.diam_labels
diameterYZ = modelYZ.diam_labels

In [None]:
# compute the flows

nchan = 2

shape = img.shape
cellprob = np.zeros((3, *shape), "float32")
dP = np.zeros((3, 2, *shape), "float32")

pm = [(0,1,2), (1,0,2), (2,0,1)]
ipm = [(0,1,2), (1,0,2), (1,2,0)]

for p in range(0, 3):
    print(p)
    img0 = img.copy().transpose(pm[p])
    y = np.zeros((3, *img0.shape), "float32")
    for z in trange(img0.shape[0]):
        if p==0:
            _, flows, _ = modelXY.eval(img0[z], batch_size=128, compute_masks=False, diameter=diameterXY)
        else:
            _, flows, _ = modelYZ.eval(img0[z], batch_size=128, compute_masks=False, diameter=diameterYZ)
        y[:2, z] = flows[1].squeeze()
        y[-1, z] = flows[2].squeeze()
    dP[p, 0] = y[0].transpose(ipm[p])
    dP[p, 1] = y[1].transpose(ipm[p])
    cellprob[p] = y[-1].transpose(ipm[p])


In [None]:
# optional - save intermediates
#np.save(root / "dP.npy", dP)
#np.save(root / "cellprob.npy", cellprob)

In [None]:
# average predictions from 3 views
cellprob_all = cellprob.mean(axis=0)
dP_all = np.stack((dP[1][0] + dP[2][0], dP[0][0] + dP[2][1], dP[0][1] + dP[1][1]),
                          axis=0) # (dZ, dY, dX)

In [None]:
# compute masks (most memory intensive)
masks_pred, p = dynamics.compute_masks(dP_all, cellprob_all, do_3D=True,
                                        device=device)

In [None]:
# remove cells below a certain size
masks_pred0 = utils.fill_holes_and_remove_small_masks(masks_pred, min_size=1000)


In [None]:
# save crop and masks
io.imsave(root / "Masks.tiff", masks_pred0)
