Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROI] Memory and running time for ROI-based illumination correction #27

Closed
tcompa opened this issue Aug 2, 2022 · 25 comments
Closed

[ROI] Memory and running time for ROI-based illumination correction #27

tcompa opened this issue Aug 2, 2022 · 25 comments
Assignees
Labels
High Priority Current Priorities & Blocking Issues Tables AnnData and ROI/feature tables

Comments

@tcompa
Copy link
Collaborator

tcompa commented Aug 2, 2022

We already know that ROI-parallel tasks have a computational overhead over chunk-parallel tasks. Here we report more in detail how things look like for the illumination correction task (which is also somewhat special -- see below).

TL;DR

  1. We managed to obtain a ROI-parallel version of the task which is actually faster than the corresponding per-chunk version.
  2. The ROI-parallel memory usage is terribly higher than the chunk-parallel one, but maybe we can live with this (or mitigate it somehow).
  3. It is crucial (speed-wise and memory-wise) to explicitly unravel the Z indices of the ROIs, since for this task they are trivial (each ROI lives on a single Z plane) -- see below. This is typically not possible/relevant for a general ROI-parallel task.

Details

We compare the current version of the task (31fe8533e00159d42691fd1207d3174eb5d941d5) and an old chunk-parallel one (c653138cb1a37fb8e1348fdf99fa82b48a4fc987). As a test, we use a single well with 9x8 FOVs and 19 Z planes.

After some tinkering, we find that the optimal (speed-wise and memory-wise) version of the per-ROI task includes these loops:

    # Loop over channels
    data_czyx_new = []
    for ind_ch, ch in enumerate(chl_list):
        # Set correction matrix
        illum_img = corrections[ch]
        # 3D data for multiple FOVs
        data_zyx_new = da.empty(
            data_czyx[ind_ch].shape,
            chunks=data_czyx[ind_ch].chunks,
            dtype=data_czyx.dtype,
        )

        # Loop over FOVs
        for indices in list_indices:
            s_z, e_z, s_y, e_y, s_x, e_x = indices[:]
            # 3D single-FOV data
            tmp_zyx = []
            # For each FOV, loop over Z planes
            for ind_z in range(e_z):
                shape = [e_y - s_y, e_x - s_x]
                new_img = delayed_correct(
                    data_czyx[ind_ch, ind_z, s_y:e_y, s_x:e_x],
                    illum_img,
                    background=background,
                )
                tmp_zyx.append(da.from_delayed(new_img, shape, dtype))
            data_zyx_new[s_z:e_z, s_y:e_y, s_x:e_x] = da.stack(tmp_zyx, axis=0)
        data_czyx_new.append(data_zyx_new)
    accumulated_data = da.stack(data_czyx_new, axis=0)

Notice that we do not call the function split_3D_indices_into_z_layers (which would produce multiple ROIs for each FOV, distributed over all Z planes) any more, but we rather iterate "by hand" over the Z planes (for ind_z in range(e_z):), and then combine the e_z=19 single-Z-plane FOVs into a single stack (data_zyx_new[s_z:e_z, s_y:e_y, s_x:e_x] = da.stack(tmp_zyx, axis=0)). This is clearly better (in terms of both memory and speed) than an older version, see https://github.com/fractal-analytics-platform/fractal/blob/d80e83253df4d2ec6ffbadbdf52b281b3de30217/fractal/tasks/illumination_correction.py#L228-L254

Results

Here is the memory trace of the two (old/new) tasks.

fig_memory

Detailed comments:

  1. The new per-ROI version is even faster than the per-chunk one, somehow contradicting [ROI] Per-ROI parallelization introduces dask overhead (with respect to per-chunk parallelization) #26.
  2. The memory usage is terribly larger than the old version (something like 12 times larger), which is not good. On the one hand, this is expected since we are not based on dask's "native" parallelization over chunks (meaning we probably load and process several chunks at the same time). On the other hand, other tasks like yokogawa_to_zarr have a similar structure (they work image-by-image within for loops) but no memory issue (however: they are mostly based on the append/stack procedure, which is clearly better than using explicit indices).

Higher-level comments:

  1. In some previous (suboptimal) version of this task, we had introduced a cutoff on how many dask threads could be run in parallel, with little success. We should try again with the new (optimized) version. Hopefully we'll see that a small cutoff reduces speed but also memory usage. That would be a perfect compromise, since the user could always say "I have a very large dataset, but please never run more than 20 corrections at the same time".
  2. If strategy 1 is not effective, and if we see that the current memory usage would break workflows on larger datasets (how large?), we need to find other ideas for mitigation. Batching is always fine, but we should find a way to make its interface robust and general (e.g. we could batch over channels, if they are more than a certain number of if the user asks for it). To be explored, if needed.
  3. Moving away from dask standard procedures (not only the map_blocks parallelization, but also the append/stack/concatenate statements are recommended) has serious consequences. In the specific case of illumination correction, there is no advantage in this change, because we always work on images and we always work at level 0. The advantages should be evident in other tasks, where (1) we have more complex (or at least 3D) ROIs, or (2) we need to work at lower resolution.
@tcompa tcompa self-assigned this Aug 2, 2022
tcompa referenced this issue in fractal-analytics-platform/fractal-client Aug 2, 2022
tcompa referenced this issue in fractal-analytics-platform/fractal-client Aug 2, 2022
@tcompa
Copy link
Collaborator Author

tcompa commented Aug 2, 2022

  1. In some previous (suboptimal) version of this task, we had introduced a cutoff on how many dask threads could be run in parallel, with little success. We should try again with the new (optimized) version. Hopefully we'll see that a small cutoff reduces speed but also memory usage. That would be a perfect compromise, since the user could always say "I have a very large dataset, but please never run more than 20 corrections at the same time".

A rapid test suggests this scheme is useless.

The next trivial check is to play with rechunking. The first attempt (fractal-analytics-platform/fractal-client@2ee3e22) was to use rechunk("auto") and then only set the expected chunk sizes at the end. This was combined with a more general handling of channels, as in
https://github.com/fractal-analytics-platform/fractal/blob/2ee3e22f6faf8c57505f053316792f6aab271540/fractal/tasks/illumination_correction.py#L228-L248

The result is a small reduction of memory (new task is the green line, in the figure below), but this doesn't really make a real difference.
fig_memory

Other options we quickly mentioned with @jluethi:

  1. Some more checks that explicit rechunking cannot solve this issue (we don't expect it to work, as it seems that almost the whole input array is loaded before computation/writing.. but we can have a look).
  2. Batching would be useful (in channels or in Z planes), but we still don't know how to write portions of an array with to_zarr.
  3. Is there some smart way of indexing a dask array, where we can include additional information on chunks? More generally: is there a way to make assignment by index somewhat consistent with chunks (even at the price of a very verbose definition)?

tcompa referenced this issue in fractal-analytics-platform/fractal-client Aug 3, 2022
* Remove split_3D_indices_into_z_layers;
* Remove temporary_test;
* Update test.
@tcompa
Copy link
Collaborator Author

tcompa commented Aug 4, 2022

Here is the result of a comprehensive set of local tests, which only mimic the illumination-correction task. It would be more useful to translate it into more realistic tests, but the cluster is not available at the moment.

TL;DR

For this artificial example:

  • All cases where chunks align to ROIs are memory safe;
  • Cases where chunks do not align to ROIs are not memory safe, but this is solved via inline_array=True.

This is encouraging, but we need to switch to a more realistic test.

Setup

  • We define ROIs to be 2000x2000 squares in YX (similar to an image).
  • We prepare four zarr arrays stored on disk. Two of them are "small" (shape=(1, 2, 4000, 4000), made of 8 ROIs) and two of them are "large" (shape=(2, 10, 16000, 18000), made of 1440 ROIs). Apart from being small/large, we either align chunks with ROIs (chunks=(1,1,2000,2000)), or not (chunks=(1,1,1000,4000)).
  • For each one of the array, we can run processing with a combination of inline_array=True/False for from_zarr and compute=None/True for the final to_zarr (the statement which triggers dask execution). Note that compute=None in this case just means that we do not include this parameter.
  • The processing function looks like
def function(dummy):
    return dummy.copy() + 1


def doall(shape_label, chunk_label, inline_array, compute):
    in_zarr = f"../raw_{shape_label}_{chunk_label}.zarr"
    out_zarr = f"out_{shape_label}_{chunk_label}"
    out_zarr += f"_inline_{inline_array}_compute_{compute}.zarr"
    x = da.from_zarr(in_zarr, inline_array=inline_array)
    y = da.empty(x.shape, chunks=x.chunks, dtype=x.dtype)
    if os.path.isdir(out_zarr):
        shutil.rmtree(out_zarr)

    print(out_zarr)

    CS = 2000
    delayed_function = dask.delayed(function)

    print("shape:", x.shape)
    print("chunks:", x.chunks)
    print("ROI edge:", CS)

    t0 = time.perf_counter()
    num_ind = 0
    for ind in product(list(range(0, x.shape[-2] - 1, CS)),
                       list(range(0, x.shape[-1] - 1, CS))):
        for i_ch in range(x.shape[0]):
            sy, sx = ind[:]
            ex = sx + CS
            ey = sy + CS
            for i_z in range(x.shape[1]):
                lazy_res = delayed_function(x[i_ch, i_z, sy:ey, sx:ex])
                y[i_ch, i_z, sy:ey, sx:ex] = da.from_delayed(lazy_res,
                                                             shape=(CS, CS),
                                                             dtype=np.uint16)
                num_ind += 1
    t1 = time.perf_counter()
    print(f"End of loop over {num_ind} indices, elapsed {t1-t0:.3f} s")

    if compute is None:
        y.to_zarr(out_zarr, dimension_separator="/")
    elif compute:
        y.to_zarr(out_zarr, dimension_separator="/", compute=True)
    size_MB = get_dir_size(out_zarr) / 1e6

    t1 = time.perf_counter()
    print(f"End, elapsed {t1-t0:.3f} s")
    print(f"Folder size: {size_MB:.1f} MB")
    print()

    assert size_MB > 1

    shutil.rmtree(out_zarr)

Tests

We perform two kinds of tests:

  1. For small arrays, we look at the 8 possible dask graphs (where there are 8 combinations of aligned/not-aligned, inline_arrays=True/False, and compute=True/None), looking for ordering problems as described in https://docs.dask.org/en/stable/order.html.
  2. For large arrays, we perform again the same 8 possible computations, and we profile the memory usage with memory_profiler (as in https://github.com/pythonprofilers/memory_profiler#api).

I'll describe the results of the two tests in two follow-up comments - stay tuned..

@tcompa
Copy link
Collaborator Author

tcompa commented Aug 4, 2022

Test 1: dask graphs and ordering

We generated and looked at the graphs for the 8 combinations corresponding to the small array.

The interesting ones are those were the ROIs and chunks are not aligned, and for simplicity we only show those without the compute argument. In this case, the default graph (with inline_array=False) looks "bad", meaning that the different branches are somewhat mixed (note that color gradients and numbers within round boxes represent execution order). This can be identified as an ordering problem, and possibly "the poor static ordering means we fail to complete tasks that would let us release pieces of data. We load more pieces into memory at once, leading to higher memory usage." (https://docs.dask.org/en/stable/order.html). Adding the inline_array=True flag solves the issue, in this case, leading to four towers of tasks which are correctly ordered (each one has roughly the same color).

inline_graph=False:

fig_small_align_N_inline_False_compute_None

inline_graph=True:

fig_small_align_N_inline_True_compute_None

Take-home message: when chunks and ROIs are not aligned, inlining the from_zarr arrays can be useful.
On the other hand, this issue does not appear when ROIs are aligned with chunks (not shown here, but we also looked at those graphs). Using the inline_array=True flag in that case may or may not be beneficial - this we don't know.

Broader context: it is unclear whether this test can help us, because in principle illumination correction works in a situation where ROIs are aligned with chunks. Still, this is something to keep in mind.

@tcompa
Copy link
Collaborator Author

tcompa commented Aug 4, 2022

Test 2: memory profiling on large arrays

For a zarr array which takes 11 GB on disk (independently on the chunking choice), we run the test above 8 times, and we profile it with lines like

from memory_profiler import memory_usage
mem = memory_usage((doall, (shape_label, chunk_label, inline_array, compute)), interval=0.1)

The output is shown below. We observe that 6 cases out of 8 can be considered memory-safe (11 GB are written on disk, with a memory peak of <2.5 GB). The only two "bad" cases are those where ROIs and chunks are not aligned, and where inline_array=False. This matches with Test 1 in the previous comment.

fig_mem_large

Once again: if we are doing everything correctly, the illumination-correction task has the ROIs aligned with chunks, compute=True (this is because of our custom to_zarr function, but maybe we can get rid of it if needed), and inline_array=False. This configuration should lead to a safe execution, according to the artificial example, but this is not what we observe. What is the difference between artificial/realistic tests is still to be understood.

@jluethi
Copy link
Collaborator

jluethi commented Aug 4, 2022

Thanks @tcompa ! Very insightful and good visualization on when dask ordering can matter. Interesting to see how much of a difference "better" ordering with inline_array=True makes in the overlapping ROI case, certainly something to keep in mind.

I'm also very curious to see profiling of the illumination correction function with these learnings and how it differs from those tests. The dask graph in #26 already looks fairly structured (with default inline_array=False, right?), so it's surprising that this is a memory issue indeed for illum-corr.

@jluethi
Copy link
Collaborator

jluethi commented Aug 4, 2022

Have we measured how large your large array is vs. how large the real data for illum_corr is if it's all loaded to a numpy array? I wonder whether the real illum_corr data (3 channels instead of 2, more planes right, 19 instead of 10?) may not just be larger, e.g. the 12 GB bump there is the 2.5 GB bump here for most cases?

@jluethi
Copy link
Collaborator

jluethi commented Aug 4, 2022

And have you looked into restricting the opportunistic caching explicitly? See here: https://docs.dask.org/en/stable/caching.html

@tcompa
Copy link
Collaborator Author

tcompa commented Aug 4, 2022

We have now tested two sets of artificial examples on the cluster:

  1. The original one described above, with a shape of (2, 10, 16000, 18000) has the same memory footprint on the cluster as it does on a local machine (timing is terrible, like 10 times slower, but that's another issue). This is good to know. The peak memory usage is around 2.5 G for the "good" cases.
  2. For a larger example with shape which is three times larger ((3, 19, 16000, 18000)), we were expecting a linear scaling of the peak memory usage, that is something around 7.5 G (always looking at the same "good" cases). This does not happen, and instead we find a super-linear increase - with peak memory usage around 16 G.

This means: let's forget the illumination-correction task, dask handling of this graph changes qualitatively when going from (2,10,2000,2000) to (3,19,2000,2000), in that memory usage does not scale linearly with the array size.

Is it because of dask being unable to optimize a large graph?
Is it due to cache doing "something" wrong, or accumulating too much?
To be understood.

Vaguely related reference: https://blog.dask.org/2018/06/26/dask-scaling-limits.

PS it's useful to do this artificial test, cause in the illumination-correction case the compression of data introduces a bit more complexity.

And have you looked into restricting the opportunistic caching explicitly? See here: https://docs.dask.org/en/stable/caching.html

We used this feature extensively in the past, for testing, but never found a case where it really mattered. We may try again on this one.

EDIT: fixed wrong shapes

@jluethi
Copy link
Collaborator

jluethi commented Aug 4, 2022

Very good that we have a synthetic example then that reproduces this issue! :)

Maybe the synthetic issue is also something we could report to dask to see if anyone has inputs there? It's certainly surprising.

Just for my understanding: What happens with dask arrays of shape=(2, 10, 16000, 18000) or even better shape=(3, 19, 16000, 18000)? Those would be quite close to what we actually want to process, no?

Vaguely related reference: https://blog.dask.org/2018/06/26/dask-scaling-limits.

Though we are far away from the 100s of GB or TBs that they describe as typical sizes in that post...

@tcompa
Copy link
Collaborator Author

tcompa commented Aug 4, 2022

Just for my understanding: What happens with dask arrays of shape=(2, 10, 16000, 18000) or even better shape=(3, 19, 16000, 18000)? Those would be quite close to what we actually want to process, no?

My bad, these are actually the shapes we are looking at:

  • Shape (2, 10, 16000, 18000), 11 G on disk, 2.5 G peak memory during processing.
  • Shape (3, 19, 16000, 18000), 31 G on disk, 16 G peak memory during processing.

@jluethi
Copy link
Collaborator

jluethi commented Aug 4, 2022

Ok, great. Then we have a synthetic example that reproduces the issue and fits in data size to our actual data! (without dealing with compression, actual function calls for illumcorr etc)

@tcompa
Copy link
Collaborator Author

tcompa commented Aug 4, 2022

For this data,

Shape (3, 19, 16000, 18000), 31 G on disk, 16 G peak memory during processing.

and with parameters {ROI/chunk aligned, inline_array=True, compute=True}, we played a bit with the cache.

First tests:

  • With cache = 2 G, the peak memory is still at 16 G.
  • With cache = 1 G, the peak memory is still at 16 G.
  • With cache = 500 M, the peak memory is still at 16 G.

@jluethi
Copy link
Collaborator

jluethi commented Aug 4, 2022

Ok, I guess we can rule out the opportunistic caching. Thanks for testing! :)

@tcompa
Copy link
Collaborator Author

tcompa commented Aug 5, 2022

Here's the draft for a more self-contained discussion (which we readily turn into a dask issue).


Brief description

We process a 4-dimensional array by applying a simple (delayed) function to several parts of it, and then write it to disk with to_zarr. The scaling of memory usage vs array size is unexpectedly super-linear.
CAVEAT: we are aware that map_blocks would be a better solution in this case, since indexed regions are aligned with chunks; but we must work with index assignments, since we also have to treat cases where regions are not aligned to chunks.

Detailed steps to reproduce

What follows is with python 3.8.13 and dask 2022.7.0.

We first prepare four arrays of different size, and we store them on disk as zarr files:

import numpy as np
import dask.array as da


shapes = [
    (2, 2, 16000, 16000),
    (2, 4, 16000, 16000),
    (2, 8, 16000, 16000),
    (2, 16, 16000, 16000),
]

for shape in shapes:
    shape_id = f"{shape}".replace(",", "").replace(" ", "_")[1:-1]
    x = da.random.randint(0, 2 ** 16 - 1,
                          shape,
                          chunks=(1, 1, 2000, 2000),
                          dtype=np.uint16)
    x.to_zarr(f"data_{shape_id}.zarr")

The on-disk sizes scale as expected with the array shape (i.e. they increase by a factor of two from array to the next one):

2.0G	./data_2_2_16000_16000.zarr
3.9G	./data_2_4_16000_16000.zarr
7.7G	./data_2_8_16000_16000.zarr
16G 	./data_2_16_16000_16000.zarr

For context: in our actual use case, these are 4-dimensional arrays (with indices which are named "channel, z, y, x") that store a bunch of 2-dimensional images coming from an optical microscope.

After preparing these arrays, we process them with

import sys
import numpy as np
import shutil
import dask.array as da
import dask
from memory_profiler import memory_usage

SIZE = 2000

# Function which shifts a SIZExSIZE image by one
def shift_img(img):
    assert img.shape == (SIZE, SIZE)
    return img.copy() + 1

delayed_shift_img = dask.delayed(shift_img)

def process_zarr(input_zarr):
    out_zarr = f"out_{input_zarr}"
    data_old = da.from_zarr(input_zarr)
    data_new = da.empty(data_old.shape,
                        chunks=data_old.chunks,
                        dtype=data_old.dtype)

    n_c, n_z, n_y, n_x = data_old.shape[:]

    print(f"Input file: {input_zarr}")
    print(f"Output file: {out_zarr}")
    print("Array shape:", data_old.shape)
    print("Array chunks:", data_old.chunks)
    print(f"Image size: ({SIZE},{SIZE})")

    # Loop over all images in the array
    for i_x in range(0, n_x - 1, SIZE):
        for i_y in range(0, n_y - 1, SIZE):
            for i_c in range(n_c):
                for i_z in range(n_z):
                    new_img = delayed_shift_img(data_old[i_c, i_z, i_y:i_y+SIZE, i_x:i_x+SIZE])
                    data_new[i_c, i_z, i_y:i_y+SIZE, i_x:i_x+SIZE] = da.from_delayed(new_img,
                                                                                     shape=(SIZE, SIZE),
                                                                                     dtype=data_old.dtype)

    # Write data_new to disk (triggers execution)
    data_new.to_zarr(out_zarr)

    # Clean up output folder
    shutil.rmtree(out_zarr)

if __name__ == "__main__":
    input_zarr = sys.argv[1]
    interval = 0.1
    mem = memory_usage((process_zarr, (input_zarr,)), interval=interval)
    time = np.arange(len(mem)) * interval
    mem_file = "log_memory_" + input_zarr.split(".zarr")[0] + ".dat"
    print(mem_file)
    np.savetxt(mem_file, np.array((time, mem)).T)

where we are applying the function shift_img to all 2-dimensional images forming the 4-dimensional array. We also trace the memory usage of process_zarr, via memory_profiler API.

Unexpected results

We expected a linear scaling of the memory footprint with the array size. Instead, we observe the following peak memory:

# shape, peak memory (GB)
(2. 2. 16k, 16k), 0.43
(2, 4, 16k, 16k), 0.48
(2, 8, 16k, 16k), 1.13
(2, 16, 16k, 16k), 4.29

Apart from the first two (smallest) arrays, where we are probably looking at some fixed-size overhead, we expected a more linear growth of memory, while the two-fold array-size increase from (2, 8, 16k, 16k) to (2, 16, 16k, 16k) yields a four-fold memory increase from ~1G to ~4G.

The more detailed memory trace looks like
fig_memory

As obtained with

import numpy as np
import matplotlib.pyplot as plt

for n_z in [2, 4, 8, 16]:
    f = f"log_memory_data_2_{n_z}_16000_16000.dat"
    t, mem = np.loadtxt(f, unpack=True)
    plt.plot(t, mem, label=f"shape=(2,{n_z},16k,16k)")
    print(f, np.max(mem))

plt.xlabel("Time (s)")
plt.ylabel("Memory (MB)")
plt.legend(fontsize=8, framealpha=1)
plt.grid()
plt.savefig("fig_memory.png", dpi=256)

@jluethi
Copy link
Collaborator

jluethi commented Aug 15, 2022

I did a bit more reading before posting the above issue on dask and hit this issue: dask/distributed#5960

It's a mix of discussions about using LocalCluster, logging and memory management in general. Apparently, there were fixes to some of the memory handling issues though, so I thought it's worthwhile to test our problem with different versions of dask. The results were somewhat surprising.

Using the 2022.07.0 release like above, I reproduce the results you hit @tcompa
fig_memory_dask202207

Comparing this to the 2022.01.1 release that was a workaround in the dask issue above, it doesn't seem to change anything:
fig_memory_dask202201

Thus, likely this issue wasn't actually related to the same underlying memory management problem. But I still tested it with the current dask version (2022.08.0). There were two main differences:

  1. It ran waaaay slower (taking 30min instead of 1.5min for the 16 z levels)
  2. The memory usage pattern is quite different, rising much slower, but spiking briefly in the end. And still rising above 4GB, even though much briefer)

fig_memory_dask202208

This happens reproducibly. @tcompa We had issues in the past where CPU usage was low for a while and then suddenly it would increase to process things. Do you remember what the issue was there? Because maybe that issue now hits us again here? Maybe that helps us trace why things go so much slower here?

Also, another observation with those slow runs: Nothing is written to disk until right before the end => it's not really finishing the writing to disk process. I assume that's why the memory load increases continuously.
=> Can we figure out why it's not writing anything to disk? Shouldn't it finish some of the trees and write the output to disk? If it doesn't do that, it's not surprising that memory accumulates, no?

@jluethi
Copy link
Collaborator

jluethi commented Aug 16, 2022

The mystery continues. I wrote a mapblocks version of our synthetic example and compared the performance to the 2022.07.0 dask indexing performance above.

def shift_img_block(img):
    # Function which shifts a SIZExSIZE image by one
    return img.copy() + 1

def process_zarr_mapblocks(input_zarr):
    out_zarr = f"out_{input_zarr}"
    data_old = da.from_zarr(input_zarr)
    # data_new = da.empty(data_old.shape,
    #                     chunks=data_old.chunks,
    #                     dtype=data_old.dtype)

    n_c, n_z, n_y, n_x = data_old.shape[:]
    dtype = np.uint16
    print(f"Input file: {input_zarr}")
    print(f"Output file: {out_zarr}")
    print("Array shape:", data_old.shape)
    print("Array chunks:", data_old.chunks)
    print(f"Image size: ({SIZE},{SIZE})")

    data_new = data_old.map_blocks(
        shift_img_block,
        chunks=data_old.chunks,
        meta=np.array((), dtype=dtype),
    )            
    
    # Write data_new to disk (triggers execution)
    data_new.to_zarr(out_zarr)

    # Clean up output folder
    shutil.rmtree(out_zarr)

input_zarr = 'data_2_16_16000_16000.zarr'
interval = 0.1
mem = memory_usage((process_zarr_mapblocks, (input_zarr,)), interval=interval)
time = np.arange(len(mem)) * interval
mem_file = "log_memory_" + input_zarr.split(".zarr")[0] + "_mapblocks.dat"
print(mem_file)
np.savetxt(mem_file, np.array((time, mem)).T)

When running on the largest synthetic dataset, mapblocks outperforms the indexing approach in both memory and running time significantly (blue: indexing example from above. Orange: mapblocks performance).

fig_memory_dask202208_Z16_mapblocks

The mapblocks currently runs on the whole 4D array at a time though (but still with the same chunks?), so I'm not sure yet whether it will be slower if we apply it per channel & concatenate again / if we run it per z plane explicitly. Will test later.

One very interesting observation here: The mapblocks implementation looks like it continuously writes to the zarr file, not just at the end when everything is computed. Thus, that would explain why it needs less memory, as it writes piece-by-piece to disk already, thus freeing up memory.

Conclusions:

  1. I will investigate further why our indexing approach does not appear to write to the zarr file continuously. I suspect this is the memory bottleneck.
  2. I will look into the mapblocks a bit more, to see whether I can understand why it runs that much faster (/ whether that still happens when running separately per channel etc.)

@jluethi
Copy link
Collaborator

jluethi commented Aug 17, 2022

I couldn't figure out why the indexing approach does not seem to continuously write to the Zarr file during processing. I now created a topic on this on the dask forum, let's hope we get some feedback there: https://dask.discourse.group/t/using-da-delayed-for-zarr-processing-memory-overhead-how-to-do-it-better/1007

@jluethi
Copy link
Collaborator

jluethi commented Aug 17, 2022

And I reported the differences in runtime for the indexing approach between 2022080 and 2022070 here: dask/dask#9389

@tcompa
Copy link
Collaborator Author

tcompa commented Aug 22, 2022

@tcompa We had issues in the past where CPU usage was low for a while and then suddenly it would increase to process things. Do you remember what the issue was there? Because maybe that issue now hits us again here? Maybe that helps us trace why things go so much slower here?

I think our typical experience with low-CPU usage is related to IO bottlenecks (as in fractal-analytics-platform/fractal-client#57), where writing a large number of small files is a slow process because of disk performances and then it doesn't take up large CPU resources.
I don't see an immediate connection with our current issue. And, by the way, my tests were on a local machine where I don't think there should be any serious IO-speed problem.

@jluethi
Copy link
Collaborator

jluethi commented Aug 22, 2022

Thanks @tcompa
Makes sense. I also ran those tests locally and given the differences between 2022080 and 2022070, that shouldn't be anything IO related.
I reported it here dask/dask#9389 and traced it to a commit changing things in the graph handling in dask, so I do think it's a dask bug

tcompa referenced this issue in fractal-analytics-platform/fractal-client Aug 22, 2022
@jluethi
Copy link
Collaborator

jluethi commented Aug 26, 2022

I think I have the solution to our ROI memory issue! If we use the region parameter of to_zarr and split it up into a task per ROI that then writes that region to disk, it seems to work! Memory usage is down to basically mapblocks levels. Runtime is longer in the synthetic example, but it was shorter in the early tests for illumination correction, so let's see how things will look there.

Orange: Old
Blue: Using region parameter

Dask_indexing_regions

I'll report further details and example code in the coming days, just quickly wanted to share the success 🥇

@jluethi
Copy link
Collaborator

jluethi commented Aug 27, 2022

I reported a detailed summary and evaluation here: https://dask.discourse.group/t/using-da-delayed-for-zarr-processing-memory-overhead-how-to-do-it-better/1007/11

For reference, this is example code that now works very well for me and produces the results above:

Code
def process_zarr_regions(input_zarr, inplace=False, overwrite=True):
    if inplace and not overwrite:
        raise Exception('If zarr is processed in place, overwrite needs to be True')

    out_zarr = f"out_{input_zarr}"
    data_old = da.from_zarr(input_zarr)

    # Prepare output zarr file
    if inplace:
        new_zarr = zarr.open(input_zarr)
    else:
        new_zarr = zarr.create(
            shape=data_old.shape,
            chunks=data_old.chunksize,
            dtype=data_old.dtype,
            store=da.core.get_mapper(out_zarr),
            overwrite=overwrite,
        )
    n_c, n_z, n_y, n_x = data_old.shape[:]

    print(f"Input file: {input_zarr}")
    print(f"Output file: {out_zarr}")
    print("Array shape:", data_old.shape)
    print("Array chunks:", data_old.chunks)
    print(f"Image size: ({SIZE},{SIZE})")

    tasks = []
    regions = []
    for i_c in range(n_c):
        for i_z in range(n_z):
            for i_y in range(0, n_y - 1, SIZE):
                for i_x in range(0, n_x - 1, SIZE):
                    regions.append((slice(i_c, i_c+1), slice(i_z, i_z+1), slice(i_y,i_y+SIZE), slice(i_x,i_x+SIZE)))
    
    for region in regions:
        data_new = shift_img(data_old[region])
        task = data_new.to_zarr(url=new_zarr, region=region, compute=False, overwrite=overwrite)
        tasks.append(task)

    # Compute tasks sequentially
    # TODO: Figure out how to run tasks in parallel where save => batching
    # (where they don't read/write from/to the same chunk)
    for task in tasks:
        task.compute()

I think important next steps for Fractal will be:

  • Integrate this logic into core tasks and test on real-world data => illumination correction example from the initial post
  • Test whether the inplace overwrite works. Otherwise, it could actually get quite messy if we e.g. only process a subset of ROIs or ROIs don't cover the whole dataset (some things get processed, others should maybe be left untouched? Or if we only want to apply something to some channels). Do we still run into this issue here? Writing inplace a modified zarr array with dask failed dask/dask#5942? @tcompa How can we test this? On the dummy test, it looked fine
  • Do we want to look into ways to parallelize the runs where possible? We cannot execute tasks in parallel that write to the same chunk, but could check for this or e.g. just allow it when ROIs match with chunks
  • Future goal: Can we have the whole logic for applying things to ROIs as part of the library? => Simplify tasks. And allows us to optimize things in the library, e.g. add option for parallelization depending on chunk checks or such that we solve once and then is used by all tasks (let's first get it running well for e.g. illumination correction, but let's keep in mind that this is something we may want to abstract)

Also, one learning for the illumination correction: Overhead seems to scale with number of ROIs. Thus, for the illumination correction, I think it will be most efficient to slightly rewrite the correction function. It should also be able to handle a 3D array and just process it plane by plane. That way, we will save a lot on the overhead :) (we should maintain the ability to run on 2D planes only, in case a user provides 2D data or applies illumination correction after an MIP)

@tcompa
Copy link
Collaborator Author

tcompa commented Aug 29, 2022

This looks promising, thank you!

I'll look into this and come back with more detailed comments/updates.

@jluethi
Copy link
Collaborator

jluethi commented Aug 31, 2022

Let's implement the sequential running of ROIs for this issue. I moved the discussion of parallelization to this discussion issue if we want to discuss it further in the future: #44

@tcompa
Copy link
Collaborator Author

tcompa commented Sep 19, 2022

Closed with #79, further testing will be part of #30

@tcompa tcompa closed this as completed Sep 19, 2022
@tcompa tcompa added the Tables AnnData and ROI/feature tables label Sep 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
High Priority Current Priorities & Blocking Issues Tables AnnData and ROI/feature tables
Projects
None yet
Development

No branches or pull requests

2 participants