Skip to content

Commit

Permalink
Fix #196 by calling numpy.random.seed().
Browse files Browse the repository at this point in the history
  • Loading branch information
donkirkby committed Aug 9, 2018
1 parent 72d4d04 commit 3799bf4
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
12 changes: 7 additions & 5 deletions plugin/PySrc/code_tracer.py
Expand Up @@ -733,7 +733,7 @@ def find_module(self, fullname, path=None):
if (fullname == self.module_name or
(fullname == SCOPE_NAME and self.is_own_driver)):
return self
if fullname not in ('matplotlib', 'matplotlib.pyplot'):
if fullname not in ('matplotlib', 'matplotlib.pyplot', 'numpy.random'):
return None
is_after = False
for finder in sys.meta_path:
Expand All @@ -742,11 +742,11 @@ def find_module(self, fullname, path=None):
continue
loader = finder.find_module(fullname, path)
if loader is not None:
return PatchedMatplotlibLoader(fullname, loader, self.is_zoomed)
return PatchedModuleLoader(fullname, loader, self.is_zoomed)
if sys.version_info < (3, 0) and not TracedModuleImporter.is_desperate:
# Didn't find anyone to load the module, get desperate.
TracedModuleImporter.is_desperate = True
return PatchedMatplotlibLoader(fullname, None, self.is_zoomed)
return PatchedModuleLoader(fullname, None, self.is_zoomed)

def load_module(self, fullname):
if '.' in self.module_name:
Expand All @@ -768,7 +768,7 @@ def load_module(self, fullname):
return new_mod


class PatchedMatplotlibLoader(object):
class PatchedModuleLoader(object):
def __init__(self, fullname, main_loader, is_zoomed):
self.fullname = fullname
self.main_loader = main_loader
Expand All @@ -781,7 +781,9 @@ def load_module(self, fullname):
else:
module = import_module(fullname)
TracedModuleImporter.is_desperate = False
if fullname == 'matplotlib':
if fullname == 'numpy.random':
module.seed(0)
elif fullname == 'matplotlib':
module.use('Agg')
elif fullname == 'matplotlib.pyplot':
self.plt = module
Expand Down
31 changes: 31 additions & 0 deletions test/PySrc/tests/test_code_tracer_matplotlib.py
Expand Up @@ -20,6 +20,18 @@ def clear_matplotlib():
yield True


@pytest.fixture(name='is_numpy_cleared')
def clear_numpy():
for should_yield in (True, False):
to_delete = [module_name
for module_name in sys.modules
if module_name.startswith('numpy')]
for module_name in to_delete:
del sys.modules[module_name]
if should_yield:
yield True


def replace_image(report):
report = trim_report(report)
report = re.sub(r"image='[a-zA-Z0-9+/=]*'", "image='...'", report)
Expand Down Expand Up @@ -47,3 +59,22 @@ def test_show(is_matplotlib_cleared):
report = tracer.trace_turtle(code)

assert expected_report == replace_image(report)


def test_numpy_random(is_numpy_cleared):
assert is_numpy_cleared
code = """\
import numpy as np
x = np.random.normal(size=3)
"""
expected_report = """\
x = array([1.76405235, 0.40015721, 0.97873798])
"""
tracer = CodeTracer()

report = tracer.trace_code(code)

assert expected_report == trim_report(report)

0 comments on commit 3799bf4

Please sign in to comment.