# Welcome to Zarr and Dask for large-scale imaging workshop
## Fernando Cervantes
### Systems Analyst in JAX's Research IT
### email: fernando.cervantes@jax.org

## Outcomes for today's session:
- Learn to use Dask library with Zarr image data
- Implement and apply image analysis pipelines with Dask
- Save image analysis outputs as Zarr files


---
# Overview of the Dask package

Dask is lazy!

Find more about this [here](https://docs.dask.org/en/stable/array.html)

![image](https://docs.dask.org/en/stable/_images/dask-array.svg)

# 1. Manipulate Dask arrays

## 1.1 Create Dask arrays

- [ ] Create a $10\times10$ dask array of type `int16`, that is formed by chunks of size $5\times5$.

In [None]:
import dask
import dask.array as da
import numpy as np

In [None]:
d1 = da.zeros((10, 10), chunks=(5, 5), dtype=np.int16)

In [None]:
d1

- [ ] Modify the content of the dask array using slice selection.

In [None]:
d1[:5, :5] = 1

---
## 1.3 Execute the computation graph

- [ ] Visualize the information of the dask array.

In [None]:
d1

- [ ] Use the `.compute()` method of the dask array to trigger the actual computation of the instructions.

In [None]:
d1.compute()

- [ ] Add more steps to the computation graph.

In [None]:
d1 = d1 + 1

In [None]:
d1

In [None]:
d2 = da.ones((10, 10), chunks=(3, 3))

In [None]:
d3 = d1 + d2

In [None]:
d3

- [ ] Inspect the chunks' size of the resulting dask array

In [None]:
d3.chunks

---
## 1.4 Rechunk Dask arrays

- [ ] Use the `.rechunk(...)` method of the dask array to change the size of each of its chunks.

In [None]:
d3 = d3.rechunk((5, 5))

In [None]:
d3

In [None]:
d3 = d1 + d2.rechunk(d1.chunks)

In [None]:
d3

- [ ] Apply some math operations on the dask array using `numpy`.

In [None]:
d3_sum = np.sum(d3)

In [None]:
d3_sum

In [None]:
d3_sum.compute()

In [None]:
d3_cos = np.cos(d3)

In [None]:
d3_cos

In [None]:
d3_cos.compute()

---
## 1.5 Persist vs Compute

- [ ] Use the `.persist()` method of the dask array to partially compute the operations graph.

In [None]:
d3 = d1 + d2.rechunk((5, 5))

In [None]:
d3

In [None]:
d3 = d3.persist()

In [None]:
d3

In [None]:
d3 = d3 + 1

In [None]:
d3

---

## 1.3 Delayed operations

- [ ] Create a delayed function (decorated with `@dask.delayed`) that can be applied lazily

In [None]:
@dask.delayed
def grid_x(height, width, offset = 0):
    x = np.arange(offset, offset + width)
    return np.tile(x, (height, 1))

@dask.delayed
def grid_y(height, width, offset = 0):
    y = np.arange(offset, offset + height)
    return np.tile(y[:, None], (1, width))

In [None]:
X = grid_x(500, 500)
Y = grid_y(500, 500)

In [None]:
da_X = da.from_delayed(X, (500, 500), dtype=np.float64)

In [None]:
da_Y = da.from_delayed(Y, (500, 500), dtype=np.float64)

In [None]:
da_Z = da_X ** 2 + da_Y ** 2

In [None]:
da_Z

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(da_Z)

---
## 1.4 Stack, Concatenate, and Block

In [None]:
da_X_0_0 = da.from_delayed(grid_x(500, 500, 0), (500, 500), dtype=np.float64)

da_Y_0_0 = da.from_delayed(grid_y(500, 500, 0), (500, 500), dtype=np.float64)

da_X_0_1 = da.from_delayed(grid_x(500, 500, 500), (500, 500), dtype=np.float64)

da_Y_1_0 = da.from_delayed(grid_y(500, 500, 500), (500, 500), dtype=np.float64)

In [None]:
da_X_0_0

In [None]:
da_X_0 = da.stack((da_X_0_0, da_X_0_1), axis=1)
da_X_0

In [None]:
da_X_0 = da.concatenate((da_X_0_0, da_X_0_1), axis=1)
da_X_0

In [None]:
da_X = da.block([[da_X_0_0, da_X_0_1],
                 [da_X_0_0, da_X_0_1]])
da_X

In [None]:
da_Y = da.block([[da_Y_0_0, da_Y_0_0],
                 [da_Y_1_0, da_Y_1_0]])
da_Y

In [None]:
da_Z = da_X ** 2 + da_Y ** 2

In [None]:
da_Z

In [None]:
plt.imshow(da_Z)

---
# 2. Open Zarr files with Dask

- [ ]  Use the `tifffile` library to open a `.svs` image file, treating it as if it was a `Zarr` file (`aszarr=True`).

In [None]:
import zarr
import tifffile

In [None]:
z_grp = tifffile.imread(r"C:\Users\Public\Documents\WSI_example\CMU-1.svs", aszarr=True)
z_grp

- [ ] Use the `Store` object that is returned by `tifffile.imread` with `dask.array.from_zarr` function to open the image as a `dask.array`.

In [None]:
da_arr = da.from_zarr(z_grp, component="0")

In [None]:
da_arr

- [ ] Rechunk the image to have chunks of size $512\times512$

In [None]:
da_arr = da_arr.rechunk((512, 512, 3))

In [None]:
da_arr

- [ ] Extract a window from the image to analyze

In [None]:
offset_y = 16_000
offset_x = 8_000

In [None]:
da_sel = da_arr[offset_y:offset_y + 2048, offset_x:offset_x + 2048]

In [None]:
da_sel

In [None]:
import matplotlib.pyplot as plt

ℹ Dask arrays already work with `matplotlib.pyplot.imshow` without calling `.compute()`

In [None]:
plt.imshow(da_sel)

---
# 3. [Example] Perform image processing on Dask arrays

- [ ] Convert an image region from RGB color to Gray scale.

In [None]:
offset_y = 16_000
offset_x = 8_000

In [None]:
da_sel = da_arr[offset_y:offset_y + 2048, offset_x:offset_x + 2048]

In [None]:
da_sel

In [None]:
from skimage.color import rgb2gray

@dask.delayed
def color2gray(img_chunk):
    return rgb2gray(img_chunk)

In [None]:
da_gray = da.from_delayed(color2gray(da_sel), shape=da_sel.shape[:2], dtype=da_sel.dtype)

In [None]:
da_gray

In [None]:
plt.imshow(da_gray, cmap="gray")

---
# 4. [Exercise] Perform image analysis on Dask arrays

- [ ] Implement an object segmentation operation to detect nuclei pixels on a $2048\times2048$ pixels region.
    - [ ] Convert the image region from RGB to Gray
    - [ ] Reduce noise in the image region with a Gaussian Filter
    - [ ] Use a Thresholding algorithm to discriminate between structures given their pixel intensity

ℹ Dask arrays already work with `skimage` functions without calling `.compute()`

In [None]:
offset_y = 16_000
offset_x = 8_000

In [None]:
da_sel = da_arr[offset_y:offset_y + 2048, offset_x:offset_x + 2048]

In [None]:
da_sel

In [None]:
@dask.delayed
def color2gray(img_chunk):
    return rgb2gray(img_chunk)

In [None]:
da_gray = da.from_delayed(color2gray(da_sel), shape=da_sel.shape[:2], dtype=da_sel.dtype)

In [None]:
from dask_image import ndfilters

In [None]:
da_gauss = ndfilters.gaussian(da_gray, 5.0, order=0, mode='reflect', cval=0.0, truncate=4.0)

In [None]:
da_gauss

In [None]:
from skimage.filters import threshold_multiotsu

@dask.delayed
def thresholding(img_chunk):
    thresh_levels = threshold_multiotsu(img_chunk, classes=3)
    thresholded_chunk = img_chunk < thresh_levels[0]
    return thresholded_chunk

In [None]:
arr_nuclei = da.from_delayed(thresholding(da_gauss), shape=da_gauss.shape, dtype=bool)

- [ ] Visualize the results using `Matplotlib`

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(da_sel)
plt.imshow(arr_nuclei, cmap="gray", alpha=0.5)

In [None]:
plt.imshow(da_sel[1500:2000, 1000:1500])
plt.imshow(arr_nuclei[1500:2000, 1000:1500], cmap="gray", alpha=0.5)