Fix bag sampling when there are some smaller partitions#9349
Fix bag sampling when there are some smaller partitions#9349pavithraes merged 2 commits intodask:mainfrom
Conversation
of samples. This allows you to have some small partitions.
|
Let's look at the following: import dask.bag
from dask.bag import random
bag = dask.bag.from_sequence(range(10), partition_size=3)
bag2 = random.sample(bag, k=8, split_every=2)
bag2.compute()This should sample 8 elements from a bag with four partitions, and ten total elements. Should be totally doable. The bag looks something like [0, 1, 2] [3, 4, 5] [6, 7, 8], [9]In flowchart TD
A["[0, 1, 2]"];
B["[3, 4, 5]"];
C["[6, 7, 8]"];
D["[9]"];
AB["[0, 1, 2, 3, 4, 5]"];
A --> AB;
B --> AB;
AB --> ABcheck["Is k > len(pop)?"] --> ABfail["Error"];
CD["[6, 7, 8, 9]"];
C --> CD;
D --> CD;
CD --> CDcheck["Is k > len(pop)?"] --> CDfail["Error"];
With this PR, we defer the flowchart TD
A["[0, 1, 2]"];
B["[3, 4, 5]"];
C["[6, 7, 8]"];
D["[9]"];
AB["[0, 1, 2, 3, 4, 5]"];
A --> AB;
B --> AB;
CD["[6, 7, 8, 9]"];
C --> CD;
D --> CD;
ABCD["[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]"] -- sample --> ABCDsample["[0, 1, 3, 4, 5, 6, 7, 9]"];
AB --> ABCD;
CD --> ABCD;
ABCDsample --> ABCDcheck["Is k > len(pop)?"] --> ABCDsuccess["Success"];
Now, we want to make sure that the result is not biased towards short partitions. That is to say, if I sample one element from the following two-partition bag: [1, 2, 3, 4] [5]The second partition will always grab 5, and the first will grab any of the four. So the combine step will choose from 5 and one of 1-4. Without weighting, we'd have a 50% change of getting 5. Thankfully, the combine already weights the different sub-lists according to the lengths of the populations from which they were drawn (the original author of the algorithm wrote a nice post about it). So I believe we are okay! I did some small sampling tests to see if things seemed out of whack: import dask.bag as db
from dask.bag import random
numbers = range(10)
buckets = {i: 0 for i in numbers}
a = db.from_sequence(numbers, partition_size=3)
for i in range(10000):
s = random.sample(a, 1).compute(scheduler="sync")
for e in s:
buckets[e] += 1
print(buckets) # {0: 1002, 1: 981, 2: 1025, 3: 972, 4: 1027, 5: 968, 6: 1010, 7: 1010, 8: 966, 9: 1039}So the sampling is still evenly distributed. All of which is to say, I think this PR is ready. |
pavithraes
left a comment
There was a problem hiding this comment.
@ian-r-rose Thank you for working on this, and for including the above (super helpful!) explanation. I think this looks great!
Fixes the rest of #9249. Currently, if you have some imbalanced partitions in a bag such that some of them are small, and then sample from that bag with a
klarger than those partitions, you can run into the error in #9249. The error message isn't correct in that case.I still need to write some tests and also sit down with a pencil and paper to convince myself and others that this doesn't bias the sampling, so marking as a draft until I can get around to that.
dask.bag.random.samplethrows various errors #9249pre-commit run --all-files