# Loading images with Dask

Welcome to this notebook, where we wil show how to process images bigger than memory in GPU.
It uses Dask, Xarray, CuPy, Napari… and many others. The documentation of those projects is pretty good, and therefore this notebook is mainly a copy-paste of that documentation. And Stack overflow, of course.
Let’s start with some imports:

In [None]:
from math import isqrt
from typing import Tuple

import cupy as cp
import dask
import dask.array as da
import numpy as np
import rioxarray as rx
import xarray as xr
from dask.distributed import Client
from dask.utils import parse_bytes
from dask_cuda import LocalCUDACluster
from skimage.util import view_as_blocks

Modify your settings here

In [None]:
BASE_PATH="YOUR PATH HERE"

Create the cluster. It uses Dask-Cuda, from RAPIDS. The cluster must only be initialized once. Note the memory limit.

In [None]:
cluster = LocalCUDACluster(device_memory_limit="1GB")

Adjust and create the client. Note the chunk size. It has been tuned for my computer, but you can try other values. Also note that the backend is set to CuPy by default, despite rioxarray (we will se it later) is ignoring it, I think.

You can see the client's dashboard. It is quite relaxing.

In [None]:
dask.config.set({
    "array.backend": "cupy",
    "array.chunk-size": "64MiB"
    })
client = Client(cluster)
client

I do not have a real model. But anyway, let's create a fake one. CuPy provides interoperability with TensorFlow, so adapting this for a Keras model should not be that difficult.

In [None]:
class Model:

    @property
    def input_shape(self) -> Tuple[int, int, int]:
        return 299, 299, 3

    @property
    def output_shape(self) -> Tuple[int, int, int]:
        return 1, 1, 3

    def predict(self, images_to_predict: np.ndarray) -> np.ndarray:
        # Create a fake model.
        # Average per band. That will generate a valid visible image
        # TODO: maybe first dimensions must be grouped
        return np.mean(images_to_predict, axis=(2, 3))

model = Model()
model

The first step when working with images is loading them. Yeah, obviously. This function loads all the images using rioxarray, and creates a big mosaic with all of them. Note that an xarray DataArray is returned. xarray is cool.

Caveat: at the moment, only one image is loaded several times, and therefore Dask optimizes it. Real different images or image formats may behave in different ways.

Also note the rioxarray warning. That library is though for gepospatial images, and it will complain if no geospatial data available. BUT, rioxarray creates xarray DataArray objects, using Dask arrays as backends, and honoring image chunks, and that is a good point. I think dask-image does not honour chunks.

My test image is a hubble image downloaded from [here](https://esahubble.org/images/heic1502a/).

In [None]:
def get_image_mosaic() -> xr.DataArray:
    # https://esahubble.org/images/heic1502a/
    base_image = BASE_PATH + '/data/heic1502a.tif'
    rows = []
    for row in range(8):
        current_row = []
        for col in range(8):
            # Read and remove useless band
            img = rx.open_rasterio(base_image,
                                   parse_coordinates=False,
                                   chunks='auto')
            current_row.append(img)
        row_array = xr.concat(current_row, dim='x')
        rows.append(row_array)
    image = xr.concat(rows, dim='y')

    # Tidy the image
    image = image.assign_coords({'band': ['red', 'green', 'blue']})
    return image


image_mosaic = get_image_mosaic()
image_mosaic = image_mosaic.copy(deep=False, data=image_mosaic.data.map_blocks(cp.asarray))
image_mosaic

Until now, there is not anything weird, we just loaded images. Now the image must be adapted to the model.

Dask works splitting the image into chunks. Too big chunks, we blow the memory, too small chunks, it is slow. We determine the optimal chunk shape based on the desired chunk size, and reshape the image for having chunks with shapes multiples of the model shape.

We also normalize the image to range \[0, 1\]. This is different in numpy than in cupy.

Note the `da.map_blocks` function.

In [None]:
def rechunk_for_model(image: xr.DataArray, model_shape: Tuple[int, int, int]) -> da.array:
    # Get items to pad on rows
    model_y = model_shape[0]
    image_y = image.sizes['y']
    remainder_y = image_y % model_y
    pad_y = (0, 0)
    if remainder_y != 0:
        extra_y = model_y - remainder_y
        pad_y = (0, extra_y)

    # Get items to pad on cols
    model_x = model_shape[1]
    image_x = image.sizes['x']
    remainder_x = image_x % model_x
    pad_x = (0, 0)
    if remainder_x != 0:
        extra_x = model_x - remainder_x
        pad_x = (0, extra_x)

    # Pad the image
    image = image.pad({
        'y': pad_y,
        'x': pad_x
        },
        constant_values=0)
    image_y = image.sizes['y']
    image_x = image.sizes['x']

    # Rechunk in multiples of model size
    # Calculate the optimal size for dims x and y, assuming band joins together
    chunk_size = parse_bytes(dask.config.get('array.chunk-size'))
    model_band = model_shape[2]
    patch_size = model_y * model_x * model_band * image.data.itemsize
    # Check how many patches will be for every chunk in dims y and x
    num_patches = chunk_size // patch_size
    y_patches_in_chunk = isqrt(num_patches)
    x_patches_in_chunk = num_patches // y_patches_in_chunk
    # Rechunk and send
    image = image.transpose("y", "x", "band")
    image = image.chunk(chunks={
        'y': y_patches_in_chunk * model_y,
        'x': x_patches_in_chunk * model_x,
        'band': -1})
    return image

def normalize_chunk(chunk: np.ndarray):
    # Use this one for [-1, 1] normalization
    # chunk = cp.float32(chunk.get()) / 255.0 * 2.0 - 1.0
    chunk = cp.float32(chunk.get()) / 255.0
    return chunk

def normalize_for_model(image: xr.DataArray) -> xr.DataArray:
    image_data = image.data
    image_data = da.map_blocks(normalize_chunk, image_data, dtype=np.float32)
    return image.copy(deep=False,
                      data=image_data)

image_model = rechunk_for_model(image_mosaic, model.input_shape)
image_model = normalize_for_model(image_model)
image_model


Now, the model is passed over the chunks of the image. Note `da.blockwise`, like map_blocks, but mucho more flexible (and difficult).

In [None]:
def patch_and_predict(chunk: np.ndarray, model: Model):
    # Generate patches
    patches = view_as_blocks(chunk, model.input_shape)
    # Ignore blocks for channel
    patches = patches[:, :, 0]
    # Predict
    prediction = model.predict(patches)
    return prediction

def predict_image(image: xr.DataArray, model: Model) -> xr.DataArray:
    image_data = image.data
    predicted_data = da.blockwise(lambda x: patch_and_predict(x, model), 'yxc', image.data, 'yxc',
                                       meta=image_data._meta,
                                       name='patch-and-predict',
                                       adjust_chunks={'y': lambda y: y * model.output_shape[0] // model.input_shape[0],
                                                      'x': lambda x: x * model.output_shape[1] // model.input_shape[1],
                                                      'c': lambda _: model.output_shape[2]})
    return xr.DataArray(data=predicted_data,
                       dims=('y', 'x', 'band'),
                       coords={
                           'band': ['red', 'green', 'blue']
                       })

predicted_image = predict_image(image_model, model)
predicted_image

Napari could be used for painting the results. If it were working on my computer

In [None]:
#%gui qt5
#import napari
# viewer = napari.view_image(predicted_image)

So, at the moment, let's just save the result. That will trigger all the computation.

By the way, you can check the progress in the client's dashboard.

In [None]:
netcdf_path = BASE_PATH + "/data/results.nc"
predicted_image.transpose("band", "y", "x").to_dataset(name="prediction").to_netcdf(netcdf_path, engine="netcdf4")