diff --git a/sympy/utilities/lambdify.py b/sympy/utilities/lambdify.py index b3c39c1db6..5c3bc394d1 100644 --- a/sympy/utilities/lambdify.py +++ b/sympy/utilities/lambdify.py @@ -325,7 +325,7 @@ def lambdify(args, expr, modules=None, printer=None, use_imps=True, namespaces.append(modules) else: # consistency check - if 'numexpr' in modules and len(modules) > 1: + if _module_present('numexpr', modules) and len(modules) > 1: raise TypeError("numexpr must be the only item in 'modules'") namespaces += list(modules) # fill namespace with first having highest priority @@ -342,8 +342,8 @@ def lambdify(args, expr, modules=None, printer=None, use_imps=True, for term in syms: namespace.update({str(term): term}) - if 'numexpr' in namespaces and printer is None: - #XXX: This has to be done here because of circular imports + if _module_present('numexpr', namespaces) and printer is None: + # XXX: This has to be done here because of circular imports from sympy.printing.lambdarepr import NumExprPrinter as printer # Get the names of the args, for creating a docstring @@ -406,6 +406,16 @@ def _issue_7853_dep_check(namespaces, namespace, expr): "supplying `modules=[{'ImmutableMatrix': numpy.matrix}, " "'numpy']`.", issue=7853).warn() + +def _module_present(modname, modlist): + if modname in modlist: + return True + for m in modlist: + if hasattr(m, '__name__') and m.__name__ == modname: + return True + return False + + def _get_namespace(m): """ This is used by _lambdify to parse its arguments. diff --git a/sympy/utilities/tests/test_lambdify.py b/sympy/utilities/tests/test_lambdify.py index ad3ffccace..0b5837a975 100644 --- a/sympy/utilities/tests/test_lambdify.py +++ b/sympy/utilities/tests/test_lambdify.py @@ -157,6 +157,7 @@ def test_numpy_translation_abs(): assert f(-1) == 1 assert f(1) == 1 + def test_numexpr_printer(): if not numexpr: skip("numexpr not installed.") @@ -180,6 +181,19 @@ def test_numexpr_printer(): f = lambdify(args, ssym(*args), modules='numexpr') assert f(*(1, )*nargs) is not None + +def test_issue_9334(): + if not numexpr: + skip("numexpr not installed.") + if not numpy: + skip("numpy not installed.") + expr = sympy.S('b*a - sqrt(a**2)') + a, b = sorted(expr.free_symbols, key=lambda s: s.name) + func_numexpr = lambdify((a, b), expr, modules=[numexpr], dummify=False) + foo, bar = numpy.random.random((2, 4)) + func_numexpr(foo, bar) + + #================== Test some functions ============================