Skip to content

Fix bag sampling when there are some smaller partitions#9349

Merged
pavithraes merged 2 commits intodask:mainfrom
ian-r-rose:bag-sample-with-smaller-partitions
Aug 9, 2022
Merged

Fix bag sampling when there are some smaller partitions#9349
pavithraes merged 2 commits intodask:mainfrom
ian-r-rose:bag-sample-with-smaller-partitions

Conversation

@ian-r-rose
Copy link
Copy Markdown
Collaborator

@ian-r-rose ian-r-rose commented Aug 3, 2022

Fixes the rest of #9249. Currently, if you have some imbalanced partitions in a bag such that some of them are small, and then sample from that bag with a k larger than those partitions, you can run into the error in #9249. The error message isn't correct in that case.

I still need to write some tests and also sit down with a pencil and paper to convince myself and others that this doesn't bias the sampling, so marking as a draft until I can get around to that.

of samples. This allows you to have some small partitions.
@ian-r-rose ian-r-rose added bag bug Something is broken labels Aug 3, 2022
@ian-r-rose ian-r-rose requested a review from pavithraes August 3, 2022 19:06
@ian-r-rose
Copy link
Copy Markdown
Collaborator Author

Let's look at the following:

import dask.bag
from dask.bag import random

bag = dask.bag.from_sequence(range(10), partition_size=3)
bag2 = random.sample(bag, k=8, split_every=2)
bag2.compute()

This should sample 8 elements from a bag with four partitions, and ten total elements. Should be totally doable.

The bag looks something like

[0, 1, 2] [3, 4, 5] [6, 7, 8], [9]

In main, it tries to sample k from each partition, then combines them, samples k again, combines them again, etc until we have a single partition. But each partition is too small to sample all 8, so we just wind up sampling the whole partition. In the intermediate combine steps, the algorithm checks to see whether k is larger than the intermediate population size, and it is! So it thows the error, when really we wanted to know whether k is larger than the total population size.

flowchart TD
A["[0, 1, 2]"];
B["[3, 4, 5]"];
C["[6, 7, 8]"];
D["[9]"];
AB["[0, 1, 2, 3, 4, 5]"];
A --> AB;
B --> AB;
AB --> ABcheck["Is k > len(pop)?"] --> ABfail["Error"];
CD["[6, 7, 8, 9]"];
C --> CD;
D --> CD;
CD --> CDcheck["Is k > len(pop)?"] --> CDfail["Error"];
Loading

With this PR, we defer the k check until the end. If the current partition has fewer items in it than the k, we just forward it along, with the understanding that we will eventually be able to do some actual downsampling. This means that we potentially have to do more work before we know if the initial request was valid! But I don't really see a way around that: bags can have indeterminate length, so the only way to really know whether k was bigger than the length is to actually perform the reduction. (Note, this is not strictly true, we could do a len at the start, at the cost of some eager computations. Or we could make some intermediate "check" steps in the graph, at the cost of a more complex, coupled task graph. I don't really like either of those).

flowchart TD
A["[0, 1, 2]"];
B["[3, 4, 5]"];
C["[6, 7, 8]"];
D["[9]"];
AB["[0, 1, 2, 3, 4, 5]"];
A --> AB;
B --> AB;
CD["[6, 7, 8, 9]"];
C --> CD;
D --> CD;
ABCD["[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]"] -- sample --> ABCDsample["[0, 1, 3, 4, 5, 6, 7, 9]"];
AB --> ABCD;
CD --> ABCD;
ABCDsample --> ABCDcheck["Is k > len(pop)?"] --> ABCDsuccess["Success"];
Loading

Now, we want to make sure that the result is not biased towards short partitions. That is to say, if I sample one element from the following two-partition bag:

[1, 2, 3, 4] [5]

The second partition will always grab 5, and the first will grab any of the four. So the combine step will choose from 5 and one of 1-4. Without weighting, we'd have a 50% change of getting 5. Thankfully, the combine already weights the different sub-lists according to the lengths of the populations from which they were drawn (the original author of the algorithm wrote a nice post about it). So I believe we are okay! I did some small sampling tests to see if things seemed out of whack:

import dask.bag as db
from dask.bag import random


numbers = range(10)
buckets = {i: 0 for i in numbers}
a = db.from_sequence(numbers, partition_size=3)
for i in range(10000):
    s = random.sample(a, 1).compute(scheduler="sync")
    for e in s:
        buckets[e] += 1
print(buckets)  # {0: 1002, 1: 981, 2: 1025, 3: 972, 4: 1027, 5: 968, 6: 1010, 7: 1010, 8: 966, 9: 1039}

So the sampling is still evenly distributed.

All of which is to say, I think this PR is ready.

@ian-r-rose ian-r-rose marked this pull request as ready for review August 4, 2022 00:20
Copy link
Copy Markdown
Member

@pavithraes pavithraes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ian-r-rose Thank you for working on this, and for including the above (super helpful!) explanation. I think this looks great!

@pavithraes pavithraes merged commit f7fd6c4 into dask:main Aug 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bag bug Something is broken

Projects

None yet

Development

Successfully merging this pull request may close these issues.

dask.bag.random.sample throws various errors

2 participants