<center>
    
# 4 - Combining `Dask.array` and `Numba`

   
</center>

### Let's see some useful ways ot combine the power of `Dask` and `Numba` together!
Some interesting links: 
- https://examples.dask.org/applications/stencils-with-numba.html
- https://developer.nvidia.com/blog/accelerated-portfolio-construction-with-numba-and-dask-in-python/

In [None]:
import os

import dask.array as da
import dask_image.ndfilters
import numba
import numpy as np

from utils import show_images

## Let's load some big data! 
We are going to **lazy-load** some images acquired by the [Hubble Space Telescope](https://www.wikiwand.com/en/Hubble_Space_Telescope). 

This image is known as the ***Hubble Ultra Deep Field***, and captures a view of nearly 10,000 galaxies (is the deepest visible-light image of the cosmos). 

In [None]:
hubble_image = da.from_zarr(os.path.join("imgs", "hubble_enh.zarr"))
print(f"The image has {hubble_image[..., 0].size / 1e6 : .0f} MPix, and takes:")
hubble_image

In [None]:
# Plot a few chunks
show_images(images=[hubble_image.blocks[0, 20], hubble_image.blocks[0,23], hubble_image.blocks[4,19]])

Imagine that you are an astronomer and want to find all forming galaxies in their early stages (red, circular and small).

Here we will use Dask and Numba to do so in an efficient manner:

## Application: The Structure Tensor    

The [Structure Tensor](https://www.wikiwand.com/en/Structure_tensor) is a powerful tool for analyzing the structure of images and extracting useful information from them, often used in
- image segmentation, 
- object recognition, and
- optical flow estimation

It makes a great showcase example for *Numba* and *Dask* as it makes very **simple** operations (convolutions) and it is **highly parallelizable**.

<div style="text-align: right"><a href="https://www.crisluengo.net/archives/1132/">This</a> is a great post on the structure tensor, by Chris Luengo.</div>

The structure tensor is the outer product of the gradient vector with itself, locally averaged. 

$$\mathbf{S} = \overline{(\nabla \mathbf{I})(\nabla \mathbf{I})^{\top}} = \overline{\begin{pmatrix} \mathbf{I}_{x} \\
\mathbf{I}_{y}\end{pmatrix}  \begin{pmatrix} \mathbf{I}_{x} &
\mathbf{I}_{y}\end{pmatrix}} = \begin{pmatrix}
\overline{\mathbf{I}_{x}\mathbf{I}_{x}} & \overline{\mathbf{I}_{x}\mathbf{I}_{y}} \\
\overline{\mathbf{I}_{x}\mathbf{I}_{y}} & \overline{\mathbf{I}_{y}\mathbf{I}_{y}} 
\end{pmatrix}$$

where $\mathbf{I}_{x}$ indicates the partial derivative along axis $x$, and the overlines $\overline{\cdot}$ indicate local averaging, usually by means of a Gaussian kernel. 

*Note that the structure tensor is composed by **first-order** partial derivatives.*

## 1) The image gradient
To compute the structure tensor, we first calculate the gradient of the image at each point, which gives us a vector that describes the direction and strength of the local intensity change. 

Because an image has usually two dimensions, the gradient consists on the stacking of two partial derivatives:

$$ \nabla I(x, y) = \begin{pmatrix} 
                            \frac{\partial I(x, y) }{\partial x }\\
                            \frac{\partial I(x, y) }{\partial x }
                    \end{pmatrix} = 
                    \begin{pmatrix} 
                            \mathbf{I}_{x}\\
                            \mathbf{I}_{y}
                    \end{pmatrix}
$$


Lets start with the first-order partial derivatives. 

The derivative of a function is defined as: 

$$ \frac{\partial f(x) }{\partial x } =  \lim_{h \to 0} \frac{f(x+h) - f(x)}{h}$$


On a **discrete grid**, the smallest distance obtainable without interpolation is $h=1$, which yields the [Finite Difference Method](https://www.wikiwand.com/en/Finite_difference_method) approximation of the derivative:

$$ \left(\frac{\partial f(x) }{\partial x }\right)_{FD} \approx  f(x+1) - f(x)$$


In practice, the finite diffrence operator corresponds to a *convolution* with a linear filter with values `[1,-1]`:

$$ f(x+1) - f(x) = f(x) * [1 \, \, -1]$$


Let's use *Numba Stencils* to efficiently perform  `finite_difference`!

In [None]:
@numba.stencil
def _finite_difference_x(images):
    """
    Apply finite differences on the x-axis of a 3D image.
    
    Parameters
    ----------
        images: nd-array (npix_y, npix_x, nchan)
                Stack of images
                
    Returns
    -------
        derivatives: nd-array (npix_y, npix_x, nchan)
                Stack of images
    """
    return images[0, 1, 0] - images[0, 0, 0]


@numba.stencil
def _finite_difference_y(images):
    """
    Apply finite differences on the y-axis of a 3D image.
    
    Parameters
    ----------
        images: nd-array (npix_y, npix_x, nchan)
                Stack of images

    Returns
    -------
        derivatives: nd-array (npix_y, npix_x, nchan)
                Stack of images
    """
    return images[1, 0, 0] - images[0, 0, 0]

As we saw on the first notebook, we can compile these functions to run even faster!

In [None]:
@numba.jit(parallel=True)
def fd_x(images):
    return _finite_difference_x(images)

@numba.jit(parallel=True)
def fd_y(images):
    return _finite_difference_y(images)

Now we can apply each function to each array chunk with [`map_overlap`](https://docs.dask.org/en/stable/generated/dask.array.map_overlap.html). Dask will run the computation in multi-threaded mode. 

In [None]:
dfdx_fd = hubble_image.map_overlap(fd_x, depth=(0, 1, 0))
dfdy_fd = hubble_image.map_overlap(fd_y, depth=(1, 0, 0))

In [None]:
show_images(images=[hubble_image.blocks[0, 20], dfdx_fd.blocks[0, 20].mean(-1), dfdy_fd.blocks[0,20].mean(-1)], 
            titles=["Original", "fd_x", "fd_y"],
)

The problem with the `finite_difference` filter is that it respond very strongly to the noise in the image!

As a solution to that, we can use the **Gaussian derivative** (see another [great post](https://www.crisluengo.net/archives/22/) from Chris Luengo) on its benefits with respect to the finite difference method. 

The Gaussian derivative is defined as:
$$\begin{aligned}
 \frac{\partial}{\partial x} [I(x, y) \ast g(x, y)] &=  \frac{\partial}{\partial x} \ast I(x, y) \ast g(x, y) \\
                                                    &=  I(x, y) \ast \frac{\partial}{\partial x} \ast g(x, y) \\
                                                    &=  I(x, y) \ast \left[\frac{\partial}{\partial x} g(x, y)\right]
\end{aligned}$$

Where:
- in the first step we used the associative property of the convolution. 
- in the second step we used the commutative property  property of the convolution. 
- in the third step we used the associative property of the convolution. 

These properties show us that computing the gradient of an image blurred with a Gaussian is the same thing as convolving the image with the gradient of a Gaussian!

Let's use again *Numba stencils* to perform the `gaussian_derivative`. This time, we'll define the convolutional kernel using a Scipy function.

In [None]:
import matplotlib.pyplot as plt
from scipy.ndimage._filters import _gaussian_kernel1d

gaussian_derivative = _gaussian_kernel1d(1.5, 1, 10)
plt.plot(gaussian_derivative, label="Gaussian Derivative kernel")
plt.xlabel(r"$x$")
plt.ylabel(r"$g'(x)$")
plt.legend();

There are pre-defined some helper function in `utils` to create stencils from a 1d array kernel. Check them out if you have time. 

In [None]:
from utils import apply_kernel_x, apply_kernel_y
@numba.jit(parallel=True)
def gd_x(images):
    return apply_kernel_x(images, gaussian_derivative)


@numba.jit(parallel=True)
def gd_y(images):
    return apply_kernel_y(images, gaussian_derivative)

In [None]:
dfdx_gd = hubble_image.map_overlap(gd_x, depth=(0, 10, 0))
dfdy_gd = hubble_image.map_overlap(gd_y, depth=(10, 0, 0))

#### Can you see any difference between finite differences and Gaussian derivative?

In [None]:
show_images(images=[hubble_image.blocks[0, 20], dfdx_fd.blocks[0, 20].mean(-1), dfdy_fd.blocks[0, 20].mean(-1)], 
            titles=["Original", "Finite Diff. (x-axis)", "Finite Diff. (y-axis)"],
           )
show_images(images=[hubble_image.blocks[0, 20], dfdx_gd.blocks[0, 20].mean(-1), dfdy_gd.blocks[0, 20].mean(-1)], 
            titles=["Original", "Gauss. Der. (x-axis)", "Gauss. Der. (y-axis)"],
           )

## 2) The Gradient outer product

$$\mathbf{S} = \overline{(\nabla \mathbf{I})(\nabla \mathbf{I})^{\top}} = \overline{\begin{pmatrix} \mathbf{I}_{x} \\
\mathbf{I}_{y}\end{pmatrix}  \begin{pmatrix} \mathbf{I}_{x} &
\mathbf{I}_{y}\end{pmatrix}} = \begin{pmatrix}
\overline{\mathbf{I}_{x}\mathbf{I}_{x}} & \overline{\mathbf{I}_{x}\mathbf{I}_{y}} \\
\overline{\mathbf{I}_{x}\mathbf{I}_{y}} & \overline{\mathbf{I}_{y}\mathbf{I}_{y}} 
\end{pmatrix}$$


In [None]:
methods = {0: "gaussian_derivative", 1: "finite_differences"}

method = methods[0]

if method == "finite_differences":
    IxIx = dfdx_fd * dfdx_fd
    IxIy = dfdx_fd * dfdy_fd
    IyIy = dfdy_fd * dfdy_fd
elif method == "gaussian_derivative":
    IxIx = dfdx_gd * dfdx_gd
    IxIy = dfdx_gd * dfdy_gd
    IyIy = dfdy_gd * dfdy_gd

### 3) Local averaging with a Gaussian window

We then take these gradient vectors and compute a second-order tensor that describes the covariance of the gradient vectors in a small region around each point. This is in practice performed by a weighted average using a Gaussian window like the following one:

In [None]:
gaussian_kernel1d = _gaussian_kernel1d(2, 0, 10)
plt.plot(gaussian_kernel1d, label="Gaussian kernel")
plt.legend()

In [None]:
@numba.jit(parallel=True)
def smooth(images):
    return apply_kernel_y(
        apply_kernel_x(images, gaussian_kernel1d), 
        gaussian_kernel1d)

In [None]:
IxIx_bar = IxIx.map_overlap(smooth, depth=(10, 10, 0))
IxIy_bar = IxIy.map_overlap(smooth, depth=(10, 10, 0))
IyIy_bar = IyIy.map_overlap(smooth, depth=(10, 10, 0))

This tensor tells us how the gradient vectors are aligned and how strong they are in different directions, which gives us information about the local texture and patterns in the image. We can use this information to identify features like edges, corners, and lines, and to track the movement of objects in a sequence of images.

In [None]:
IxIx_bar

### 4) The structure tensor

In [None]:
show_images(
    images=[hubble_image.blocks[0,20], 
            IxIx_bar.blocks[0, 20].mean(-1), 
            IxIy_bar.blocks[0, 20].mean(-1), 
            IyIy_bar.blocks[0, 20].mean(-1)],
    titles=["Original", "IxIx_bar", "IxIy_bar", "IyIy_bar"],
    cmap=[None, "pink_r", "pink_r", "pink_r"],
)


In [None]:
structure_tensor = da.stack(
    [da.stack([IxIx_bar, IxIy_bar], axis=3), da.stack([IxIy_bar, IyIy_bar], axis=3)],
    axis=4,
)
structure_tensor = structure_tensor.rechunk(hubble_image.chunksize + (2, 2))
structure_tensor

Rechunking across axes can be expensive and incur a lot of communication, but Dask array has fairly efficient algorithms to accomplish this.

In our case, rechunking is necessary for the next steps of the pipeline:

## 5) Eigendecomposition of the structure tensor

In [None]:
w, v = da.apply_gufunc(np.linalg.eigh, "(m,m)->(m),(m,m)", structure_tensor)
eigvals, eivgecs = w.persist(), v.persist()

In [None]:
# Bonus exercise to deepen your Dask+Numba skills:
# 1) Create a generalized universal function @numba.guvectorize 
### This should perform the eigendecomposition in parallel for some axis (use @numba.prange)
# 2) Apply it to the structure tensor via dask.array.apply_gufunc.

# Is it faster than the da.apply_gufunc(np.linalg.eigh) call?

## 6) Results and interpretation

##### 6.1. Local gradient strength and Local gradient variation

In [None]:
show_images(
    images=[hubble_image.blocks[0, 20], eigvals.blocks[0, 20][..., 0].mean(-1), eigvals.blocks[0, 20][..., 1].mean(-1)],
    titles=["Original", "Gradient stength", "Gradient Variation"],
    cmap=[None, "pink_r", "pink_r"],
)

##### 6.2. Energy and anisotropy

In [None]:
energy = eigvals[..., 0] + eigvals[..., 1]
anisotropy1 = (eigvals[..., 1] - eigvals[..., 0]) / energy
anisotropy2 = 1 - (eigvals[..., 0] / eigvals[..., 1])

show_images(
    images=[energy.blocks[0, 20].mean(-1), anisotropy1.blocks[0, 20].mean(-1)],
    titles=["energy", "anisotropy1"],
    cmap=["pink_r", "pink_r"],
)

##### 6.3 Feature extraction
We can now select the features that best captue young galaxies, for example, low anisotropy and high energy in red channel would be good attributes to track.

In [None]:
# low anisotropy
feature1 = (1 - anisotropy1[...].mean(-1)) 

# high red, low blue and green
feature2 = energy[..., 0] ** 2 / (energy[..., 0] + energy[..., 1] + energy[..., 2]) 

features = feature1 * feature2

show_images(
    images=[hubble_image, features],
    zoom=[True, True],
    titles=["Original", "low anisotropy + high energy"],
    cmap=[None, "pink_r"],
)

We can now filter and select the chunks that only contain high expression of those features.

In [None]:
import itertools

top1 = 0.008
blocks_of_interest = []
for bl in itertools.product(np.arange(24), np.arange(27)):
    if da.nansum(features.blocks[bl] > top1):
        blocks_of_interest.append(bl)

In [None]:
len(blocks_of_interest)

In [None]:
show_images(
    images=[
        hubble_image.blocks[blocks_of_interest[0]],
        hubble_image.blocks[blocks_of_interest[1]],
        hubble_image.blocks[blocks_of_interest[2]],
        hubble_image.blocks[blocks_of_interest[3]],
    ]
)

show_images(
    images=[
        features.blocks[blocks_of_interest[0]],
        features.blocks[blocks_of_interest[1]],
        features.blocks[blocks_of_interest[2]],
        features.blocks[blocks_of_interest[3]],
    ]
)

In [None]:
show_images(
    images=[
        hubble_image.blocks[blocks_of_interest[4]],
        hubble_image.blocks[blocks_of_interest[5]],
        hubble_image.blocks[blocks_of_interest[6]],
        hubble_image.blocks[blocks_of_interest[7]],
    ]
)

show_images(
    images=[
        features.blocks[blocks_of_interest[4]],
        features.blocks[blocks_of_interest[5]],
        features.blocks[blocks_of_interest[6]],
        features.blocks[blocks_of_interest[7]],
    ]
)

The whole picture can be seen at https://esahubble.org/images/heic0406a/

# BONUS
##### Create a single function that performs the ST using Dask and Numba without storing any intermediate steps (such as partial derivatives).