Skip to content

Commit

Permalink
Fixes issue sympy/sympy#9334
Browse files Browse the repository at this point in the history
the numexpr printer was only activated if the string 'numexpr' was found
in the modules list.  This commit adds the function '_module_present'
which also checks for a module with module.__name__ == 'numexpr' to
activate the numexpr printer

added test for issue

// edited by skirpichev
  • Loading branch information
pbrady authored and skirpichev committed Jul 3, 2015
1 parent 36213bf commit 90de625
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
16 changes: 13 additions & 3 deletions sympy/utilities/lambdify.py
Expand Up @@ -325,7 +325,7 @@ def lambdify(args, expr, modules=None, printer=None, use_imps=True,
namespaces.append(modules)
else:
# consistency check
if 'numexpr' in modules and len(modules) > 1:
if _module_present('numexpr', modules) and len(modules) > 1:
raise TypeError("numexpr must be the only item in 'modules'")
namespaces += list(modules)
# fill namespace with first having highest priority
Expand All @@ -342,8 +342,8 @@ def lambdify(args, expr, modules=None, printer=None, use_imps=True,
for term in syms:
namespace.update({str(term): term})

if 'numexpr' in namespaces and printer is None:
#XXX: This has to be done here because of circular imports
if _module_present('numexpr', namespaces) and printer is None:
# XXX: This has to be done here because of circular imports
from sympy.printing.lambdarepr import NumExprPrinter as printer

# Get the names of the args, for creating a docstring
Expand Down Expand Up @@ -406,6 +406,16 @@ def _issue_7853_dep_check(namespaces, namespace, expr):
"supplying `modules=[{'ImmutableMatrix': numpy.matrix}, "
"'numpy']`.", issue=7853).warn()


def _module_present(modname, modlist):
if modname in modlist:
return True
for m in modlist:
if hasattr(m, '__name__') and m.__name__ == modname:
return True
return False


def _get_namespace(m):
"""
This is used by _lambdify to parse its arguments.
Expand Down
14 changes: 14 additions & 0 deletions sympy/utilities/tests/test_lambdify.py
Expand Up @@ -157,6 +157,7 @@ def test_numpy_translation_abs():
assert f(-1) == 1
assert f(1) == 1


def test_numexpr_printer():
if not numexpr:
skip("numexpr not installed.")
Expand All @@ -180,6 +181,19 @@ def test_numexpr_printer():
f = lambdify(args, ssym(*args), modules='numexpr')
assert f(*(1, )*nargs) is not None


def test_issue_9334():
if not numexpr:
skip("numexpr not installed.")
if not numpy:
skip("numpy not installed.")
expr = sympy.S('b*a - sqrt(a**2)')
a, b = sorted(expr.free_symbols, key=lambda s: s.name)
func_numexpr = lambdify((a, b), expr, modules=[numexpr], dummify=False)
foo, bar = numpy.random.random((2, 4))
func_numexpr(foo, bar)


#================== Test some functions ============================


Expand Down

0 comments on commit 90de625

Please sign in to comment.