Skip to content

Commit

Permalink
ENH: Enable 0 samples in hypergeometric
Browse files Browse the repository at this point in the history
Enable Hypergeometric to work with 0 samples

xref numpy/numpy#9237
  • Loading branch information
bashtage committed Apr 10, 2019
1 parent 9a32519 commit 699256c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions randomgen/generator.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3448,14 +3448,14 @@ cdef class RandomGenerator:
return disc(&random_hypergeometric, self._brng, size, self.lock, 0, 3,
lngood, 'ngood', CONS_NON_NEGATIVE,
lnbad, 'nbad', CONS_NON_NEGATIVE,
lnsample, 'nsample', CONS_GTE_1)
lnsample, 'nsample', CONS_NON_NEGATIVE)

if np.any(np.less(np.add(ongood, onbad), onsample)):
raise ValueError("ngood + nbad < nsample")
return discrete_broadcast_iii(&random_hypergeometric, self._brng, size, self.lock,
ongood, 'ngood', CONS_NON_NEGATIVE,
onbad, 'nbad', CONS_NON_NEGATIVE,
onsample, 'nsample', CONS_GTE_1)
onsample, 'nsample', CONS_NON_NEGATIVE)

def logseries(self, p, size=None):
"""
Expand Down
4 changes: 3 additions & 1 deletion randomgen/src/distributions/distributions.c
Original file line number Diff line number Diff line change
Expand Up @@ -1179,8 +1179,10 @@ int64_t random_hypergeometric(brng_t *brng_state, int64_t good, int64_t bad,
int64_t sample) {
if (sample > 10) {
return random_hypergeometric_hrua(brng_state, good, bad, sample);
} else {
} else if (sample > 0) {
return random_hypergeometric_hyp(brng_state, good, bad, sample);
} else {
return 0;
}
}

Expand Down
4 changes: 2 additions & 2 deletions randomgen/tests/test_generator_mt19937.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,7 +1832,7 @@ def test_hypergeometric(self):
nsample = [2]
bad_ngood = [-1]
bad_nbad = [-2]
bad_nsample_one = [0]
bad_nsample_one = [-1]
bad_nsample_two = [4]
hypergeom = random.hypergeometric
desired = np.array([1, 1, 1])
Expand Down Expand Up @@ -1863,7 +1863,7 @@ def test_hypergeometric(self):

assert_raises(ValueError, hypergeom, -1, 10, 20)
assert_raises(ValueError, hypergeom, 10, -1, 20)
assert_raises(ValueError, hypergeom, 10, 10, 0)
assert_raises(ValueError, hypergeom, 10, 10, -1)
assert_raises(ValueError, hypergeom, 10, 10, 25)

def test_logseries(self):
Expand Down

0 comments on commit 699256c

Please sign in to comment.