Skip to content

Commit

Permalink
Merge branch 'main' into forward_collections_metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed May 6, 2024
2 parents 2f9e9db + 1ec61a1 commit 345b6b4
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 151 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-pre-commit.yml
Expand Up @@ -11,7 +11,7 @@ jobs:
name: pre-commit hooks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4.1.2
- uses: actions/checkout@v4.1.3
- uses: actions/setup-python@v5
with:
python-version: '3.9'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/conda.yml
Expand Up @@ -26,7 +26,7 @@ jobs:
name: Build (and upload)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4.1.2
- uses: actions/checkout@v4.1.3
with:
fetch-depth: 0
- name: Set up Python
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-report.yaml
Expand Up @@ -18,7 +18,7 @@ jobs:
run:
shell: bash -l {0}
steps:
- uses: actions/checkout@v4.1.2
- uses: actions/checkout@v4.1.3

- name: Setup Conda Environment
uses: conda-incubator/setup-miniconda@v3.0.3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Expand Up @@ -120,7 +120,7 @@ jobs:
shell: bash

- name: Checkout source
uses: actions/checkout@v4.1.2
uses: actions/checkout@v4.1.3
with:
fetch-depth: 0

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/update-gpuci.yaml
Expand Up @@ -11,7 +11,7 @@ jobs:
if: github.repository == 'dask/distributed'

steps:
- uses: actions/checkout@v4.1.2
- uses: actions/checkout@v4.1.3

- name: Parse current axis YAML
id: rapids_current
Expand Down
100 changes: 53 additions & 47 deletions distributed/diagnostics/memray.py
Expand Up @@ -144,36 +144,40 @@ def memray_workers(
# Sleep for a brief moment such that we get
# a clear profiling signal when everything starts
time.sleep(0.1)
yield
directory.mkdir(exist_ok=True)

client = get_client()
if fetch_reports_parallel is True:
fetch_parallel = len(workers)
elif fetch_reports_parallel is False:
fetch_parallel = 1
else:
fetch_parallel = fetch_reports_parallel

for w in partition(fetch_parallel, workers):
try:
profiles = client.run(
_fetch_memray_profile,
filename=filename,
report_args=report_args,
workers=w,
)
for worker_addr, profile in profiles.items():
path = directory / quote(str(worker_names[worker_addr]), safe="")
if report_args:
suffix = ".html"
else:
suffix = ".memray"
with open(str(path) + suffix, "wb") as fd:
fd.write(profile)

except Exception:
logger.exception("Exception during report downloading from worker %s", w)
try:
yield
finally:
directory.mkdir(exist_ok=True)

client = get_client()
if fetch_reports_parallel is True:
fetch_parallel = len(workers)
elif fetch_reports_parallel is False:
fetch_parallel = 1
else:
fetch_parallel = fetch_reports_parallel

for w in partition(fetch_parallel, workers):
try:
profiles = client.run(
_fetch_memray_profile,
filename=filename,
report_args=report_args,
workers=w,
)
for worker_addr, profile in profiles.items():
path = directory / quote(str(worker_names[worker_addr]), safe="")
if report_args:
suffix = ".html"
else:
suffix = ".memray"
with open(str(path) + suffix, "wb") as fd:
fd.write(profile)

except Exception:
logger.exception(
"Exception during report downloading from worker %s", w
)


@contextlib.contextmanager
Expand Down Expand Up @@ -226,20 +230,22 @@ def memray_scheduler(
# Sleep for a brief moment such that we get
# a clear profiling signal when everything starts
time.sleep(0.1)
yield
directory.mkdir(exist_ok=True)

client = get_client()

profile = client.run_on_scheduler(
_fetch_memray_profile,
filename=filename,
report_args=report_args,
)
path = directory / "scheduler"
if report_args:
suffix = ".html"
else:
suffix = ".memray"
with open(str(path) + suffix, "wb") as fd:
fd.write(profile)
try:
yield
finally:
directory.mkdir(exist_ok=True)

client = get_client()

profile = client.run_on_scheduler(
_fetch_memray_profile,
filename=filename,
report_args=report_args,
)
path = directory / "scheduler"
if report_args:
suffix = ".html"
else:
suffix = ".memray"
with open(str(path) + suffix, "wb") as fd:
fd.write(profile)
90 changes: 84 additions & 6 deletions distributed/shuffle/tests/test_shuffle.py
Expand Up @@ -1232,12 +1232,90 @@ async def test_head(c, s, a, b):


def test_split_by_worker():
workers = ["a", "b", "c"]
npartitions = 5
df = pd.DataFrame({"x": range(100), "y": range(100)})
df["_partitions"] = df.x % npartitions
worker_for = {i: random.choice(workers) for i in range(npartitions)}
s = pd.Series(worker_for, name="_worker").astype("category")
pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [0, 1, 2, 0, 1],
}
)
meta = df[["x"]].head(0)
workers = ["alice", "bob"]
worker_for_mapping = {}
npartitions = 3
for part in range(npartitions):
worker_for_mapping[part] = _get_worker_for_range_sharding(
npartitions, part, workers
)
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
out = split_by_worker(df, "_partition", meta, worker_for)
assert set(out) == {"alice", "bob"}
assert list(out["alice"].to_pandas().columns) == list(df.columns)

assert sum(map(len, out.values())) == len(df)


def test_split_by_worker_empty():
pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [0, 1, 2, 0, 1],
}
)
meta = df[["x"]].head(0)
worker_for = pd.Series({5: "chuck"}, name="_workers").astype("category")
out = split_by_worker(df, "_partition", meta, worker_for)
assert out == {}


def test_split_by_worker_many_workers():
pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [5, 7, 5, 0, 1],
}
)
meta = df[["x"]].head(0)
workers = ["a", "b", "c", "d", "e", "f", "g", "h"]
npartitions = 10
worker_for_mapping = {}
for part in range(npartitions):
worker_for_mapping[part] = _get_worker_for_range_sharding(
npartitions, part, workers
)
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
out = split_by_worker(df, "_partition", meta, worker_for)
assert _get_worker_for_range_sharding(npartitions, 5, workers) in out
assert _get_worker_for_range_sharding(npartitions, 0, workers) in out
assert _get_worker_for_range_sharding(npartitions, 7, workers) in out
assert _get_worker_for_range_sharding(npartitions, 1, workers) in out

assert sum(map(len, out.values())) == len(df)


@pytest.mark.parametrize("drop_column", [True, False])
def test_split_by_partition(drop_column):
pa = pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [3, 1, 2, 3, 1],
}
)
t = pa.Table.from_pandas(df)

out = split_by_partition(t, "_partition", drop_column)
assert set(out) == {1, 2, 3}
if drop_column:
df = df.drop(columns="_partition")
assert out[1].column_names == list(df.columns)
assert sum(map(len, out.values())) == len(df)


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
Expand Down
92 changes: 0 additions & 92 deletions distributed/shuffle/tests/test_shuffle_plugins.py
Expand Up @@ -5,11 +5,6 @@
import pytest

from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
from distributed.shuffle._shuffle import (
_get_worker_for_range_sharding,
split_by_partition,
split_by_worker,
)
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils_test import gen_cluster

Expand All @@ -35,90 +30,3 @@ async def test_installation_on_scheduler(s, a):
assert isinstance(ext, ShuffleSchedulerPlugin)
assert s.handlers["shuffle_barrier"] == ext.barrier
assert s.handlers["shuffle_get"] == ext.get


def test_split_by_worker():
pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [0, 1, 2, 0, 1],
}
)
meta = df[["x"]].head(0)
workers = ["alice", "bob"]
worker_for_mapping = {}
npartitions = 3
for part in range(npartitions):
worker_for_mapping[part] = _get_worker_for_range_sharding(
npartitions, part, workers
)
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
out = split_by_worker(df, "_partition", meta, worker_for)
assert set(out) == {"alice", "bob"}
assert list(out["alice"].to_pandas().columns) == list(df.columns)

assert sum(map(len, out.values())) == len(df)


def test_split_by_worker_empty():
pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [0, 1, 2, 0, 1],
}
)
meta = df[["x"]].head(0)
worker_for = pd.Series({5: "chuck"}, name="_workers").astype("category")
out = split_by_worker(df, "_partition", meta, worker_for)
assert out == {}


def test_split_by_worker_many_workers():
pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [5, 7, 5, 0, 1],
}
)
meta = df[["x"]].head(0)
workers = ["a", "b", "c", "d", "e", "f", "g", "h"]
npartitions = 10
worker_for_mapping = {}
for part in range(npartitions):
worker_for_mapping[part] = _get_worker_for_range_sharding(
npartitions, part, workers
)
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
out = split_by_worker(df, "_partition", meta, worker_for)
assert _get_worker_for_range_sharding(npartitions, 5, workers) in out
assert _get_worker_for_range_sharding(npartitions, 0, workers) in out
assert _get_worker_for_range_sharding(npartitions, 7, workers) in out
assert _get_worker_for_range_sharding(npartitions, 1, workers) in out

assert sum(map(len, out.values())) == len(df)


@pytest.mark.parametrize("drop_column", [True, False])
def test_split_by_partition(drop_column):
pa = pytest.importorskip("pyarrow")

df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [3, 1, 2, 3, 1],
}
)
t = pa.Table.from_pandas(df)

out = split_by_partition(t, "_partition", drop_column)
assert set(out) == {1, 2, 3}
if drop_column:
df = df.drop(columns="_partition")
assert out[1].column_names == list(df.columns)
assert sum(map(len, out.values())) == len(df)
2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -28,7 +28,7 @@ requires-python = ">=3.9"
dependencies = [
"click >= 8.0",
"cloudpickle >= 1.5.0",
"dask == 2024.4.1",
"dask == 2024.5.0",
"jinja2 >= 2.10.3",
"locket >= 1.0.0",
"msgpack >= 1.0.0",
Expand Down

0 comments on commit 345b6b4

Please sign in to comment.