In [1]:
import numpy as np
from scipy.ndimage import laplace

def op_mean(array_2d):
    """Plain mean of a 2D array."""
    return np.mean(array_2d)

def op_mean_laplace(array_2d):
    """Mean of the laplacian of a 2D array."""
    return np.mean(laplace(array_2d))

def op_std_dev(array_2d):
    """Standard deviation of a 2D array."""
    return np.std(array_2d)

In [2]:
from dask.distributed import Client, LocalCluster

# --- Setup Dask Client ---
client = Client(LocalCluster(n_workers=4, threads_per_worker=2, memory_limit='16GB'))

In [3]:
import dask.array as da
import pandas as pd
import time


def run_dask_benchmark(methods_to_test, shapes_to_test, chunk_configs_func, n_trials=3):
    """
    Runs the benchmark by loading data from Zarr stores on disk.
    """
    results = []
    base_dir = "dask_data"

    for shape in shapes_to_test:
        chunk_configs = chunk_configs_func(shape)
        for chunk_name in chunk_configs.keys(): # Iterate by name
            print(f"\n--- Shape: {shape} | Chunks: {chunk_name} ---")
            
            shape_str = 'x'.join(map(str, shape))
            path = os.path.join(base_dir, f"data_{shape_str}_{chunk_name}.zarr")
            
            # Load the array lazily from disk
            dask_arr = da.from_zarr(path)
            
            for method_name, core_func in methods_to_test.items():
                print(f"  -> Testing Method: {method_name}...")
                
                for _ in range(n_trials):
                    start = time.perf_counter()
                    # The computation itself remains the same
                    da.apply_gufunc(core_func, "(i,j)->()", dask_arr).compute()
                    end = time.perf_counter()
                    results.append({
                        'method': method_name, 
                        'shape': str(shape), 
                        'chunks': chunk_name, 
                        'time': end - start
                    })
            
    return pd.DataFrame(results)

In [4]:
import os
import dask.array as da
import zarr # You may need to install this: pip install zarr

def prepare_datasets(shapes_to_test, chunk_configs_func):
    """
    Generates Dask arrays and saves them to disk in Zarr format.
    """
    print("--- Preparing and saving datasets to disk ---")
    base_dir = "dask_data"
    os.makedirs(base_dir, exist_ok=True)
    
    for shape in shapes_to_test:
        chunk_configs = chunk_configs_func(shape)
        for chunk_name, chunks in chunk_configs.items():
            shape_str = 'x'.join(map(str, shape))
            # Define a unique path for each dataset configuration
            path = os.path.join(base_dir, f"data_{shape_str}_{chunk_name}.zarr")
            
            if os.path.exists(path):
                print(f"  -> Dataset already exists: {path}")
                continue
                
            print(f"  -> Creating dataset: {path}")
            # Create a random array
            dask_arr = da.random.random(size=shape, chunks=chunks)
            # Save it to a Zarr store on disk
            dask_arr.to_zarr(path, overwrite=True)

In [5]:
import seaborn as sns
import matplotlib.pyplot as plt

# --- Define Test Configurations ---
METHODS_TO_TEST = {
    "Mean": op_mean,
    "Mean Laplace": op_mean_laplace,
}

xy_size = 500
other_size = 50

SHAPES_TO_TEST = [
    (xy_size, xy_size),
    (4, xy_size, xy_size),
    (other_size, other_size, 4, xy_size),
    (other_size, other_size, 4, xy_size, xy_size)
]

def get_chunk_configs(shape):
    # Same function as before
    n, m = shape[-2], shape[-1]
    other_dims = shape[:-2] if len(shape) > 2 else ()
    return {
        "Small": (*[1 for d in other_dims], 50, 50),
        "Sliced": (*[1 for d in other_dims], n, m),
        "Ideal": (*[d//2 if d > 1 else 1 for d in other_dims], n, m),
        "Realistic": (*[d//2 if d > 1 else 1 for d in other_dims], n//4, m//4),
    }

# --- Execute and Plot ---
prepare_datasets(SHAPES_TO_TEST, get_chunk_configs)
df_results = run_dask_benchmark(METHODS_TO_TEST, SHAPES_TO_TEST, get_chunk_configs)

--- Preparing and saving datasets to disk ---
  -> Creating dataset: dask_data/data_500x500_Small.zarr
  -> Creating dataset: dask_data/data_500x500_Sliced.zarr
  -> Creating dataset: dask_data/data_500x500_Ideal.zarr
  -> Creating dataset: dask_data/data_500x500_Realistic.zarr
  -> Creating dataset: dask_data/data_4x500x500_Small.zarr
  -> Creating dataset: dask_data/data_4x500x500_Sliced.zarr
  -> Creating dataset: dask_data/data_4x500x500_Ideal.zarr
  -> Creating dataset: dask_data/data_4x500x500_Realistic.zarr
  -> Creating dataset: dask_data/data_50x50x4x500_Small.zarr


This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


  -> Creating dataset: dask_data/data_50x50x4x500_Sliced.zarr
  -> Creating dataset: dask_data/data_50x50x4x500_Ideal.zarr
  -> Creating dataset: dask_data/data_50x50x4x500_Realistic.zarr
  -> Creating dataset: dask_data/data_50x50x4x500x500_Small.zarr


This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


  -> Creating dataset: dask_data/data_50x50x4x500x500_Sliced.zarr


This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


  -> Creating dataset: dask_data/data_50x50x4x500x500_Ideal.zarr


2025-07-26 13:55:35,332 - distributed.worker - ERROR - Compute Failed
Key:       ('random_sample-store-map-06b51e628cfba7cf4a23f534493c940e', 1, 1, 1, 0, 0)
State:     executing
Task:  <Task ('random_sample-store-map-06b51e628cfba7cf4a23f534493c940e', 1, 1, 1, 0, 0) _execute_subgraph(...)>
Exception: "ValueError('Codec does not support buffers of > 2147483647 bytes')"
Traceback: '  File "/media/data/Development/hi/repos/pixel-patrol/.venv/lib/python3.12/site-packages/dask/array/core.py", line 4617, in load_store_chunk\n    out[index] = x\n    ~~~^^^^^^^\n  File "/media/data/Development/hi/repos/pixel-patrol/.venv/lib/python3.12/site-packages/zarr/core.py", line 1449, in __setitem__\n    self.set_orthogonal_selection(pure_selection, value, fields=fields)\n  File "/media/data/Development/hi/repos/pixel-patrol/.venv/lib/python3.12/site-packages/zarr/core.py", line 1638, in set_orthogonal_selection\n    self._set_selection(indexer, value, fields=fields)\n  File "/media/data/Development/hi/

ValueError: Codec does not support buffers of > 2147483647 bytes

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

def plot_dask_results(df):
    """
    Generates and saves a separate plot for each benchmarked method,
    comparing chunking strategies.
    """
    print("\n--- Generating Dask performance plots ---")
    
    for method_name in df['method'].unique():
        print(f"  -> Creating plot for: {method_name}")
        
        df_method = df[df['method'] == method_name]
        
        g = sns.catplot(
            data=df_method,
            kind="bar",
            col="shape",      # Create a subplot column for each shape
            x="chunks",       # Compare chunking strategies on the x-axis
            y="time",
            sharey=False,
            height=6,
            aspect=1.0
        )
        
        # Configure titles and labels
        g.set_axis_labels("Chunking Strategy", "Execution Time (s)")
        g.set_titles("Shape: {col_name}")
        g.fig.suptitle(f"Dask Performance for: {method_name}", y=1.03, fontsize=16)
        
        # Create a dynamic filename for each plot
        safe_filename = method_name.lower().replace(' ', '_')
        output_filename = f'dask_benchmark_{safe_filename}.png'
        
        # Save the figure
        g.figure.savefig(output_filename, dpi=300, bbox_inches='tight')
        print(f"     ✅ Plot saved to {output_filename}")
        
        plt.show()
        plt.close(g.fig)

In [None]:
plot_dask_results(df_results)