Skip to content

Commit

Permalink
Merge 55c355a into debeed9
Browse files Browse the repository at this point in the history
  • Loading branch information
bashtage committed Apr 12, 2019
2 parents debeed9 + 55c355a commit 7e5103d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
6 changes: 3 additions & 3 deletions randomgen/generator.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -722,9 +722,9 @@ cdef class RandomGenerator:
cdf /= cdf[-1]
uniform_samples = self.random_sample(shape)
idx = cdf.searchsorted(uniform_samples, side='right')
idx = np.array(idx, copy=False) # searchsorted returns a scalar
idx = np.array(idx, copy=False, dtype=np.int64) # searchsorted returns a scalar
else:
idx = self.randint(0, pop_size, size=shape)
idx = self.randint(0, pop_size, size=shape, dtype=np.int64)
else:
if size > pop_size:
raise ValueError("Cannot take a larger sample than "
Expand Down Expand Up @@ -753,7 +753,7 @@ cdef class RandomGenerator:
n_uniq += new.size
idx = found
else:
idx = self.permutation(pop_size)[:size]
idx = (self.permutation(pop_size)[:size]).astype(np.int64)
if shape is not None:
idx.shape = shape

Expand Down
20 changes: 16 additions & 4 deletions randomgen/tests/test_generator_mt19937.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,25 +541,25 @@ def test_random_sample_unsupported_type(self):
def test_choice_uniform_replace(self):
random.seed(self.seed)
actual = random.choice(4, 4)
desired = np.array([2, 3, 2, 3])
desired = np.array([2, 3, 2, 3], dtype=np.int64)
assert_array_equal(actual, desired)

def test_choice_nonuniform_replace(self):
random.seed(self.seed)
actual = random.choice(4, 4, p=[0.4, 0.4, 0.1, 0.1])
desired = np.array([1, 1, 2, 2])
desired = np.array([1, 1, 2, 2], dtype=np.int64)
assert_array_equal(actual, desired)

def test_choice_uniform_noreplace(self):
random.seed(self.seed)
actual = random.choice(4, 3, replace=False)
desired = np.array([0, 1, 3])
desired = np.array([0, 1, 3], dtype=np.int64)
assert_array_equal(actual, desired)

def test_choice_nonuniform_noreplace(self):
random.seed(self.seed)
actual = random.choice(4, 3, replace=False, p=[0.1, 0.3, 0.5, 0.1])
desired = np.array([2, 3, 1])
desired = np.array([2, 3, 1], dtype=np.int64)
assert_array_equal(actual, desired)

def test_choice_noninteger(self):
Expand Down Expand Up @@ -638,6 +638,18 @@ def test_choice_nan_probabilities(self):
p = [None, None, None]
assert_raises(ValueError, random.choice, a, p=p)

def test_choice_return_type(self):
# gh 9867
p = np.ones(4) / 4.
actual = random.choice(4, 2)
assert actual.dtype == np.int64
actual = random.choice(4, 2, replace=False)
assert actual.dtype == np.int64
actual = random.choice(4, 2, p=p)
assert actual.dtype == np.int64
actual = random.choice(4, 2, p=p, replace=False)
assert actual.dtype == np.int64

def test_bytes(self):
random.seed(self.seed)
actual = random.bytes(10)
Expand Down

0 comments on commit 7e5103d

Please sign in to comment.