In [None]:
import pandas
import dask
import dask.array as da
import distributed
from distributed import span, wait
from distributed.metrics import meter

pandas.set_option('display.max_rows', 500)
dask.config.set({"optimization.fuse.active": False})

client = distributed.Client(n_workers=6, threads_per_worker=2)

with span("gen_data"):
    a = da.random.random((2**15, 2**15), chunks=("auto", -1))
    a = a.persist()
    wait(a)

a

In [None]:
with meter() as m, span("rechunk"):
    a = a.rechunk((-1, "auto"), method="p2p")
    a = a.persist()
    wait(a)

In [None]:
metrics = client.cluster.scheduler.cumulative_worker_metrics

In [None]:
pandas.Series(
    {
        k: v
        for k, v in metrics.items()
        if isinstance(k, tuple) 
        and k[0] in ("execute", "p2p")
    }
).sort_index()

In [None]:
spans_ext = client.cluster.scheduler.extensions["spans"]
span, = spans_ext.find_by_tags("rechunk")
s = pandas.Series(
    span.cumulative_worker_metrics
).sort_index()
s

In [None]:
# Average shard size
metrics["execute", "rechunk-transfer", "p2p-shards", "bytes"] / metrics["execute", "rechunk-transfer", "p2p-shards", "count"]

In [None]:
print("Client end-to-end time", m.delta)

In [None]:
s2 = s.reset_index()
s2 = s2[s2.level_3 == "seconds"]
del s2["level_2"]
del s2["level_3"]
s2.groupby(["level_0", "level_1"]).sum()