Skip to content

Commit

Permalink
Merge pull request #1074 from skirpichev/no-strategies
Browse files Browse the repository at this point in the history
Drop deps on strategies
  • Loading branch information
skirpichev committed Oct 17, 2020
2 parents 1d220d4 + 25a8046 commit f9dc278
Show file tree
Hide file tree
Showing 11 changed files with 211 additions and 43 deletions.
195 changes: 188 additions & 7 deletions diofant/core/strategies.py
Expand Up @@ -3,7 +3,7 @@
This file assumes knowledge of Basic and little else.
"""

from strategies.dispatch import dispatch
import functools

from ..utilities.iterables import sift
from .basic import Atom, Basic
Expand All @@ -13,32 +13,38 @@
'glom', 'flatten', 'unpack', 'sort')


@dispatch(Basic)
@functools.singledispatch
def arguments(o):
"""Extract arguments from an expression."""
return o.args


@arguments.register((int, Atom))
@arguments.register(int)
@arguments.register(Atom)
def arguments_atomic(o):
return ()


@dispatch(Basic)
@functools.singledispatch
def operator(o):
"""Extract the head of an expression."""
return o.func


@operator.register((int, Atom))
@operator.register(int)
@operator.register(Atom)
def operator_atomic(o):
return o


@dispatch(type, (tuple, list))
@functools.singledispatch
def term(op, args):
"""Build an expression from the head and arguments."""
return op(*args)


@term.register((int, Atom), (tuple, list))
@term.register(int)
@term.register(Atom)
def term_atomic(op, args):
return op

Expand Down Expand Up @@ -154,3 +160,178 @@ def flatten(expr):
else:
args.append(arg)
return term(cls, args)


def identity(x):
return x


def switch(key, ruledict):
"""Select a rule based on the result of key called on the function."""
def switch_rl(expr):
rl = ruledict.get(key(expr), identity)
return rl(expr)
return switch_rl


def typed(ruletypes):
"""Apply rules based on the expression type.
Examples
========
>>> rm_zeros = rm_id(lambda x: x == 0)
>>> rm_ones = rm_id(lambda x: x == 1)
>>> remove_idents = typed({Add: rm_zeros, Mul: rm_ones})
"""
return switch(type, ruletypes)


def treeapply(tree, join, leaf=identity):
"""Apply functions onto recursive containers (tree).
join - a dictionary mapping container types to functions
e.g. ``{list: minimize, tuple: chain}``
Keys are containers/iterables. Values are functions [a] -> a.
Examples
========
>>> tree = [(3, 2), (4, 1)]
>>> treeapply(tree, {list: max, tuple: min})
2
>>> def mul(*args):
... total = 1
... for arg in args:
... total *= arg
... return total
>>> treeapply(tree, {list: mul, tuple: lambda *args: sum(args)})
25
"""
for typ in join:
if isinstance(tree, typ):
return join[typ](*map(functools.partial(treeapply, join=join, leaf=leaf),
tree))
return leaf(tree)


def minimize(*rules, objective=identity):
"""Select result of rules that minimizes objective.
Examples
========
>>> from diofant.core.strategies import minimize
>>> rl = minimize(lambda x: x + 1, lambda x: x - 1)
>>> rl(4)
3
"""
def minrule(expr):
return min((rule(expr) for rule in rules), key=objective)
return minrule


def chain(*rules):
"""Compose a sequence of rules so that they apply to the expr sequentially."""
def chain_rl(expr):
for rule in rules:
expr = rule(expr)
return expr
return chain_rl


def greedy(tree, objective=identity, **kwargs):
"""Execute a strategic tree. Select alternatives greedily,
Examples
========
>>> tree = [lambda x: x + 1,
... (lambda x: x - 1, lambda x: 2*x)] # either inc or dec-then-double
>>> fn = greedy(tree)
>>> fn(4) # lowest value comes from the inc
5
>>> fn(1) # lowest value comes from dec then double
0
This function selects between options in a tuple. The result is chosen that
minimizes the objective function.
>>> fn = greedy(tree, objective=lambda x: -x) # maximize
>>> fn(4) # highest value comes from the dec then double
6
>>> fn(1) # highest value comes from the inc
2
"""
optimize = functools.partial(minimize, objective=objective)
return treeapply(tree, {list: optimize, tuple: chain}, **kwargs)


def do_one(rules):
"""Try each of the rules until one works. Then stop."""
def do_one_rl(expr):
for rl in rules:
result = rl(expr)
if result != expr:
return result
return expr
return do_one_rl


def condition(cond, rule):
"""Only apply rule if condition is true."""
def conditioned_rl(expr):
if cond(expr):
return rule(expr)
else:
return expr
return conditioned_rl


def exhaust(rule):
"""Apply a rule repeatedly until it has no effect."""
def exhaustive_rl(expr):
new, old = rule(expr), expr
while new != old:
new, old = rule(new), new
return new
return exhaustive_rl


basic_fns = {'op': type,
'new': Basic.__new__,
'leaf': lambda x: not isinstance(x, Basic) or x.is_Atom,
'children': lambda x: x.args}


def sall(rule, fns=basic_fns):
"""Strategic all - apply rule to args."""
op, new, children, leaf = map(fns.get, ('op', 'new', 'children', 'leaf'))

def all_rl(expr):
if leaf(expr):
return expr
else:
args = map(rule, children(expr))
return new(op(expr), *args)

return all_rl


def bottom_up(rule, fns=basic_fns):
"""Apply a rule down a tree running it on the bottom nodes first."""
return chain(lambda expr: sall(bottom_up(rule, fns), fns)(expr), rule)


def null_safe(rule):
"""Return original expr if rule returns None."""
def null_safe_rl(expr):
result = rule(expr)
if result is None:
return expr
else:
return result
return null_safe_rl
7 changes: 2 additions & 5 deletions diofant/matrices/expressions/blockmatrix.py
@@ -1,9 +1,6 @@
from strategies import condition, do_one, exhaust
from strategies.core import typed
from strategies.traverse import bottom_up

from ...core import Add, Expr, Integer, sympify
from ...core.strategies import unpack
from ...core.strategies import (bottom_up, condition, do_one, exhaust, typed,
unpack)
from ...logic import false
from ...utilities import sift
from .determinant import Determinant
Expand Down
4 changes: 1 addition & 3 deletions diofant/matrices/expressions/hadamard.py
@@ -1,7 +1,5 @@
from strategies import condition, do_one, exhaust

from ...core import Mul, sympify
from ...core.strategies import flatten, unpack
from ...core.strategies import condition, do_one, exhaust, flatten, unpack
from ..matrices import ShapeError
from .matexpr import MatrixExpr

Expand Down
5 changes: 2 additions & 3 deletions diofant/matrices/expressions/matadd.py
@@ -1,11 +1,10 @@
import functools
import operator

from strategies import condition, do_one, exhaust

from ...core import Add, Expr, sympify
from ...core.logic import _fuzzy_group
from ...core.strategies import flatten, glom, rm_id, sort, unpack
from ...core.strategies import (condition, do_one, exhaust, flatten, glom,
rm_id, sort, unpack)
from ...functions import adjoint
from ...utilities import default_sort_key, sift
from ..matrices import MatrixBase, ShapeError
Expand Down
5 changes: 1 addition & 4 deletions diofant/matrices/expressions/matmul.py
@@ -1,9 +1,6 @@
from strategies import do_one, exhaust
from strategies.core import typed

from ...core import Add, Expr, Mul, Number, sympify
from ...core.logic import _fuzzy_group
from ...core.strategies import flatten, rm_id, unpack
from ...core.strategies import do_one, exhaust, flatten, rm_id, typed, unpack
from ...functions import adjoint
from ..matrices import MatrixBase, ShapeError
from .matexpr import Identity, MatrixExpr, ZeroMatrix
Expand Down
15 changes: 1 addition & 14 deletions diofant/simplify/fu.py
Expand Up @@ -188,13 +188,10 @@

from collections import defaultdict

from strategies.core import debug, identity
from strategies.tree import greedy

from .. import DIOFANT_DEBUG
from ..core import (Add, Dummy, Expr, I, Integer, Mul, Pow, Rational,
expand_mul, factor_terms, gcd_terms, pi, sympify)
from ..core.exprtools import Factors
from ..core.strategies import greedy, identity
from ..functions import (binomial, cos, cosh, cot, coth, csc, sec, sin, sinh,
sqrt, tan, tanh)
from ..functions.elementary.hyperbolic import HyperbolicFunction
Expand Down Expand Up @@ -1572,16 +1569,6 @@ def L(rv):

# ============== end of basic Fu-like tools =====================

if DIOFANT_DEBUG: # pragma: no cover
(TR0, TR1, TR2, TR3, TR4, TR5,
TR6, TR7, TR8, TR9, TR10, TR11, TR12, TR13,
TR2i, TRmorrie, TR14, TR15, TR16,
TR12i, TR111, TR22) = list(map(debug, (TR0, TR1, TR2, TR3, TR4, TR5,
TR6, TR7, TR8, TR9, TR10, TR11,
TR12, TR13, TR2i, TRmorrie, TR14,
TR15, TR16, TR12i, TR111, TR22)))


# tuples are chains -- (f, g) -> lambda x: g(f(x))
# lists are choices -- [f, g] -> lambda x: min(f(x), g(x), key=objective)

Expand Down
4 changes: 1 addition & 3 deletions diofant/simplify/trigsimp.py
@@ -1,14 +1,12 @@
import functools
from collections import defaultdict

from strategies.core import identity
from strategies.tree import greedy

from ..core import (Add, Basic, Dummy, E, Expr, FunctionClass, I, Integer, Mul,
Pow, Rational, Wild, cacheit, count_ops, expand,
expand_mul, factor_terms, igcd, symbols, sympify)
from ..core.compatibility import iterable
from ..core.function import _mexpand
from ..core.strategies import greedy, identity
from ..domains import ZZ
from ..functions import cos, cosh, cot, coth, exp, sin, sinh, tan, tanh
from ..functions.elementary.hyperbolic import HyperbolicFunction
Expand Down
15 changes: 13 additions & 2 deletions diofant/tests/core/test_strategies.py
@@ -1,7 +1,7 @@
from diofant import Add, Basic, Integer
from diofant.abc import x
from diofant.core.strategies import (arguments, flatten, glom, operator, rm_id,
sort, term, unpack)
from diofant.core.strategies import (arguments, flatten, glom, null_safe,
operator, rm_id, sort, term, unpack)


__all__ = ()
Expand Down Expand Up @@ -55,3 +55,14 @@ def test_term():
assert operator(Integer(2)) == Integer(2)
assert term(Add, (2, x)) == 2 + x
assert term(Integer(2), ()) == Integer(2)


def test_null_safe():
def rl(expr):
if expr == 1:
return 2
safe_rl = null_safe(rl)
assert rl(1) == safe_rl(1)

assert rl(3) is None
assert safe_rl(3) == 3
2 changes: 1 addition & 1 deletion diofant/tests/matrices/test_matmul.py
@@ -1,10 +1,10 @@
import pytest
from strategies.core import null_safe

from diofant import (Adjoint, Basic, I, Identity, ImmutableMatrix, Inverse,
MatMul, MatPow, Matrix, MatrixSymbol, ShapeError,
Transpose, ZeroMatrix, adjoint, det, eye, symbols,
transpose)
from diofant.core.strategies import null_safe
from diofant.matrices.expressions.matmul import (any_zeros, factor_in_front,
only_squares, remove_ids,
unpack, xxinv)
Expand Down
1 change: 1 addition & 0 deletions docs/release/notes-0.12.rst
Expand Up @@ -41,6 +41,7 @@ Developer changes

* Depend on `flake8-sfs <https://github.com/peterjc/flake8-sfs>`_, see :pull:`983`.
* Depend on `mypy <http://mypy-lang.org/>`_, see :pull:`1022`.
* Drop dependency on strategies, see :pull:`1074`.

Issues closed
=============
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Expand Up @@ -38,7 +38,6 @@ setup_requires = setuptools>=36.7.0
pip>=9.0.1
isort
install_requires = mpmath>=0.19
strategies>=0.2.3
tests_require = diofant[develop]
[options.package_data]
diofant = tests/logic/*.cnf
Expand Down

0 comments on commit f9dc278

Please sign in to comment.