In [None]:
from dask.distributed import Client, progress
import dask.array as da
import numpy as np
import scipy
client = Client(processes=False)
client

Overlapping Operations
======================

Some operations depend on neigbourding values. For exemple derivatives, sliding sum, image filter.

For that there is `overlap` and `map_overlap`.  
It add a border to each chunks before mapping function on each chunks.  

![](https://docs.dask.org/en/stable/_images/overlapping-neighbors.svg)

It does so in all dimension, including diagonals:

![](https://docs.dask.org/en/stable/_images/overlapping-blocks.svg)

In [None]:
x = da.from_array(np.arange(100).reshape((10, 10)), chunks=(5, 2))
x

In [None]:
x.compute()

In [None]:
extented = da.overlap.overlap(x, depth=(1, 1), boundary=("reflect", -100)) 
extented

In [None]:
extented.compute()[:, :10]

Outer border can be:
- "periodic"
- "reflect"
- Any constant

Once extented, we can map a function on each blocks.  
This is good for scipy array functions.


In [None]:
def convolve2d(arr):
    filter = np.array([[0, 1, 0],[1, -4, 1],[0, 1, 0]])
    return scipy.signal.convolve2d(arr, filter, mode="same")

filtered = extented.map_blocks(convolve2d)
filtered

In [None]:
filtered.blocks[0, 1].compute()

Often once the action on each block is done, we often want to remove the borders.  
This can be done with `trim_overlap` or `trim_internal`.

In [None]:
trimmed = da.overlap.trim_overlap(filtered, (1, 1))
trimmed

In [None]:
trimmed.blocks[0, 1].compute()

In [None]:
trimmed.blocks[0, 0].compute()

In [None]:
trimmed.compute()

## Exercise: gradient

We define a function on a 2d plane. Then compute the derivatives.

$\frac{df(x, y)}{dx} = \frac{f(x + dx, y) - f(x - dx, y)}{2 dx}$


In [None]:
import matplotlib.pyplot as plt
ts = np.linspace(-2, 2, 4001)
xs, ys = np.meshgrid(ts, ts)

func = (xs**2 - ys) **2 * np.exp(-(xs**2 + ys**2))

dx = np.diff(func, axis=0) / 0.001
dy = np.diff(func, axis=1) / 0.001

plt.imshow(func)
plt.show()
plt.imshow(dx)
plt.show()
plt.imshow(dy)

### Solution
<!---
ts = da.linspace(-2, 2, 4001, chunks=(1000,))
xs, ys = da.meshgrid(ts, ts)

func = (xs**2 - ys) **2 * np.exp(-(xs**2 + ys**2))

dx = da.diff(func, axis=0) / 0.001
dy = da.diff(func, axis=0) / 0.001

dx.visualize()
plt.imshow(dx.compute())
--->