diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 5928b1d892de0..9e7acc28e99dd 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -17,6 +17,7 @@ import sys import random +import math class RDDSamplerBase(object): @@ -37,16 +38,21 @@ def getUniformSample(self): return self._random.random() def getPoissonSample(self, mean): - # here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by - # drawing a sequence of numbers delta_j ~ Exp(mean) - num_arrivals = 0 - cur_time = self._random.expovariate(mean) - - while cur_time < 1.0: - cur_time += self._random.expovariate(mean) - num_arrivals += 1 - - return num_arrivals + # Using Knuth's algorithm described in http://en.wikipedia.org/wiki/Poisson_distribution + if mean < 20.0: # one exp and k+1 random calls + l = math.exp(-mean) + p = self._random.random() + k = 0 + while p > l: + k += 1 + p *= self._random.random() + else: # switch to the log domain, k+1 expovariate (random + log) calls + p = self._random.expovariate(mean) + k = 0 + while p < 1.0: + k += 1 + p += self._random.expovariate(mean) + return k def func(self, split, iterator): raise NotImplementedError