# Shuffling large data at constant memory in Dask

**A showcase for P2P shuffling/rechunking at Dask Demo Day 2023-03-16**

To learn more, check out our blog post at [https://blog.coiled.io/blog/shuffling-large-data-at-constant-memory/](https://blog.coiled.io/blog/shuffling-large-data-at-constant-memory/)!

In [None]:
import coiled
import dask
from distributed import Client

## Define utilities

Sources: 
* [https://github.com/coiled/coiled-runtime/blob/c8540241e1c2b19d9348e57e12ac62c689463100/tests/utils_test.py](https://github.com/coiled/coiled-runtime/blob/c8540241e1c2b19d9348e57e12ac62c689463100/tests/utils_test.py)
* [https://github.com/coiled/coiled-runtime/blob/c8540241e1c2b19d9348e57e12ac62c689463100/tests/benchmarks/test_dataframe.py](https://github.com/coiled/coiled-runtime/blob/c8540241e1c2b19d9348e57e12ac62c689463100/tests/benchmarks/test_dataframe.py)

In [None]:
import dask.dataframe as dd
import distributed
import pandas as pd
from dask.datasets import timeseries
from dask.sizeof import sizeof
from dask.utils import format_bytes, parse_bytes

def cluster_memory(client: distributed.Client) -> int:
    """Total memory available on the cluster, in bytes"""
    return int(
        sum(w["memory_limit"] for w in client.scheduler_info()["workers"].values())
    )


def timeseries_of_size(
    target_nbytes: int | str,
    *,
    start="2000-01-01",
    freq="1s",
    partition_freq="1d",
    dtypes={"name": str, "id": int, "x": float, "y": float},
    seed=None,
    **kwargs,
) -> dd.DataFrame:
    """
    Generate a `dask.demo.timeseries` of a target total size.

    Same arguments as `dask.demo.timeseries`, but instead of specifying an ``end`` date,
    you specify ``target_nbytes``. The number of partitions is set as necessary to reach
    approximately that total dataset size. Note that you control the partition size via
    ``freq``, ``partition_freq``, and ``dtypes``.

    Examples
    --------
    >>> timeseries_of_size(
    ...     "1mb", freq="1s", partition_freq="100s", dtypes={"x": float}
    ... ).npartitions
    278
    >>> timeseries_of_size(
    ...     "1mb", freq="1s", partition_freq="100s", dtypes={i: float for i in range(10)}
    ... ).npartitions
    93

    Notes
    -----
    The ``target_nbytes`` refers to the amount of RAM the dask DataFrame would use up
    across all workers, as many pandas partitions.

    This is typically larger than ``df.compute()`` would be as a single pandas
    DataFrame. Especially with many partions, there can be significant overhead to
    storing all the individual pandas objects.

    Additionally, ``target_nbytes`` certainly does not correspond to the size
    the dataset would take up on disk (as parquet, csv, etc.).
    """
    if isinstance(target_nbytes, str):
        target_nbytes = parse_bytes(target_nbytes)

    start_dt = pd.to_datetime(start)
    partition_freq_dt = pd.to_timedelta(partition_freq)
    example_part = timeseries(
        start=start,
        end=start_dt + partition_freq_dt,
        freq=freq,
        partition_freq=partition_freq,
        dtypes=dtypes,
        seed=seed,
        **kwargs,
    )
    p = example_part.compute(scheduler="threads")
    partition_size = sizeof(p)
    npartitions = round(target_nbytes / partition_size)
    assert npartitions > 0, (
        f"Partition size of {format_bytes(partition_size)} > "
        f"target size {format_bytes(target_nbytes)}"
    )

    ts = timeseries(
        start=start,
        end=start_dt + partition_freq_dt * npartitions,
        freq=freq,
        partition_freq=partition_freq,
        dtypes=dtypes,
        seed=seed,
        **kwargs,
    )
    assert ts.npartitions == npartitions
    return ts

def print_dataframe_info(df):
    p = df.partitions[0].compute(scheduler="threads")
    partition_size = sizeof(p)
    total_size = partition_size * df.npartitions
    print(
        f"~{len(p) * df.npartitions:,} rows x {len(df.columns)} columns, "
        f"{format_bytes(total_size)} total, "
        f"{df.npartitions:,} {format_bytes(partition_size)} partitions"
    )

## What is shuffling?

**TL;DR:** Shuffling is used whenever we move a dataset around in an all-to-all fashion, such as occurs in sorting, dataframe joins, or array rechunking.

![](https://assets-global.website-files.com/63192998e5cab906c1b55f6e/633f7b5df9c63728c2ce7ac6_image-3-700x340.png)

## Problem: Task-based shuffling scales poorly

In [None]:
from coiled import Cluster

tasks_cluster = Cluster(
    name="dask-p2p-demo-tasks",
    n_workers=10,
    shutdown_on_close=False,
    wait_for_workers=True,
    worker_vm_types="m6i.large", 
    scheduler_vm_types=["m6i.large"],
    scheduler_options={"idle_timeout": "1 hours"}
)
tasks_client = Client(tasks_cluster)

In [None]:
memory = cluster_memory(tasks_client)
format_bytes(memory)

In [None]:
%%capture --no-display

df = timeseries_of_size(
    memory,
    start="2020-01-01",
    freq="600ms",
    partition_freq="24h",
    dtypes={str(i): float for i in range(100)},
)

In [None]:
print_dataframe_info(df)

In [None]:
with dask.config.set({"dataframe.shuffle.method": "tasks"}):
    shuffled = df.shuffle("0")

In [None]:
final = shuffled.size
print(tasks_client.dashboard_link)
f1 = tasks_client.compute(final)

## Solution: P2P shuffling

In [None]:
from coiled import Cluster

p2p_cluster = Cluster(
    name="dask-p2p-demo-p2p",
    n_workers=10,
    shutdown_on_close=False,
    wait_for_workers=True,
    worker_vm_types="m6i.large", 
    scheduler_vm_types=["m6i.large"],
    scheduler_options={"idle_timeout": "1 hours"}
)
p2p_client = Client(p2p_cluster)

In [None]:
# Restart task-based cluster
tasks_client.restart()

In [None]:
print(f"Task-based dashboard: {tasks_client.dashboard_link}")
with dask.config.set({"dataframe.shuffle.method": "tasks"}):
    shuffled = df.shuffle("0")
f1 = tasks_client.compute(shuffled.size)

print(f"P2P dashboard: {p2p_client.dashboard_link}")
shuffled = df.shuffle("0")
f2 = p2p_client.compute(shuffled.size)

## Preview: P2P rechunking for arrays

In [None]:
import dask.array as da

In [None]:
# Restart clusters
tasks_client.restart()
p2p_client.restart()

In [None]:
shape = (nt, ny, nx) = (2500, 1800, 3600)
chunks = (1, ny, nx)
arr = da.random.random(shape, chunks=chunks)
arr

In [None]:
print(f"Task-based dashboard: {tasks_client.dashboard_link}")
rechunked = arr.rechunk((-1, 90, 36))    
f1 = tasks_client.compute(rechunked.sum())

print(f"P2P dashboard: {p2p_client.dashboard_link}")
with dask.config.set({"optimization.fuse.active": False, "array.rechunk.method": "p2p"}):
    rechunked = arr.rechunk((-1, 90, 36))
    f2 = p2p_client.compute(rechunked.sum())