# 1. Convert a Tiff into Zarr format with `bioformats2raw`

Open a terminal (from the Lanucher if you prefer) and use the following comand:

```
bioformats2raw CMU-1.svs CMU-1.zarr --use-existing-resolutions -p
```

---
# 2. Compute on Dask arrays

In [None]:
import zarr
import dask
import dask.array as da
import numpy as np
import matplotlib.pyplot as plt

In [None]:
z_grp = zarr.open("CMU-1.zarr", mode="r")

In [None]:
z_grp.info

In [None]:
z_grp["0"].info

In [None]:
da_arr = da.from_zarr("CMU-1.zarr", component="0/0")

In [None]:
da_arr

In [None]:
da_arr = da_arr[0, :, 0].rechunk((3, 512, 512))

In [None]:
da_arr

In [None]:
da_arr = np.moveaxis(da_arr, 0, -1)

In [None]:
da_arr

In [None]:
da_sel = da_arr[16_000:16_000 + 512, 8_000:8_000 + 512]

In [None]:
da_sel

# [Use case] Segmentation of nuclei in WSI (Cellpose)

In [None]:
from cellpose import models, transforms
import torch

gpu = torch.cuda.is_available()
model_type = "cyto3"

cellpose_model = models.CellposeModel(gpu=gpu, model_type=model_type)

---
# 1.1 Compute as numpy array (`.compute()`)

In [None]:
img_t = transforms.convert_image(da_sel.compute(), channel_axis=2, channels=[0, 0])
img_t = transforms.normalize_img(img_t, invert=False, axis=2)

labels, _, _ = cellpose_model.eval(img_t[None, ...], diameter=None, flow_threshold=None, channels=[0, 0])

In [None]:
labels.shape

In [None]:
import skimage

In [None]:
plt.imshow(da_sel)
plt.imshow(skimage.color.label2rgb(labels), alpha=0.5)

---
## 1.2 Compute the segmentation lazily with Dask (`delayed`)

In [None]:
@dask.delayed
def cellpose_segment_delayed(img, cellpose_model):
    img_t = transforms.convert_image(img, channel_axis=2, channels=[0, 0])
    img_t = transforms.normalize_img(img_t, invert=False, axis=2)

    labels, _, _ = cellpose_model.eval(img_t[None, ...], diameter=None, flow_threshold=None, channels=[0, 0])
    return labels

In [None]:
delayed_labels_sel = cellpose_segment_delayed(da_sel, cellpose_model)

In [None]:
delayed_labels_sel

In [None]:
plt.imshow(delayed_labels_sel.compute())

---
# 2. Distribute computation with `map_blocks`

In [None]:
def cellpose_segment(img, cellpose_model):
    img_t = transforms.convert_image(img, channel_axis=2, channels=[0, 0])
    img_t = transforms.normalize_img(img_t, invert=False, axis=2)

    labels, _, _ = cellpose_model.eval(img_t[None, ...], diameter=None, flow_threshold=None, channels=[0, 0])
    return labels

In [None]:
da_labels = da.map_blocks(
    cellpose_segment,
    da_arr,
    cellpose_model,
    drop_axis=(2, ),
    dtype=np.int32,
    meta=np.empty(0, dtype=np.int32)
)

In [None]:
da_labels

In [None]:
plt.imshow(da_labels[16_000:16_000+512, 8_000:8_000+512])

In [None]:
plt.imshow(da_arr[16_000:16_000+512, 8_000:8_000+512])
plt.imshow(skimage.color.label2rgb(da_labels[16_000:16_000+512, 8_000:8_000+512].compute()), alpha=0.5)

---
# 3. Debug `map_blocks` computations

In [None]:
def cellpose_segment(img, cellpose_model, block_info=None):
    print(block_info, img.shape)
    img_t = transforms.convert_image(img, channel_axis=2, channels=[0, 0])
    img_t = transforms.normalize_img(img_t, invert=False, axis=2)

    labels, _, _ = cellpose_model.eval(img_t[None, ...], diameter=None, flow_threshold=None, channels=[0, 0])
    return labels

In [None]:
da_labels = da.map_blocks(
    cellpose_segment,
    da_arr,
    cellpose_model,
    drop_axis=(2, ),
    dtype=np.int32,
    meta=np.empty(0, dtype=np.int32)
)

In [None]:
labels = da_labels[10_000:10_000+512, 10_000:10_000+512].compute()

---
# 4. Return arrays with different shape with `map_blocks`

## 4.1 Compute features from the segmentation result

In [None]:
def rprops(img, cellpose_model, block_info=None):
    tl_y = block_info[0]["array-location"][0][0]
    tl_x = block_info[0]["array-location"][1][0]

    img_t = transforms.convert_image(img, channel_axis=2, channels=[0, 0])
    img_t = transforms.normalize_img(img_t, invert=False, axis=2)

    labels, _, _ = cellpose_model.eval(img_t[None, ...], diameter=None, flow_threshold=None, channels=[0, 0])

    rprops = skimage.measure.regionprops(labels, intensity_image=img)

    features_arr = np.array([[dict(rprops=rprops, offset=(tl_y, tl_x))]], dtype=object)

    return features_arr

In [None]:
da_rprops = da.map_blocks(
    rprops,
    da_arr,
    cellpose_model,
    chunks=(1, 1),
    drop_axis=(2, ),
    dtype=object,
    meta=np.empty(0, dtype=object)
)

In [None]:
da_rprops

In [None]:
rprops_arr = da_rprops[20, 20].compute()

In [None]:
rprops_arr["rprops"][0].intensity_mean

---
# 5. Compute on masked chunks

## 5.1 Compute a mask from a low-resolution level of the input *pyramid*

In [None]:
dwn_gray = skimage.color.rgb2gray(z_grp["0/2"][0, :, 0], channel_axis=0)

In [None]:
dwn_gray.shape

In [None]:
plt.imshow(dwn_gray, cmap="gray")

In [None]:
dwn_blur = skimage.filters.gaussian(dwn_gray, sigma=5)

In [None]:
plt.imshow(dwn_blur, cmap="gray")

In [None]:
th = skimage.filters.threshold_otsu(dwn_blur)

In [None]:
dwn_mask = dwn_blur < th

In [None]:
plt.imshow(dwn_mask)

In [None]:
list(map(len, da_arr.chunks))

In [None]:
dwn_mask.shape

In [None]:
mask = skimage.transform.downscale_local_mean(dwn_mask, (4, 4)) > 0

In [None]:
mask.shape

In [None]:
plt.imshow(mask, cmap="gray")

In [None]:
mask.sum()

In [None]:
mask.size

In [None]:
mask.sum() / mask.size

In [None]:
da_mask = da.from_array(mask[..., None], chunks=(1, 1, 1))

In [None]:
def masked_rprops(img, mask, cellpose_model, block_info=None):
    tl_y = block_info[0]["array-location"][0][0]
    tl_x = block_info[0]["array-location"][1][0]

    if mask.sum():
        img_t = transforms.convert_image(img, channel_axis=2, channels=[0, 0])
        img_t = transforms.normalize_img(img_t, invert=False, axis=2)

        labels, _, _ = cellpose_model.eval(img_t[None, ...], diameter=None, flow_threshold=None, channels=[0, 0])
        rprops = skimage.measure.regionprops(labels, intensity_image=img)

        features_arr = np.array([[dict(rprops=rprops, offset=(tl_y, tl_x))]], dtype=object)

    else:
        features_arr = np.zeros((1, 1), dtype=object)

    return features_arr

In [None]:
da_rprops = da.map_blocks(
    masked_rprops,
    da_arr,
    da_mask,
    cellpose_model,
    chunks=(1, 1),
    drop_axis=(2, ),
    dtype=object,
    meta=np.empty(0, dtype=object)
)

In [None]:
da_rprops

In [None]:
from dask.diagnostics import ProgressBar

In [None]:
da_rprops[20, 20].compute()["rprops"][0]

In [None]:
with ProgressBar():
    rprops_arr = da_rprops.compute()

In [None]:
rprops_arr.shape

In [None]:
rprops_arr[15, 15]["rprops"][0].perimeter

---
## 5.2 Segment only masked regions of the image

In [None]:
def cellpose_masked_segment(img, mask, cellpose_model, block_info=None):
    if mask.sum():
        img_t = transforms.convert_image(img, channel_axis=2, channels=[0, 0])
        img_t = transforms.normalize_img(img_t, invert=False, axis=2)

        labels, _, _ = cellpose_model.eval(img_t[None, ...], diameter=None, flow_threshold=None, channels=[0, 0])

    else:
        labels = np.zeros(block_info[None]["chunk-shape"], dtype=block_info[None]["dtype"])
    return labels

In [None]:
da_labels = da.map_blocks(
    cellpose_masked_segment,
    da_arr,
    da_mask,
    cellpose_model,
    drop_axis=(2, ),
    dtype=np.int32,
    meta=np.empty(0, dtype=np.int32)
)

---
## 5.3 Store the segmentation as a Zarr file on disk

In [None]:
with ProgressBar():
    da_labels.to_zarr(
        "test_image_labels.zarr",
        component="0",
        write_empty_chunks=False,
        compressor=zarr.Blosc(clevel=9),
        overwrite=True
    )

---
# 6. Visualize the segmentation results

In [None]:
# @title 6.1 Downsample the labels array to have a *pyramid* version (for easy visualization)

with ProgressBar():
    for s in range(1, 6):
        da_labels = da.from_zarr("test_image_labels.zarr", component=str(s - 1))

        da_labels_ds = da_labels[::2, ::2]
        da_labels_ds = da_labels_ds.rechunk()

        da_labels_ds.to_zarr(
            "test_image_labels.zarr",
            component=str(s),
            write_empty_chunks=False,
            compressor=zarr.Blosc(clevel=9),
            overwrite=True
        )

z_labels = zarr.open("test_image_labels.zarr", mode="a")
z_labels.attrs["multiscales"] = [
    {
        "axes" : [
            {
                "unit" : "millimeter",
                "name" : "y",
                "type" : "space"
            },
            {
                "unit" : "millimeter",
                "name" : "x",
                "type" : "space"
            }
        ],
        "name" : "Cellpose labels",
        "datasets" : [
            {
                "path" : str(s),
                "coordinateTransformations": [
                    {
                        "scale" : [ 1.0, 1.0, 1.0, 4.942E-4 / (2**s), 4.942E-4 / (2**s)],
                        "type" : "scale"
                    }
                ]
            }
            for s in range(6)
        ],
        "version" : "0.1"
    }
]

In [None]:
z_labels = zarr.open("test_image_labels.zarr", mode="r")

In [None]:
z_grp = zarr.open("test_image.zarr", mode="r")