From 98eb31bafdfd63e7b727591471e5283018adb753 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 20 Nov 2014 13:38:37 -0800 Subject: [PATCH] make poisson sampling slightly faster --- python/pyspark/rddsampler.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 5928b1d892de0..9e7acc28e99dd 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -17,6 +17,7 @@ import sys import random +import math class RDDSamplerBase(object): @@ -37,16 +38,21 @@ def getUniformSample(self): return self._random.random() def getPoissonSample(self, mean): - # here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by - # drawing a sequence of numbers delta_j ~ Exp(mean) - num_arrivals = 0 - cur_time = self._random.expovariate(mean) - - while cur_time < 1.0: - cur_time += self._random.expovariate(mean) - num_arrivals += 1 - - return num_arrivals + # Using Knuth's algorithm described in http://en.wikipedia.org/wiki/Poisson_distribution + if mean < 20.0: # one exp and k+1 random calls + l = math.exp(-mean) + p = self._random.random() + k = 0 + while p > l: + k += 1 + p *= self._random.random() + else: # switch to the log domain, k+1 expovariate (random + log) calls + p = self._random.expovariate(mean) + k = 0 + while p < 1.0: + k += 1 + p += self._random.expovariate(mean) + return k def func(self, split, iterator): raise NotImplementedError