Skip to content

Commit

Permalink
[BACKPORT] Fix random execute that a random tensor can have different…
Browse files Browse the repository at this point in the history
… results when executed in different sessions (#178)
  • Loading branch information
qinxuye authored and wjsi committed Jan 24, 2019
1 parent b590394 commit e617f0a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
2 changes: 1 addition & 1 deletion mars/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import sys

version_info = (0, 1, 0, 'b1')
version_info = (0, 1, 0)
_num_index = max(idx if isinstance(v, int) else 0
for idx, v in enumerate(version_info))
__version__ = '.'.join(map(str, version_info[:_num_index + 1])) + \
Expand Down
4 changes: 1 addition & 3 deletions mars/tensor/expressions/random/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
class RandomState(object):
def __init__(self, seed=None):
self._random_state = np.random.RandomState(seed=seed)
self._curr_seed = seed

def seed(self, seed=None):
"""
Expand All @@ -53,11 +52,10 @@ def seed(self, seed=None):
RandomState
"""
self._random_state.seed(seed=seed)
self._curr_seed = seed

@property
def _state(self):
return State(self._random_state) if self._curr_seed is not None else None
return State(self._random_state)

@classmethod
def _handle_size(cls, size):
Expand Down
11 changes: 11 additions & 0 deletions mars/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,14 @@ def testArrayProtocol(self):

result = np.asarray(arr4, dtype=np.float_)
np.testing.assert_array_equal(result, np.asarray(200, dtype=np.float_))

def testRandomExecuteInSessions(self):
arr = mt.random.rand(20, 20)

sess1 = new_session()
res1 = sess1.run(arr)

sess2 = new_session()
res2 = sess2.run(arr)

np.testing.assert_array_equal(res1, res2)

0 comments on commit e617f0a

Please sign in to comment.