Skip to content

Commit

Permalink
_solve() helper now for solving univariate equations
Browse files Browse the repository at this point in the history
  • Loading branch information
skirpichev committed Oct 6, 2017
1 parent a96f4c1 commit f44a422
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 63 deletions.
112 changes: 53 additions & 59 deletions diofant/solvers/solvers.py
Expand Up @@ -646,14 +646,8 @@ def _sympified_list(w):
#
# try to get a solution
###########################################################################
if bare_f:
solution = _solve(f[0], *symbols, **flags)

if not solution:
solution = []
elif not isinstance(solution[0], dict):
assert len(symbols) == 1
solution = [{symbols[0]: s} for s in solution]
if bare_f and len(symbols) == 1:
solution = [{symbols[0]: s} for s in _solve(f[0], symbols[0], **flags)]
else:
solution = _solve_system(f, symbols, **flags)

Expand Down Expand Up @@ -713,11 +707,9 @@ def test_assumptions(sol):
return solution


def _solve(f, *symbols, **flags):
def _solve(f, symbol, **flags):
"""Return a checked solution for f in terms of one or more of the
symbols. A list should be returned except for the case when a linear
undetermined-coefficients equation is encountered (in which case
a dictionary is returned).
symbols. A list (possibly empty) should be returned.
If no method is implemented to solve the equation, a NotImplementedError
will be raised. In the case that conversion of an expression to a Poly
Expand All @@ -726,49 +718,6 @@ def _solve(f, *symbols, **flags):

not_impl_msg = "No algorithms are implemented to solve equation %s"

if len(symbols) != 1:
soln = None
free = f.free_symbols
ex = free - set(symbols)
if len(ex) != 1:
ind, dep = f.as_independent(*symbols)
ex = ind.free_symbols & dep.free_symbols
# find first successful solution
failed = []
got_s = set()
result = []
for s in symbols:
n, d = solve_linear(f, symbols=[s])
if n.is_Symbol:
# no need to check but we should simplify if desired
if flags.get('simplify', True):
d = simplify(d)
if got_s and any(ss in d.free_symbols for ss in got_s):
# sol depends on previously solved symbols: discard it
continue
got_s.add(n)
result.append({n: d})
elif n and d: # otherwise there was no solution for s
failed.append(s)
if not failed:
return result
for s in failed:
try:
soln = _solve(f, s, **flags)
for sol in soln:
if got_s and any(ss in sol.free_symbols for ss in got_s):
# sol depends on previously solved symbols: discard it
continue
got_s.add(s)
result.append({s: sol})
except NotImplementedError:
continue
if got_s:
return result
else:
raise NotImplementedError(not_impl_msg % f)
symbol = symbols[0]

# /!\ capture this flag then set it to False so that no checking in
# recursive calls will be done; only the final answer is checked
checkdens = check = flags.pop('check', True)
Expand All @@ -785,7 +734,7 @@ def _solve(f, *symbols, **flags):
# all solutions have been checked but now we must
# check that the solutions do not set denominators
# in any factor to zero
dens = denoms(f, symbols)
dens = denoms(f, [symbol])
result = [s for s in result if
all(not checksol(den, {symbol: s}, **flags) for den in
dens)]
Expand All @@ -796,7 +745,7 @@ def _solve(f, *symbols, **flags):
elif f.is_Piecewise:
result = set()
for n, (expr, cond) in enumerate(f.args):
candidates = _solve(piecewise_fold(expr), *symbols, **flags)
candidates = _solve(piecewise_fold(expr), symbol, **flags)
for candidate in candidates:
if candidate in result:
continue
Expand Down Expand Up @@ -831,7 +780,7 @@ def _solve(f, *symbols, **flags):
else:
# first see if it really depends on symbol and whether there
# is a linear solution
f_num, sol = solve_linear(f, symbols=symbols)
f_num, sol = solve_linear(f, symbols=[symbol])
if symbol not in f_num.free_symbols:
return []
elif f_num.is_Symbol:
Expand Down Expand Up @@ -1069,7 +1018,7 @@ def _expand(p):
if checkdens:
# reject any result that makes any denom. affirmatively 0;
# if in doubt, keep it
dens = denoms(f, symbols)
dens = denoms(f, [symbol])
result = [s for s in result if
all(not checksol(d, {symbol: s}, **flags)
for d in dens)]
Expand All @@ -1081,6 +1030,51 @@ def _expand(p):


def _solve_system(exprs, symbols, **flags):
if len(symbols) != 1 and len(exprs) == 1:
not_impl_msg = "No algorithms are implemented to solve equation %s"

f = exprs[0]
soln = None
free = f.free_symbols
ex = free - set(symbols)
if len(ex) != 1:
ind, dep = f.as_independent(*symbols)
ex = ind.free_symbols & dep.free_symbols
# find first successful solution
failed = []
got_s = set()
result = []
for s in symbols:
n, d = solve_linear(f, symbols=[s])
if n.is_Symbol:
# no need to check but we should simplify if desired
if flags.get('simplify', True):
d = simplify(d)
if got_s and any(ss in d.free_symbols for ss in got_s):
# sol depends on previously solved symbols: discard it
continue
got_s.add(n)
result.append({n: d})
elif n and d: # otherwise there was no solution for s
failed.append(s)
if not failed:
return result
for s in failed:
try:
soln = _solve(f, s, **flags)
for sol in soln:
if got_s and any(ss in sol.free_symbols for ss in got_s):
# sol depends on previously solved symbols: discard it
continue
got_s.add(s)
result.append({s: sol})
except NotImplementedError:
continue
if got_s:
return result
else:
raise NotImplementedError(not_impl_msg % f)

polys = []
dens = set()
failed = []
Expand Down
6 changes: 2 additions & 4 deletions diofant/solvers/tests/test_solvers.py
Expand Up @@ -197,6 +197,7 @@ def test_solve_polynomial2():

assert solve(z**2*x**2 - z**2*y**2) == [{x: -y}, {x: y}, {z: 0}]
assert solve(z**2*x - z**2*y**2) == [{x: y**2}, {z: 0}]
assert solve(z**2*x - z**2*y**2, simplify=False) == [{x: y**2}, {z: 0}]


def test_solve_polynomial_cv_1a():
Expand Down Expand Up @@ -439,10 +440,7 @@ def test_solve_transcendental():
assert solve(x**y - 1) == [{x: 1}, {y: 0}]
assert solve([x**y - 1]) == [{x: 1}, {y: 0}]
assert solve(x*y*(x**2 - y**2)) == [{x: 0}, {x: -y}, {x: y}, {y: 0}]
assert (solve([x*y*(x**2 - y**2)], check=False) ==
[{x: RootOf(x**3 - x*y**2, x, 0)},
{x: RootOf(x**3 - x*y**2, x, 1)},
{x: RootOf(x**3 - x*y**2, x, 2)}])
assert solve([x*y*(x**2 - y**2)], check=False) == [{x: 0}, {x: -y}, {x: y}, {y: 0}]
# issue sympy/sympy#4739
assert solve(exp(log(5)*x) - 2**x, x) == [{x: 0}]

Expand Down

0 comments on commit f44a422

Please sign in to comment.