In [2]:
import zarr
import numpy as np
import anndata
import scipy as sp
import numpy as np
import matplotlib.pyplot as plt
import dask
from pyseq import image_analysis as ia
import warnings
warnings.filterwarnings('ignore')
from dask.distributed import Client
import numba
from os import makedirs, getcwd
import joblib
from dask_jobqueue import SLURMCluster
import skimage
import time
from os.path import exists, join
from dask.distributed import progress

In [3]:
def get_cluster(queue_name = 'pe2', log_dir=None):
    """ Make dask cluster w/ workers = 2 cores, 32 G mem, and 1 hr wall time.

        return cluster, client
    """
    if log_dir is None:
        log_dir = join(getcwd(),'dask_logs')
        makedirs(log_dir, exist_ok=True)

    cluster = SLURMCluster(
                queue = queue_name, 
                cores = 6 ,
                memory = '48G',
                log_directory=log_dir)
                #extra=["--lifetime", "55m", "--lifetime-stagger", "4m"])
    client = Client(cluster, timeout="50s")

    return cluster, client

cluster, client = get_cluster()

In [4]:
def scale_cluster(count): 
    cluster.scale(count)
    return cluster.dashboard_link
scale_cluster(5)

'http://10.4.200.80:8787/status'

In [5]:
from cellpose import core, utils, io, models, metrics
from glob import glob

In [6]:
# start logger (to see training across epochs)
logger = io.logger_setup()
model = models.CellposeModel(model_type='TN2')

2022-07-15 16:49:10,500 [INFO] WRITING LOG OUTPUT TO /gpfs/commons/home/jsingh/.cellpose/run.log
2022-07-15 16:49:10,503 [INFO] >> TN2 << model set to be used
2022-07-15 16:49:10,505 [INFO] >>>> using CPU
2022-07-15 16:49:11,131 [INFO] >>>> model diam_mean =  30.000 (ROIs rescaled to this size during training)


In [7]:
import skimage
im = ia.get_HiSeqImages(image_path = '/gpfs/commons/home/jsingh/zarrs/m387ntga2.zarr')
labels = skimage.io.imread('/gpfs/commons/groups/nygcfaculty/PySeq/20210428_mouse_genotype_2/segmented_sections/m387ntga2_labels.tiff')

ImageAnalysis::Opened m387ntga2 


In [8]:
one_z_plane = im.im.sel(obj_step = 8498, channel = 558, cycle=1)

In [10]:
imgs = dask.array.from_array(one_z_plane.values)

In [11]:
imgs = imgs[6000:8000, 6000:8000]

In [18]:
imgs = imgs.compute()

In [15]:
channels = [[0,0]]

In [16]:
def model_evaluation(image):
    channels = [[0,0]]
    masks, flows, styles = model.eval(image, diameter=None, channels=channels)
    return flows
    

In [19]:
out = dask.array.map_blocks(model_evaluation,imgs)

2022-07-15 16:50:51,577 [INFO] No cell pixels found.


In [None]:
model.eval(imgs, diameter=None, channels=channels)

In [21]:
out

Unnamed: 0,Array,Chunk
Bytes,8 B,8.0 B
Shape,(),()
Count,2 Tasks,1 Chunks
Type,object,numpy.ndarray
Array Chunk Bytes 8 B 8.0 B Shape () () Count 2 Tasks 1 Chunks Type object numpy.ndarray,,

Unnamed: 0,Array,Chunk
Bytes,8 B,8.0 B
Shape,(),()
Count,2 Tasks,1 Chunks
Type,object,numpy.ndarray


In [None]:
masks, flows, styles = dask.compute(model.eval(arr, diameter=None, channels=channels))

In [None]:
import torch
torch.distributed.is_available()

In [None]:
torch.distributed.init_process_group(backend = 'gloo', rank = 0,  world_size=0)

In [None]:
import numpy as np
import dask
import dask.array as da
import dask.delayed as delayed
import ClusterWrap
import time
from cellpose import models


def distributed_eval(
    image,
    blocksize,
    mask=None,
    preprocessing_steps=[],
    model_kwargs={},
    eval_kwargs={},
    cluster_kwargs={},
):
    """
    """

    # set eval defaults
    if 'diameter' not in eval_kwargs.keys():
        eval_kwargs['diameter'] = 30

    # compute overlap
    overlap = eval_kwargs['diameter'] * 2

    # compute mask to array ratio
    if mask is not None:
        ratio = np.array(mask.shape) / image.shape

    # pipeline to run on each block
    def preprocess_and_segment(block, mask=None, block_info=None):

        # get block origin
        origin = np.array(block_info[0]['chunk-location'])
        origin = origin * blocksize - overlap

        # check mask foreground
        if mask is not None:
            mo = np.round(origin * ratio).astype(np.uint16)
            mo = np.maximum(0, mo)
            ms = np.round(blocksize * ratio).astype(np.uint16)
            mask_block = mask[mo[0]:mo[0]+ms[0],
                              mo[1]:mo[1]+ms[1],
                              mo[2]:mo[2]+ms[2],]

            # if there is no foreground, return null result
            if np.sum(mask_block) < 1:
                return np.zeros(block.shape, dtype=np.int64)

        # run preprocessing steps
        image = np.copy(block)
        for pp_step in preprocessing_steps:
            image = pp_step[0](image, **pp_step[1])

        # segment
        model = models.Cellpose(**model_kwargs)
        return model.eval(image, **eval_kwargs)[0]

    # start cluster
    with ClusterWrap.cluster(**cluster_kwargs) as cluster:

        # wrap dataset as a dask object
        if isinstance(image, np.ndarray):
            future = cluster.client.scatter(image)
            image_da = da.from_delayed(
                future, shape=image.shape, dtype=image.dtype,
            )
            image_da = image_da.rechunk(blocksize)
            image_da.persist()
            time.sleep(30)  ### a little time for workers to be allocated
            cluster.client.rebalance()
    
        # a full dataset as a zarr array
        else:
            image_da = da.from_array(image, chunks=blocksize)

        # wrap mask
        mask_d = delayed(mask) if mask is not None else None

        # distribute
        # TODO: RESULT SHOULD BE WRITTEN TO ZARR
        #    OR RETURN DASK ARRAY AND HAVE AN EXECUTE FUNCTION
        #    WITH COMPUTE OR TO_ZARR OPTIONS
        segmentation = da.map_overlap(
            preprocess_and_segment, image_da,
            mask=mask_d,
            depth=overlap,
            dtype=np.int64,
            boundary=0,
            trim=False,
            chunks=[x+2*overlap for x in blocksize],
        ).compute()

        # TODO: STITCH!

        # return result
        return segmentation

In [None]:
!pip install ClusterWrap