Permalink
Browse files

Merge pull request #2927 from smichr/abs

issue 4046: better handling of Abs by solve
  • Loading branch information...
smichr committed Mar 6, 2014
2 parents 3e706a2 + 61d447b commit 6f68fa186dc9e80cdafddb1e4d72f49684286cda
Showing with 66 additions and 49 deletions.
  1. +1 −0 sympy/core/tests/test_wester.py
  2. +43 −40 sympy/solvers/solvers.py
  3. +22 −9 sympy/solvers/tests/test_solvers.py
@@ -934,6 +934,7 @@ def test_M28():
def test_M29():
x = symbols('x', real=True)
assert solve(abs(x - 1) - 2) == [-1, 3]
@@ -23,9 +23,9 @@
from sympy.core.exprtools import factor_terms
from sympy.core.function import (expand_mul, expand_multinomial, expand_log,
Derivative, AppliedUndef, UndefinedFunction, nfloat,
count_ops, Function, expand_power_exp)
count_ops, Function, expand_power_exp, Lambda)
from sympy.core.numbers import ilcm, Float
from sympy.core.relational import Relational
from sympy.core.relational import Relational, Ge
from sympy.logic.boolalg import And, Or
from sympy.core.basic import preorder_traversal
@@ -735,58 +735,48 @@ def _sympified_list(w):
exclude = reduce(set.union, [e.free_symbols for e in sympify(exclude)])
symbols = [s for s in symbols if s not in exclude]
# Any embedded piecewise functions need to be brought out to the
# top level so that the appropriate strategy gets selected.
# However, this is necessary only if one of the piecewise
# functions depends on one of the symbols we are solving for.
def _has_piecewise(e):
if e.is_Piecewise:
return e.has(*symbols)
return any([_has_piecewise(a) for a in e.args])
# real/imag handling -----------------------------
w = Dummy('w')
piece = Lambda(w, Piecewise((w, Ge(w, 0)), (-w, True)))
for i, fi in enumerate(f):
if _has_piecewise(fi):
f[i] = piecewise_fold(fi)
# real/imag handling
for i, fi in enumerate(f):
_abs = [a for a in fi.atoms(Abs) if a.has(*symbols)]
fi = f[i] = fi.xreplace(dict(list(zip(_abs,
[sqrt(a.args[0]**2) for a in _abs]))))
if fi.has(*_abs):
if any(s.assumptions0 for a in
_abs for s in a.free_symbols):
raise NotImplementedError(filldedent('''
All absolute
values were not removed from %s. In order to solve
this equation, try replacing your symbols with
Dummy symbols (or other symbols without assumptions).
''' % fi))
else:
raise NotImplementedError(filldedent('''
Removal of absolute values from %s failed.''' % fi))
# Abs
reps = []
for a in fi.atoms(Abs):
if not a.has(*symbols):
continue
if a.args[0].is_real is None:
raise NotImplementedError('solving %s when the argument '
'is not real or imaginary.' % a)
reps.append((a, piece(a.args[0]) if a.args[0].is_real else \
piece(a.args[0]*S.ImaginaryUnit)))
fi = fi.subs(reps)
# arg
_arg = [a for a in fi.atoms(arg) if a.has(*symbols)]
f[i] = fi.xreplace(dict(list(zip(_arg,
fi = fi.xreplace(dict(list(zip(_arg,
[atan(im(a.args[0])/re(a.args[0])) for a in _arg]))))
# save changes
f[i] = fi
# see if re(s) or im(s) appear
irf = []
for s in symbols:
# if s is real or complex then re(s) or im(s) will not appear in the equation;
if s.is_real or s.is_complex:
continue
if s.is_real or s.is_imaginary:
continue # neither re(x) nor im(x) will appear
# if re(s) or im(s) appear, the auxiliary equation must be present
irs = re(s), im(s)
if any(_f.has(i) for _f in f for i in irs):
symbols.extend(irs)
if any(fi.has(re(s), im(s)) for fi in f):
irf.append((s, re(s) + S.ImaginaryUnit*im(s)))
if irf:
for s, rhs in irf:
for i, fi in enumerate(f):
f[i] = fi.xreplace({s: rhs})
f.append(s - rhs)
symbols.extend([re(s), im(s)])
if bare_f:
bare_f = False
flags['dict'] = True
f.extend(s - rhs for s, rhs in irf)
# end of real/imag handling
# end of real/imag handling -----------------------------
symbols = list(uniq(symbols))
if not ordered_symbols:
@@ -872,7 +862,7 @@ def _has_piecewise(e):
p in symset or
p.is_Add or p.is_Mul or
p.is_Pow and not implicit or
p.is_Function and not implicit):
p.is_Function and not implicit) and p.func not in (re, im):
continue
elif not p in seen:
seen.add(p)
@@ -895,6 +885,18 @@ def _has_piecewise(e):
floats = True
f[i] = nsimplify(fi, rational=True)
# Any embedded piecewise functions need to be brought out to the
# top level so that the appropriate strategy gets selected.
# However, this is necessary only if one of the piecewise
# functions depends on one of the symbols we are solving for.
def _has_piecewise(e):
if e.is_Piecewise:
return e.has(*symbols)
return any([_has_piecewise(a) for a in e.args])
for i, fi in enumerate(f):
if _has_piecewise(fi):
f[i] = piecewise_fold(fi)
#
# try to get a solution
###########################################################################
@@ -955,6 +957,7 @@ def _do_dict(solution):
# undo the dictionary solutions returned when the system was only partially
# solved with poly-system if all symbols are present
if (
not flags.get('dict', False) and
solution and
ordered_symbols and
type(solution) is not dict and
@@ -1103,12 +1103,6 @@ def test_check_assumptions():
assert solve(x**2 - 1) == [1]
def test_solve_abs():
assert set(solve(abs(x - 7) - 8)) == set([-S(1), S(15)])
r = symbols('r', real=True)
raises(NotImplementedError, lambda: solve(2*abs(r) - abs(r - 1)))
def test_issue_2957():
assert solve(tanh(x + 3)*tanh(x - 3) - 1) == []
assert set([simplify(w) for w in solve(tanh(x - 1)*tanh(x + 1) + 1)]) == set([
@@ -1215,6 +1209,7 @@ def test_issue3429():
def test_overdetermined():
x = symbols('x', real=True)
eqs = [Abs(4*x - 7) - 5, Abs(3 - 8*x) - 1]
assert solve(eqs, x) == [(S.Half,)]
assert solve(eqs, x, manual=True) == [(S.Half,)]
@@ -1260,27 +1255,45 @@ def test_issue_3693():
def test_issues_3720_3721_3722_3149():
# 3722
x, y = symbols('x y')
x, y = symbols('x y', real=True)
assert solve(abs(x + 3) - 2*abs(x - 3)) == [1, 9]
assert solve([abs(x) - 2, arg(x) - pi], x) == [
{re(x): -2, x: -2, im(x): 0}, {re(x): 2, x: 2, im(x): 0}]
assert solve([abs(x) - 2, arg(x) - pi], x) == [(-2,), (2,)]
assert set(solve(abs(x - 7) - 8)) == set([-S(1), S(15)])
# issue 4046
assert solve(2*abs(x) - abs(x - 1)) == [-1, Rational(1, 3)]
x = symbols('x')
assert solve([re(x) - 1, im(x) - 2], x) == [
{re(x): 1, x: 1 + 2*I, im(x): 2}]
# check for 'dict' handling of solution
eq = sqrt(re(x)**2 + im(x)**2) - 3
assert solve(eq) == solve(eq, x)
i = symbols('i', imaginary=True)
assert solve(abs(i) - 3) == [-3*I, 3*I]
raises(NotImplementedError, lambda: solve(abs(x) - 3))
w = symbols('w', integer=True)
assert solve(2*x**w - 4*y**w, w) == solve((x/y)**w - 2, w)
x, y = symbols('x y', real=True)
assert solve(x + y*I + 3) == {y: 0, x: -3}
# github issue 2642
assert solve(x*(1 + I)) == [0]
x, y = symbols('x y', imaginary=True)
assert solve(x + y*I + 3 + 2*I) == {x: -2*I, y: 3*I}
x = symbols('x', real=True)
assert solve(x + y + 3 + 2*I) == {x: -3, y: -2*I}
# issue 3149
f = Function('f')
assert solve(f(x + 1) - f(2*x - 1)) == [2]
assert solve(log(x + 1) - log(2*x - 1)) == [2]
x = symbols('x')
assert solve(2**x + 4**x) == [I*pi/log(2)]

0 comments on commit 6f68fa1

Please sign in to comment.