diofant/diofant

Merge pull request #2927 from smichr/abs

`issue 4046: better handling of Abs by solve`
• Loading branch information...
smichr committed Mar 6, 2014
2 parents 3e706a2 + 61d447b commit 6f68fa186dc9e80cdafddb1e4d72f49684286cda
Showing with 66 additions and 49 deletions.
1. +1 −0 sympy/core/tests/test_wester.py
2. +43 −40 sympy/solvers/solvers.py
3. +22 −9 sympy/solvers/tests/test_solvers.py
 @@ -934,6 +934,7 @@ def test_M28(): def test_M29(): x = symbols('x', real=True) assert solve(abs(x - 1) - 2) == [-1, 3]
 @@ -23,9 +23,9 @@ from sympy.core.exprtools import factor_terms from sympy.core.function import (expand_mul, expand_multinomial, expand_log, Derivative, AppliedUndef, UndefinedFunction, nfloat, count_ops, Function, expand_power_exp) count_ops, Function, expand_power_exp, Lambda) from sympy.core.numbers import ilcm, Float from sympy.core.relational import Relational from sympy.core.relational import Relational, Ge from sympy.logic.boolalg import And, Or from sympy.core.basic import preorder_traversal @@ -735,58 +735,48 @@ def _sympified_list(w): exclude = reduce(set.union, [e.free_symbols for e in sympify(exclude)]) symbols = [s for s in symbols if s not in exclude] # Any embedded piecewise functions need to be brought out to the # top level so that the appropriate strategy gets selected. # However, this is necessary only if one of the piecewise # functions depends on one of the symbols we are solving for. def _has_piecewise(e): if e.is_Piecewise: return e.has(*symbols) return any([_has_piecewise(a) for a in e.args]) # real/imag handling ----------------------------- w = Dummy('w') piece = Lambda(w, Piecewise((w, Ge(w, 0)), (-w, True))) for i, fi in enumerate(f): if _has_piecewise(fi): f[i] = piecewise_fold(fi) # real/imag handling for i, fi in enumerate(f): _abs = [a for a in fi.atoms(Abs) if a.has(*symbols)] fi = f[i] = fi.xreplace(dict(list(zip(_abs, [sqrt(a.args[0]**2) for a in _abs])))) if fi.has(*_abs): if any(s.assumptions0 for a in _abs for s in a.free_symbols): raise NotImplementedError(filldedent(''' All absolute values were not removed from %s. In order to solve this equation, try replacing your symbols with Dummy symbols (or other symbols without assumptions). ''' % fi)) else: raise NotImplementedError(filldedent(''' Removal of absolute values from %s failed.''' % fi)) # Abs reps = [] for a in fi.atoms(Abs): if not a.has(*symbols): continue if a.args[0].is_real is None: raise NotImplementedError('solving %s when the argument ' 'is not real or imaginary.' % a) reps.append((a, piece(a.args[0]) if a.args[0].is_real else \ piece(a.args[0]*S.ImaginaryUnit))) fi = fi.subs(reps) # arg _arg = [a for a in fi.atoms(arg) if a.has(*symbols)] f[i] = fi.xreplace(dict(list(zip(_arg, fi = fi.xreplace(dict(list(zip(_arg, [atan(im(a.args[0])/re(a.args[0])) for a in _arg])))) # save changes f[i] = fi # see if re(s) or im(s) appear irf = [] for s in symbols: # if s is real or complex then re(s) or im(s) will not appear in the equation; if s.is_real or s.is_complex: continue if s.is_real or s.is_imaginary: continue # neither re(x) nor im(x) will appear # if re(s) or im(s) appear, the auxiliary equation must be present irs = re(s), im(s) if any(_f.has(i) for _f in f for i in irs): symbols.extend(irs) if any(fi.has(re(s), im(s)) for fi in f): irf.append((s, re(s) + S.ImaginaryUnit*im(s))) if irf: for s, rhs in irf: for i, fi in enumerate(f): f[i] = fi.xreplace({s: rhs}) f.append(s - rhs) symbols.extend([re(s), im(s)]) if bare_f: bare_f = False flags['dict'] = True f.extend(s - rhs for s, rhs in irf) # end of real/imag handling # end of real/imag handling ----------------------------- symbols = list(uniq(symbols)) if not ordered_symbols: @@ -872,7 +862,7 @@ def _has_piecewise(e): p in symset or p.is_Add or p.is_Mul or p.is_Pow and not implicit or p.is_Function and not implicit): p.is_Function and not implicit) and p.func not in (re, im): continue elif not p in seen: seen.add(p) @@ -895,6 +885,18 @@ def _has_piecewise(e): floats = True f[i] = nsimplify(fi, rational=True) # Any embedded piecewise functions need to be brought out to the # top level so that the appropriate strategy gets selected. # However, this is necessary only if one of the piecewise # functions depends on one of the symbols we are solving for. def _has_piecewise(e): if e.is_Piecewise: return e.has(*symbols) return any([_has_piecewise(a) for a in e.args]) for i, fi in enumerate(f): if _has_piecewise(fi): f[i] = piecewise_fold(fi) # # try to get a solution ########################################################################### @@ -955,6 +957,7 @@ def _do_dict(solution): # undo the dictionary solutions returned when the system was only partially # solved with poly-system if all symbols are present if ( not flags.get('dict', False) and solution and ordered_symbols and type(solution) is not dict and
 @@ -1103,12 +1103,6 @@ def test_check_assumptions(): assert solve(x**2 - 1) == [1] def test_solve_abs(): assert set(solve(abs(x - 7) - 8)) == set([-S(1), S(15)]) r = symbols('r', real=True) raises(NotImplementedError, lambda: solve(2*abs(r) - abs(r - 1))) def test_issue_2957(): assert solve(tanh(x + 3)*tanh(x - 3) - 1) == [] assert set([simplify(w) for w in solve(tanh(x - 1)*tanh(x + 1) + 1)]) == set([ @@ -1215,6 +1209,7 @@ def test_issue3429(): def test_overdetermined(): x = symbols('x', real=True) eqs = [Abs(4*x - 7) - 5, Abs(3 - 8*x) - 1] assert solve(eqs, x) == [(S.Half,)] assert solve(eqs, x, manual=True) == [(S.Half,)] @@ -1260,27 +1255,45 @@ def test_issue_3693(): def test_issues_3720_3721_3722_3149(): # 3722 x, y = symbols('x y') x, y = symbols('x y', real=True) assert solve(abs(x + 3) - 2*abs(x - 3)) == [1, 9] assert solve([abs(x) - 2, arg(x) - pi], x) == [ {re(x): -2, x: -2, im(x): 0}, {re(x): 2, x: 2, im(x): 0}] assert solve([abs(x) - 2, arg(x) - pi], x) == [(-2,), (2,)] assert set(solve(abs(x - 7) - 8)) == set([-S(1), S(15)]) # issue 4046 assert solve(2*abs(x) - abs(x - 1)) == [-1, Rational(1, 3)] x = symbols('x') assert solve([re(x) - 1, im(x) - 2], x) == [ {re(x): 1, x: 1 + 2*I, im(x): 2}] # check for 'dict' handling of solution eq = sqrt(re(x)**2 + im(x)**2) - 3 assert solve(eq) == solve(eq, x) i = symbols('i', imaginary=True) assert solve(abs(i) - 3) == [-3*I, 3*I] raises(NotImplementedError, lambda: solve(abs(x) - 3)) w = symbols('w', integer=True) assert solve(2*x**w - 4*y**w, w) == solve((x/y)**w - 2, w) x, y = symbols('x y', real=True) assert solve(x + y*I + 3) == {y: 0, x: -3} # github issue 2642 assert solve(x*(1 + I)) == [0] x, y = symbols('x y', imaginary=True) assert solve(x + y*I + 3 + 2*I) == {x: -2*I, y: 3*I} x = symbols('x', real=True) assert solve(x + y + 3 + 2*I) == {x: -3, y: -2*I} # issue 3149 f = Function('f') assert solve(f(x + 1) - f(2*x - 1)) == [2] assert solve(log(x + 1) - log(2*x - 1)) == [2] x = symbols('x') assert solve(2**x + 4**x) == [I*pi/log(2)]

0 comments on commit `6f68fa1`

Please sign in to comment.