# Scaling-up Deep Learning Inference to Large-Scale Bioimage Data (part 2)

## Contact info:
- Fernando Cervantes
- Systems Analyst in JAX's Research IT
- email: fernando.cervantes@jax.org

# 1. [Optional] Convert a Tiff into Zarr format with `bioformats2raw`

Open a terminal (e.g. from the Launcher) and use the following command:

```
bioformats2raw CMU-1_Crop.ome.tif CMU-1_Crop.ome.zarr --use-existing-resolutions -p
```

Alternatively, download the pre-converted image from [here](https://drive.google.com/file/d/1BmNxOrO3vOFPR-PCnV00DYgFsD1sDu47/view?usp=sharing).

---
# 2. Compute on Zarr arrays

## 2.1 Set up a Dask cluster

- [ ] Use the Jupyter's Dask extension to start a distributed cluster
- [ ] Click the "+ New" button at the bottom of the plugin

![image](dask_extension.png)

- [ ] Click the "<>" button to inject the code needed to connect with this cluster

![image](dask_extension_ready.png)

In [None]:
import zarr
import dask
import dask.array as da
import numpy as np
import matplotlib.pyplot as plt
import skimage

Change this to the actual path where the image is stored

In [None]:
input_path = r"C:\Users\Public\Documents\WSI_example\CMU-1_Crop.ome.zarr"

---
# 3. [Example] Segmentation of nuclei in WSI (Cellpose)

- [ ] Load the `cyto3` pre-trained model from Cellpose library 

In [None]:
from cellpose import models, transforms
import torch

gpu = torch.cuda.is_available()
model_type = "cyto3"

cellpose_model = models.CellposeModel(gpu=gpu, model_type=model_type)

---
## 3.1 Compute as numpy array (`.compute()`)

In [None]:
img_t = transforms.convert_image(da_sel.compute(), channel_axis=2, channels=[0, 0])
img_t = transforms.normalize_img(img_t, invert=False, axis=2)

labels, _, _ = cellpose_model.eval(img_t[None, ...], diameter=None, flow_threshold=None, channels=[0, 0])

In [None]:
plt.imshow(da_sel)
plt.imshow(skimage.color.label2rgb(labels), alpha=0.5)

---
## 3.2 Compute the segmentation lazily with Dask (`delayed`)

- [ ] Define the inference pipeline as a function that can be applied to an image chunk
- [ ] Use `dask.delayed` to convert it into a lazy function
- [ ] Create a `dask.array` from the delayed output of the lazy function with `dask.array.from_delayed` (`da.from_delayed`)

---
## 3.3 Distribute computation with `map_blocks`

- [ ] Use the `dask.array.map_blocks` (`da.map_blocks`) function to apply the inference pipeline to the whole image

- [ ] Create an overlay of the labels generated by the inference pipeline on top of the image pixels. Use 'skimage.color.label2rgb', but this will require to use `.compute()`

- [ ] Use the `.blocks` property of `dask.array`s to access the pixels with a chunk/block-based coordinate system

---
## 3.4 Debug `map_blocks` computations

- [ ] Show the Log Console and change the log level to "Info"
- [ ] Import the replacement of the `print` function from `dask.distributed`. Import it as `dask_print` to prevent issues with the regular `print` function
- [ ] Test what happens when we don't use `drop_axis` argument in `dask.array.map_blocks` (`da.map_blocks`)
- [ ] Add a `dask_print` statement in `cellpose_segment` function to investigate what is the problem

In [None]:
from dask.distributed import print as dask_print

---
## 3.5 Return arrays with different shape with `map_blocks`

- [ ] Add a post-processing step to convert the outputs from the segmentation pipeline into a set of region properties (`skimage.measure.regionprops`)
    - Note that now the output of this function is a $1\times1$ array

---
# 4. [Example] Compute on masked chunks

## 4.1 Compute a mask from a low-resolution level of the input *pyramid*

- [ ] Compute a low-resolution mask using image processing
  - Use the downsampled image at level $2$ from input .zarr file ("0/2")
  - Convert the color image into grayscale
  - Smooth the image and apply a fixed threshold on all chunks
  - Downscale the mask to represent a $512\times512$ pixels region with a single pixel of the mask with an aggregation function (i.e. `.sum()`, `.mean()`)

---
# 5. [Exercise] Reduce computations on map_blocks

## 5.1  Apply the deep learning segmentation pipeline only on masked regions of the image

- [ ] Add a verification step to determine whether the current image chunk should be processed or not

---
## 5.2 Store the segmentation as a Zarr file on disk

- [ ] Use the `.to_zarr` method of `dask.array`s to store the array's content into a **.zarr** file.
    - Note: Use the argument `write_empty_chunks=False` to avoid creating files for empty chunks on disk
- [ ] Import `ProgressBar` and use it to show the progress of the segmentation process on the whole image

In [None]:
from dask.diagnostics import ProgressBar

---
## 5.3 Compute region properties from the stored labels

- [ ] Use the labels from disk instead of computing them again
- [ ] Apply a rule to avoid computing the region properties on chunks without any labels

---
# 6. [Optional] Convert a regular Zarr into a OME-Zarr

## 6.1 Downsample the labels array to have a *pyramid* version

- [ ] Use the labels that were stored as .zarr to create a downsampled version of the whole labeled image

In [None]:
with ProgressBar():
    for s in range(1, 6):
        da_labels = da.from_zarr("CMU-1_Crop_labels_cellpose_cyto3.zarr", component=str(s - 1))

        da_labels_ds = da_labels[::2, ::2]
        da_labels_ds = da_labels_ds.rechunk()

        da_labels_ds.to_zarr(
            "CMU-1_Crop_labels_cellpose_cyto3.zarr",
            component=str(s),
            write_empty_chunks=False,
            compressor=zarr.Blosc(clevel=9),
            overwrite=True
        )

In [None]:
z_labels = zarr.open("CMU-1_Crop_labels_cellpose_cyto3.zarr", mode="a")

- [ ] Add metadata to the `.zarr` to comply with *OME-Zarr* standard. This will enable *OME-Zarr* readers to open our `.zarr` file

In [None]:
z_labels.attrs["multiscales"] = [
    {
        "axes" : [
            {
                "unit" : "millimeter",
                "name" : "y",
                "type" : "space"
            },
            {
                "unit" : "millimeter",
                "name" : "x",
                "type" : "space"
            }
        ],
        "name" : "Cellpose labels",
        "datasets" : [
            {
                "path" : str(s),
                "coordinateTransformations": [
                    {
                        "scale" : [4.942E-4 / (2**s), 4.942E-4 / (2**s)],
                        "type" : "scale"
                    }
                ]
            }
            for s in range(6)
        ],
        "version" : "0.1"
    }
]

In [None]:
z_labels = zarr.open("CMU-1_Crop_labels_cellpose_cyto3.zarr", mode="r")

- [ ] Visualize the segmentation labels overlayed on top of the input image