# Memory Aware Chunking

## Introduction

This notebook explores the concept of memory-aware chunking in Dask.
The goal is to understand how to optimize the chunk size of a Dask array based on the available memory and the size of the data being processed.
This can help improve performance and reduce memory usage when working with large datasets.

For this notebook, we will levereage all the findings from the previous notebooks and experiments, such as:
- The proper way to measure memory usage.
- How to build a model to predict memory usage.
- How to execute Dask tasks in parallel using multiple workers.

**Objectives:**
- Evaluate and discover how to do memory-aware chunking.
- Compare the performance against Dask's auto-chunking.
- Compare the performance against manual chunking.

## Limitations

The findings presented in this notebook, and the experiment as a whole, primarily apply to local Dask clusters.
Based on Dask docs, the results may be applicable to remote clusters as well, but to simplify the experimentation process and avoid the need for a remote cluster, we will focus on local clusters.

## Experiment Setup

To ensure reliable and reproducible memory profiling, this notebook follows a structured experimental setup.
The setup includes defining the environment, configuring dependencies, and establishing a controlled execution process.

### Environment & Dependencies

The experiment is conducted in a Python environment with the following key libraries:
- **Dask** - For parallel computing and task scheduling.
- **Matplotlib** - For plotting and visualizing data.
- **Pandsas** - For data manipulation and analysis.
- **Setuptools** - For managing the installation of local modules.
- **Scikit-learn** - For machine learning tasks, including model training and evaluation.

In [1]:
!pip install --upgrade pip
!pip install "dask[distributed]" pandas matplotlib setuptools scikit-learn xgboost

Looking in indexes: https://pypi.org/simple, https://daniel.d2%40doordash.com:****@ddartifacts.jfrog.io/ddartifacts/api/pypi/pypi-local/simple/
Looking in indexes: https://pypi.org/simple, https://daniel.d2%40doordash.com:****@ddartifacts.jfrog.io/ddartifacts/api/pypi/pypi-local/simple/


We also rely on a feel common tools that are shared across different experiments.
We'll install it from our local module.

In [2]:
!pip install -e ../../../libs/common

Looking in indexes: https://pypi.org/simple, https://daniel.d2%40doordash.com:****@ddartifacts.jfrog.io/ddartifacts/api/pypi/pypi-local/simple/
Obtaining file:///Users/delucca/Workspaces/src/unicamp/memory-aware-chunking/libs/common
  Installing build dependencies ... [?25ldone
[?25h  Checking if build backend supports build_editable ... [?25ldone
[?25h  Getting requirements to build editable ... [?25ldone
[?25h  Preparing editable metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: common
  Building editable for common (pyproject.toml) ... [?25ldone
[?25h  Created wheel for common: filename=common-0.0.1-0.editable-py3-none-any.whl size=4014 sha256=1081183d84b51e95c8f106531b2000a1c588ef91061bc6b5fc87834813d54c6a
  Stored in directory: /private/var/folders/sl/3td2vnj56c38q77xf6_s8qrr0000gn/T/pip-ephem-wheel-cache-0seqs5d5/wheels/17/bf/90/e7b02ffba6777f2f799fce361bf93cbf760259590e6bb2023d
Successfully built common
Installing collected packa

With all dependencies installed, we also need to setup the experiment output directory.

In [3]:
import datetime
import os

timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
output_dir = f"../out/notebooks/02-memory-aware-chunking-{timestamp}"

os.makedirs(output_dir, exist_ok=True)

Finally, let's ignore some warning

In [4]:
import warnings

warnings.filterwarnings(
    "ignore",
    message=r"Sending large graph of size .*MiB",
    category=UserWarning,
)

## Methodology

This notebook demonstrates how to define, test, and compare different chunking strategies for the GST3D seismic operator under constrained memory conditions.
The process leverages a pre-trained memory usage model (from Experiment 2) to guide chunk size selection.

### Data Generation

- A synthetic seismic dataset is created with dimensions representative of real-world 3D seismic volumes.
- The dataset is large enough to trigger out-of-memory (OOM) risks when chunking is suboptimal.

### Cluster Configuration

- A local Dask cluster is launched with a limited memory budget per worker.
- The number of workers is chosen to facilitate parallel processing but still remain within realistic resource constraints.

### Memory Usage Estimation

- The trained memory usage model (identified as optimal for GST3D in Experiment 2) is used to estimate how different chunk sizes might impact overall memory requirements.
- The model provides a recommended chunk size that should fit in worker memory without causing OOM errors.

### Chunking Strategies

1. **Dask Auto-Chunking**
   - The cluster executes GST3D using Dask’s automatic chunking approach.
   - This run may lead to OOM failures if the default chunks are too large.

2. **Manual Chunking**
   - Several manual chunk sizes are tested to determine a workable configuration.
   - The best manual chunk size is identified through empirical trials (e.g., smallest execution time without OOM).

3. **Memory-Aware Chunking (Model-Based)**
   - The notebook applies the model-predicted chunk size for GST3D.
   - This approach aims for optimal performance without OOM errors, relying on the estimated memory usage.

### Comparison and Analysis

- Execution time, peak memory usage, and overall success/failure rates are collected for each chunking strategy.
- The outcomes are compared to highlight performance differences and validate the effectiveness of the memory-aware approach.
- Observations are used to confirm whether the model-derived chunk size matches or outperforms manual and auto-chunk configurations.

## Experiment Execution and Data Collection


### Step 1: Data Generation

In [5]:
INLINES = 200
XLINES = 200
SAMPLES = 200

In [6]:
from common.builders import build_seismic_data

data_segy_path = build_seismic_data(
    inlines=INLINES,
    xlines=XLINES,
    samples=SAMPLES,
    output_dir=output_dir,
    prefix="gst3d_experiment_"
)

print("Synthetic seismic data generated at:", data_segy_path)

Generating synthetic data for shape (200, 200, 200)
Synthetic data generated successfully
Synthetic data saved to ../out/notebooks/02-memory-aware-chunking-20250330204013/gst3d_experiment_-200-200-200.segy
Synthetic seismic data generated at: ../out/notebooks/02-memory-aware-chunking-20250330204013/gst3d_experiment_-200-200-200.segy


### Step 2: Memory Usage Estimation

In [7]:
MODEL_PATH = '../../02-predicting-memory-consumption-from-input-shapes/out/results/20250322210428/best_models/gst3d.pkl'
MEMORY_LIMIT_GB = 1.5


In [8]:
import pickle as pkl

model = pkl.load(open(MODEL_PATH, 'rb'))
model

In [9]:
import pandas as pd

df = pd.DataFrame({
    "volume": [INLINES * XLINES * SAMPLES],
})
X = df[["volume"]]

y_pred = model.predict(X)
memory_usage_pred = y_pred[0]

print(f"Estimated memory usage for GST3D: {memory_usage_pred:.2f} GB")

Estimated memory usage for GST3D: 1.89 GB


## Step 3: Evaluating Chunking Strategies

In [10]:
import time
import threading

monitoring = False


def monitor_memory(client, memory_usage_history, interval=0.2):
    while monitoring:
        info = client.scheduler_info()
        for addr, worker_info in info["workers"].items():
            memory_usage_history[addr].append(worker_info.get("metrics", {}).get("memory", 0))
        time.sleep(interval)


def start_monitoring(client, memory_usage_history):
    global monitoring
    monitoring = True
    thread = threading.Thread(target=monitor_memory, daemon=True, args=(client, memory_usage_history,))
    thread.start()

    return thread


def stop_monitoring(thread):
    global monitoring
    monitoring = False
    thread.join()

### Strategy 1: Dask Auto-Chunking

In [11]:
from dask.distributed import LocalCluster, Client

from common.operators.gst3d import gradient_structure_tensor_from_segy

cluster = LocalCluster(
    n_workers=2,
    threads_per_worker=1,
    memory_limit=f"{MEMORY_LIMIT_GB}GB",
)

client = Client(cluster)

memory_usage_history = {addr: [] for addr in client.scheduler_info()["workers"]}
monitoring_thread = start_monitoring(client, memory_usage_history)

start_time = time.time()
dip_map = gradient_structure_tensor_from_segy(data_segy_path, use_dask=True)
dip_result = dip_map.compute()
end_time = time.time()
elapsed_time = end_time - start_time
print("Dip result shape:", dip_result.shape)
print("Dip min:", dip_result.min(), "Dip max:", dip_result.max())
print(f"Elapsed time: {elapsed_time:.2f} seconds")

stop_monitoring(monitoring_thread)

peak_memory_usages = {key: max(value) for key, value in memory_usage_history.items()}
for addr, mem_bytes in peak_memory_usages.items():
    print(f"Worker {addr} peak memory: {mem_bytes / 1024 ** 3:.2f} GB")

client.shutdown()
client.close()
cluster.close()

Loaded data shape: (200, 200, 200)
Loaded data chunk sizes: ((200,), (200,), (200,))


Exception in thread Thread-5 (monitor_memory):
Traceback (most recent call last):
  File [35m"/Users/delucca/.pyenv/versions/3.13.0/lib/python3.13/threading.py"[0m, line [35m1041[0m, in [35m_bootstrap_inner[0m
    [31mself.run[0m[1;31m()[0m
    [31m~~~~~~~~[0m[1;31m^^[0m
  File [35m"/Users/delucca/.pyenv/versions/mac__3.13__experiments__03/lib/python3.13/site-packages/ipykernel/ipkernel.py"[0m, line [35m766[0m, in [35mrun_closure[0m
    [31m_threading_Thread_run[0m[1;31m(self)[0m
    [31m~~~~~~~~~~~~~~~~~~~~~[0m[1;31m^^^^^^[0m
  File [35m"/Users/delucca/.pyenv/versions/3.13.0/lib/python3.13/threading.py"[0m, line [35m992[0m, in [35mrun[0m
    [31mself._target[0m[1;31m(*self._args, **self._kwargs)[0m
    [31m~~~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/var/folders/sl/3td2vnj56c38q77xf6_s8qrr0000gn/T/ipykernel_48508/2185371761.py"[0m, line [35m11[0m, in [35mmonitor_memory[0m
    [31mmemory_usage_history[0m[1;31m[add

KilledWorker: Attempted to run task ('_eig_per_voxel-c00d151fa6b319f7a754d79c3faaff0b', 0, 0, 0) on 4 different workers, but all those workers died while running it. The last worker that attempt to run the task was tcp://127.0.0.1:57309. Inspecting worker logs is often a good next step to diagnose what went wrong. For more information see https://distributed.dask.org/en/stable/killed.html.

As you can see, since Dask auto chunking ignores the expected memory usage, we run into an OOM error.

### Strategy 2: Manual Chunking

For manual we chunk we do the following:

1. Run without any limites
2. Check the used memory
3. Optimize the chunks to fit into our cluster

In [12]:
client.shutdown()
client.close()
cluster.close()

cluster = LocalCluster(
    n_workers=2,
    threads_per_worker=1,
)

client = Client(cluster)

memory_usage_history = {addr: [] for addr in client.scheduler_info()["workers"]}
monitoring_thread = start_monitoring(client, memory_usage_history)

start_time = time.time()
dip_map = gradient_structure_tensor_from_segy(data_segy_path, use_dask=True, dask_chunks=(INLINES, XLINES, SAMPLES))
dip_result = dip_map.compute()
end_time = time.time()
elapsed_time = end_time - start_time
print("Dip result shape:", dip_result.shape)
print("Dip min:", dip_result.min(), "Dip max:", dip_result.max())
print(f"Elapsed time: {elapsed_time:.2f} seconds")

stop_monitoring(monitoring_thread)

peak_memory_usages = {key: max(value) for key, value in memory_usage_history.items()}
for addr, mem_bytes in peak_memory_usages.items():
    print(f"Worker {addr} peak memory: {mem_bytes / 1024 ** 3:.2f} GB")

client.shutdown()
client.close()
cluster.close()

Loaded data shape: (200, 200, 200)
Loaded data chunk sizes: ((200,), (200,), (200,))
Dip result shape: (198, 198, 198)
Dip min: 90.0 Dip max: 90.0
Elapsed time: 5.57 seconds
Worker tcp://127.0.0.1:57416 peak memory: 0.10 GB
Worker tcp://127.0.0.1:57417 peak memory: 2.15 GB


In [13]:
cluster = LocalCluster(
    n_workers=2,
    threads_per_worker=1,
    memory_limit=f"{MEMORY_LIMIT_GB}GB",
)

client = Client(cluster)

memory_usage_history = {addr: [] for addr in client.scheduler_info()["workers"]}
monitoring_thread = start_monitoring(client, memory_usage_history)

start_time = time.time()
dip_map = gradient_structure_tensor_from_segy(data_segy_path, use_dask=True, dask_chunks=(100, 100, 100))
dip_result = dip_map.compute()
end_time = time.time()
manual_elapsed_time = end_time - start_time
print("Dip result shape:", dip_result.shape)
print("Dip min:", dip_result.min(), "Dip max:", dip_result.max())
print(f"Elapsed time: {elapsed_time:.2f} seconds")

stop_monitoring(monitoring_thread)

peak_memory_usages = {key: max(value) for key, value in memory_usage_history.items()}
for addr, mem_bytes in peak_memory_usages.items():
    print(f"Worker {addr} peak memory: {mem_bytes / 1024 ** 3:.2f} GB")

client.shutdown()
client.close()
cluster.close()

Loaded data shape: (200, 200, 200)
Loaded data chunk sizes: ((100, 100), (100, 100), (100, 100))
Dip result shape: (198, 198, 198)
Dip min: 90.0 Dip max: 90.0
Elapsed time: 5.57 seconds
Worker tcp://127.0.0.1:57443 peak memory: 0.65 GB
Worker tcp://127.0.0.1:57444 peak memory: 0.73 GB


### Strategy 3: Memory-Aware Chunking

In [14]:
cluster = LocalCluster(
    n_workers=2,
    threads_per_worker=1,
    memory_limit=f"{MEMORY_LIMIT_GB}GB",
)

client = Client(cluster)

memory_usage_history = {addr: [] for addr in client.scheduler_info()["workers"]}
monitoring_thread = start_monitoring(client, memory_usage_history)
ratio = MEMORY_LIMIT_GB / memory_usage_pred
mac_chunk_size = int(INLINES * ratio)

print(f"Using chunk size: {mac_chunk_size}")

start_time = time.time()
dip_map = gradient_structure_tensor_from_segy(data_segy_path, use_dask=True, dask_chunks=mac_chunk_size)
dip_result = dip_map.compute()
end_time = time.time()
mac_elapsed_time = end_time - start_time
print("Dip result shape:", dip_result.shape)
print("Dip min:", dip_result.min(), "Dip max:", dip_result.max())
print(f"Elapsed time: {elapsed_time:.2f} seconds")

stop_monitoring(monitoring_thread)

peak_memory_usages = {key: max(value) for key, value in memory_usage_history.items()}
for addr, mem_bytes in peak_memory_usages.items():
    print(f"Worker {addr} peak memory: {mem_bytes / 1024 ** 3:.2f} GB")

client.shutdown()
client.close()

cluster.close()

Using chunk size: 159
Loaded data shape: (200, 200, 200)
Loaded data chunk sizes: ((159, 41), (159, 41), (159, 41))


Exception in thread Thread-8 (monitor_memory):
Traceback (most recent call last):
  File [35m"/Users/delucca/.pyenv/versions/3.13.0/lib/python3.13/threading.py"[0m, line [35m1041[0m, in [35m_bootstrap_inner[0m
    [31mself.run[0m[1;31m()[0m
    [31m~~~~~~~~[0m[1;31m^^[0m
  File [35m"/Users/delucca/.pyenv/versions/mac__3.13__experiments__03/lib/python3.13/site-packages/ipykernel/ipkernel.py"[0m, line [35m766[0m, in [35mrun_closure[0m
    [31m_threading_Thread_run[0m[1;31m(self)[0m
    [31m~~~~~~~~~~~~~~~~~~~~~[0m[1;31m^^^^^^[0m
  File [35m"/Users/delucca/.pyenv/versions/3.13.0/lib/python3.13/threading.py"[0m, line [35m992[0m, in [35mrun[0m
    [31mself._target[0m[1;31m(*self._args, **self._kwargs)[0m
    [31m~~~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/var/folders/sl/3td2vnj56c38q77xf6_s8qrr0000gn/T/ipykernel_48508/2185371761.py"[0m, line [35m11[0m, in [35mmonitor_memory[0m
    [31mmemory_usage_history[0m[1;31m[add

KilledWorker: Attempted to run task ('_eig_per_voxel-dfe6b8b3743dbbc02ea0345cdc3fc376', 0, 0, 0) on 4 different workers, but all those workers died while running it. The last worker that attempt to run the task was tcp://127.0.0.1:57498. Inspecting worker logs is often a good next step to diagnose what went wrong. For more information see https://distributed.dask.org/en/stable/killed.html.

Interestingly enough, it seems that the memory-aware chunking is not working as expected.
Let's further evaluate.
The actual memory usage was 2.16GB, while our prediction was 1.89GB.
Based on this, we have two alternatives:
- Either our model is not accurate enough, and we need to improve it.
- Or Dask adds memory overhead, and since we trained the model without Dask, the predictions are not accurate.

In [20]:
import os

original_exp_dir = os.path.dirname(os.path.dirname(MODEL_PATH))
profiles_file = f"{original_exp_dir}/results/operators/gst3d/results/profile_summary.csv"
df = pd.read_csv(profiles_file)

df.head()

Unnamed: 0,volume,peak_memory_usage_avg,peak_memory_usage_std_dev,peak_memory_usage_min,peak_memory_usage_max,execution_time_avg,execution_time_std_dev,execution_time_min,execution_time_max,n_samples,peak_memory_usage_cv,execution_time_cv,peak_memory_usage_unit,execution_time_unit
0,1000000,0.312251,0.000291,0.311916,0.312443,0.248579,0.001258,0.247539,0.249977,3,0.000931,0.00506,gb,s
1,2000000,0.536648,0.001451,0.534859,0.5383,0.499053,0.007539,0.489694,0.512676,9,0.002704,0.015107,gb,s
2,3000000,0.759475,0.001495,0.757774,0.761208,0.755814,0.014851,0.734732,0.773649,9,0.001969,0.019649,gb,s
3,4000000,0.984115,0.002365,0.980434,0.987648,1.018204,0.02795,0.983644,1.098931,18,0.002403,0.02745,gb,s
4,6000000,1.435727,0.001104,1.434372,1.437447,1.542636,0.051048,1.486074,1.70543,18,0.000769,0.033091,gb,s


In [23]:
df[df["volume"] == INLINES * XLINES * SAMPLES]

Unnamed: 0,volume,peak_memory_usage_avg,peak_memory_usage_std_dev,peak_memory_usage_min,peak_memory_usage_max,execution_time_avg,execution_time_std_dev,execution_time_min,execution_time_max,n_samples,peak_memory_usage_cv,execution_time_cv,peak_memory_usage_unit,execution_time_unit
5,8000000,1.886785,0.003027,1.884224,1.895241,2.065298,0.071216,1.99969,2.323209,21,0.001604,0.034482,gb,s


As we can see, the captured peak memory usage for the same volume was 1.88GB, so our model is accurate enough.
Therefore, we need to retrain our model with data captured using Dask.
Let's do it on this notebook.

In [24]:
import itertools
from common import builders

initial_size = 100
final_size = 400
step_size = 100

dataset_shapes = list(range(initial_size, final_size + 1, step_size))
dataset_combinations = list(itertools.product(dataset_shapes, repeat=3))

print(f"Generated {len(dataset_combinations)} dataset combinations")

inputs = []
inputs_ouput_dir = f"{output_dir}/train/inputs"

for inlines, xlines, samples in dataset_combinations:
    generated_data = builders.build_seismic_data(inlines=inlines, xlines=xlines, samples=samples,
                                                 output_dir=inputs_ouput_dir)
    inputs.append(generated_data)

Generated 64 dataset combinations
Generating synthetic data for shape (100, 100, 100)
Synthetic data generated successfully
Synthetic data saved to ../out/notebooks/02-memory-aware-chunking-20250330204013/train/inputs/100-100-100.segy
Generating synthetic data for shape (100, 100, 200)
Synthetic data generated successfully
Synthetic data saved to ../out/notebooks/02-memory-aware-chunking-20250330204013/train/inputs/100-100-200.segy
Generating synthetic data for shape (100, 100, 300)
Synthetic data generated successfully
Synthetic data saved to ../out/notebooks/02-memory-aware-chunking-20250330204013/train/inputs/100-100-300.segy
Generating synthetic data for shape (100, 100, 400)
Synthetic data generated successfully
Synthetic data saved to ../out/notebooks/02-memory-aware-chunking-20250330204013/train/inputs/100-100-400.segy
Generating synthetic data for shape (100, 200, 100)
Synthetic data generated successfully
Synthetic data saved to ../out/notebooks/02-memory-aware-chunking-202503

In [40]:
from common import runners

n_runs = 5

for shape in dataset_combinations:
    dataset_name = "-".join(map(str, shape))
    inlines, xlines, samples = shape
    print("---")
    print(f"Running experiment for dataset: {dataset_name}")

    runners.run_isolated_container(
        experiment_n_runs=n_runs,
        experiment_build_context="../",
        experiment_extra_contexts=["../../../libs/common"],
        experiment_volumes={
            f"{output_dir}/train": "/experiment/out"
        },
        experiment_env={
            "SESSION_ID": dataset_name,
            "INPUT_PATH": f"/experiment/out/inputs/{dataset_name}.segy",
            "DASK_CHUNKS_INLINES": f"{inlines}",
            "DASK_CHUNKS_XLINES": f"{xlines}",
            "DASK_CHUNKS_SAMPLES": f"{samples}",
        }
    )

---
Running experiment for dataset: 100-100-100
Using existing Docker volume: mac__dind-storage
Running isolated container...
Finished running isolated container. Exit status: 0
---
Running experiment for dataset: 100-100-200
Using existing Docker volume: mac__dind-storage
Running isolated container...
Finished running isolated container. Exit status: 0
---
Running experiment for dataset: 100-100-300
Using existing Docker volume: mac__dind-storage
Running isolated container...
Finished running isolated container. Exit status: 0
---
Running experiment for dataset: 100-100-400
Using existing Docker volume: mac__dind-storage
Running isolated container...
Finished running isolated container. Exit status: 0
---
Running experiment for dataset: 100-200-100
Using existing Docker volume: mac__dind-storage
Running isolated container...
Finished running isolated container. Exit status: 0
---
Running experiment for dataset: 100-200-200
Using existing Docker volume: mac__dind-storage
Running isolat

KeyboardInterrupt: 

In [None]:
import os
import random
import json

from common import transformers

profiles_directory = f"{output_dir}/train/profiles"
dask_profiles = [f for f in os.listdir(profiles_directory) if
                 os.path.isfile(os.path.join(profiles_directory, f)) and f.endswith(".json")]
sample_dask_profile_path = random.choice(dask_profiles)

sample_dask_profile = json.loads(f"{output_dir}/profiles/{sample_dask_profile_path}")
sample_peak_memory_usage = max(
    [transformers.transform_b_to_gb(float(d['dask_memory_usage'])) for d in sample_dask_profile['data']])

print(f"Sample Dask profile peak memory usage: {round(sample_peak_memory_usage, 2)} GB")

## Findings

## Next Steps