Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions distributed/shuffle/tests/test_shuffle_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,13 @@ async def test_create(s: Scheduler, *workers: Worker):
)

metadata = await exts[0]._create_shuffle(new_metadata)
assert sorted(metadata.workers) == sorted(w.address for w in workers)

# Check shuffle was created on all workers
for ext in exts:
assert len(ext.shuffles) == 1
shuffle = ext.shuffles[new_metadata.id]
assert sorted(shuffle.metadata.workers) == sorted(w.address for w in workers)
assert shuffle.metadata.workers == metadata.workers

# TODO (resilience stage) what happens if some workers already have
# the ID registered, but others don't?
Expand Down Expand Up @@ -213,7 +214,7 @@ async def test_barrier(c: Client, s: Scheduler, *workers: Worker):
await ext._barrier(metadata.id)

# Check scheduler restrictions were set for unpack tasks
for key, i in zip(fs, range(metadata.npartitions)):
for i, key in enumerate(fs):
assert s.tasks[key].worker_restrictions == {metadata.worker_for(i)}

# Check all workers have been informed of the barrier
Expand Down Expand Up @@ -260,8 +261,10 @@ async def test_get_partition(c: Client, s: Scheduler, *workers: Worker):
)
await ext._barrier(metadata.id)

with pytest.raises(AssertionError, match="belongs on"):
ext.get_output_partition(metadata.id, 7)
for addr, ext in exts.items():
if metadata.worker_for(0) != addr:
with pytest.raises(AssertionError, match="belongs on"):
ext.get_output_partition(metadata.id, 0)
Comment on lines +264 to +267
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the change that fixes the flaky test. The rest are just drivebys for readability.


full = pd.concat([p1, p2])
expected_groups = full.groupby("partition")
Expand All @@ -278,6 +281,5 @@ async def test_get_partition(c: Client, s: Scheduler, *workers: Worker):
# Once all partitions are retrieved, shuffles are cleaned up
for ext in exts.values():
assert not ext.shuffles

with pytest.raises(ValueError, match="not registered"):
ext.get_output_partition(metadata.id, 0)
with pytest.raises(ValueError, match="not registered"):
ext.get_output_partition(metadata.id, 0)