Skip to content

Commit

Permalink
BUG/ENH: ALlow bit generators to supply their own ctor
Browse files Browse the repository at this point in the history
Allow bit generators to supply their own constructors to enable Generator
objects using arbitrary bit generators to be supported

closes #22012
  • Loading branch information
bashtage committed Jul 19, 2022
1 parent 5ba36b7 commit 95e3e7f
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 29 deletions.
5 changes: 4 additions & 1 deletion numpy/random/_generator.pyx
Expand Up @@ -220,8 +220,11 @@ cdef class Generator:
self.bit_generator.state = state

def __reduce__(self):
ctor, name_tpl, state = self._bit_generator.__reduce__()

from ._pickle import __generator_ctor
return __generator_ctor, (self.bit_generator.state['bit_generator'],), self.bit_generator.state
# Requirements of __generator_ctor are (name, ctor)
return __generator_ctor, (name_tpl[0], ctor), state

@property
def bit_generator(self):
Expand Down
49 changes: 23 additions & 26 deletions numpy/random/_pickle.py
Expand Up @@ -14,70 +14,67 @@
}


def __generator_ctor(bit_generator_name='MT19937'):
def __bit_generator_ctor(bit_generator_name='MT19937'):
"""
Pickling helper function that returns a Generator object
Pickling helper function that returns a bit generator object
Parameters
----------
bit_generator_name : str
String containing the core BitGenerator
String containing the name of the BitGenerator
Returns
-------
rg : Generator
Generator using the named core BitGenerator
bit_generator : BitGenerator
BitGenerator instance
"""
if bit_generator_name in BitGenerators:
bit_generator = BitGenerators[bit_generator_name]
else:
raise ValueError(str(bit_generator_name) + ' is not a known '
'BitGenerator module.')

return Generator(bit_generator())
return bit_generator()


def __bit_generator_ctor(bit_generator_name='MT19937'):
def __generator_ctor(bit_generator_name="MT19937",
bit_generator_ctor=__bit_generator_ctor):
"""
Pickling helper function that returns a bit generator object
Pickling helper function that returns a Generator object
Parameters
----------
bit_generator_name : str
String containing the name of the BitGenerator
String containing the core BitGenerator's name
bit_generator_ctor : callable, optional
Callable function that takes bit_generator_name as its only argument
and returns an instantized bit generator.
Returns
-------
bit_generator : BitGenerator
BitGenerator instance
rg : Generator
Generator using the named core BitGenerator
"""
if bit_generator_name in BitGenerators:
bit_generator = BitGenerators[bit_generator_name]
else:
raise ValueError(str(bit_generator_name) + ' is not a known '
'BitGenerator module.')

return bit_generator()
return Generator(bit_generator_ctor(bit_generator_name))


def __randomstate_ctor(bit_generator_name='MT19937'):
def __randomstate_ctor(bit_generator_name="MT19937",
bit_generator_ctor=__bit_generator_ctor):
"""
Pickling helper function that returns a legacy RandomState-like object
Parameters
----------
bit_generator_name : str
String containing the core BitGenerator
String containing the core BitGenerator's name
bit_generator_ctor : callable, optional
Callable function that takes bit_generator_name as its only argument
and returns an instantized bit generator.
Returns
-------
rs : RandomState
Legacy RandomState using the named core BitGenerator
"""
if bit_generator_name in BitGenerators:
bit_generator = BitGenerators[bit_generator_name]
else:
raise ValueError(str(bit_generator_name) + ' is not a known '
'BitGenerator module.')

return RandomState(bit_generator())
return RandomState(bit_generator_ctor(bit_generator_name))
5 changes: 3 additions & 2 deletions numpy/random/mtrand.pyx
Expand Up @@ -213,9 +213,10 @@ cdef class RandomState:
self.set_state(state)

def __reduce__(self):
state = self.get_state(legacy=False)
ctor, name_tpl, _ = self._bit_generator.__reduce__()

from ._pickle import __randomstate_ctor
return __randomstate_ctor, (state['bit_generator'],), state
return __randomstate_ctor, (name_tpl[0], ctor), self.get_state(legacy=False)

cdef _reset_gauss(self):
self._aug_state.has_gauss = 0
Expand Down
13 changes: 13 additions & 0 deletions numpy/random/tests/test_generator_mt19937.py
Expand Up @@ -2695,3 +2695,16 @@ def test_contig_req_out(dist, order, dtype):
assert variates is out
variates = dist(out=out, dtype=dtype, size=out.shape)
assert variates is out


def test_generator_ctor_old_style_pickle():
rg = np.random.Generator(np.random.PCG64DXSM(0))
rg.standard_normal(1)
# Directly call reduce which is used in pickline
ctor, args, state_a = rg.__reduce__()
# Simulate unpickling an old pickle that only has the name
assert args[:1] == ("PCG64DXSM",)
b = ctor(*args[:1])
b.bit_generator.state = state_a
state_b = b.bit_generator.state
assert state_a == state_b
18 changes: 18 additions & 0 deletions numpy/random/tests/test_randomstate.py
Expand Up @@ -2020,3 +2020,21 @@ def test_broadcast_size_error():
random.binomial([1, 2], 0.3, size=(2, 1))
with pytest.raises(ValueError):
random.binomial([1, 2], [0.3, 0.7], size=(2, 1))


def test_randomstate_ctor_old_style_pickle():
rs = np.random.RandomState(MT19937(0))
rs.standard_normal(1)
# Directly call reduce which is used in pickline
ctor, args, state_a = rs.__reduce__()
# Simulate unpickling an old pickle that only has the name
assert args[:1] == ("MT19937",)
b = ctor(*args[:1])
b.set_state(state_a)
state_b = b.get_state(legacy=False)

assert_equal(state_a['bit_generator'], state_b['bit_generator'])
assert_array_equal(state_a['state']['key'], state_b['state']['key'])
assert_array_equal(state_a['state']['pos'], state_b['state']['pos'])
assert_equal(state_a['has_gauss'], state_b['has_gauss'])
assert_equal(state_a['gauss'], state_b['gauss'])

0 comments on commit 95e3e7f

Please sign in to comment.