forked from sympy/sympy
-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_unify.py
89 lines (69 loc) · 3 KB
/
test_unify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from sympy.unify.core import Compound, Variable, CondVariable, allcombinations
from sympy.unify import core
a,b,c = 'abc'
w,x,y,z = map(Variable, 'wxyz')
C = Compound
def is_associative(x):
return isinstance(x, Compound) and (x.op in ('Add', 'Mul', 'CAdd', 'CMul'))
def is_commutative(x):
return isinstance(x, Compound) and (x.op in ('CAdd', 'CMul'))
def unify(a, b, s={}):
return core.unify(a, b, s=s, is_associative=is_associative,
is_commutative=is_commutative)
def test_basic():
assert list(unify(a, x, {})) == [{x: a}]
assert list(unify(a, x, {x: 10})) == []
assert list(unify(1, x, {})) == [{x: 1}]
assert list(unify(a, a, {})) == [{}]
assert list(unify((w, x), (y, z), {})) == [{w: y, x: z}]
assert list(unify(x, (a, b), {})) == [{x: (a, b)}]
assert list(unify((a, b), (x, x), {})) == []
assert list(unify((y, z), (x, x), {}))!= []
assert list(unify((a, (b, c)), (a, (x, y)), {})) == [{x: b, y: c}]
def test_ops():
assert list(unify(C('Add', (a,b,c)), C('Add', (a,x,y)), {})) == \
[{x:b, y:c}]
assert list(unify(C('Add', (C('Mul', (1,2)), b,c)), C('Add', (x,y,c)), {})) == \
[{x: C('Mul', (1,2)), y:b}]
def test_associative():
c1 = C('Add', (1,2,3))
c2 = C('Add', (x,y))
result = list(unify(c1, c2, {}))
assert tuple(unify(c1, c2, {})) == ({x: 1, y: C('Add', (2, 3))},
{x: C('Add', (1, 2)), y: 3})
def test_commutative():
c1 = C('CAdd', (1,2,3))
c2 = C('CAdd', (x,y))
result = list(unify(c1, c2, {}))
assert {x: 1, y: C('CAdd', (2, 3))} in result
assert ({x: 2, y: C('CAdd', (1, 3))} in result or
{x: 2, y: C('CAdd', (3, 1))} in result)
def _test_combinations_assoc():
assert set(allcombinations((1,2,3), (a,b), True)) == \
set(((((1, 2), (3,)), (a, b)), (((1,), (2, 3)), (a, b))))
def _test_combinations_comm():
assert set(allcombinations((1,2,3), (a,b), None)) == \
set(((((1,), (2, 3)), ('a', 'b')), (((2,), (3, 1)), ('a', 'b')),
(((3,), (1, 2)), ('a', 'b')), (((1, 2), (3,)), ('a', 'b')),
(((2, 3), (1,)), ('a', 'b')), (((3, 1), (2,)), ('a', 'b'))))
def test_allcombinations():
assert set(allcombinations((1,2), (1,2), 'commutative')) ==\
set(((((1,),(2,)), ((1,),(2,))), (((1,),(2,)), ((2,),(1,)))))
def test_commutativity():
c1 = Compound('CAdd', (a, b))
c2 = Compound('CAdd', (x, y))
assert is_commutative(c1) and is_commutative(c2)
assert len(list(unify(c1, c2, {}))) == 2
def test_CondVariable():
expr = C('CAdd', (1, 2))
x = Variable('x')
y = CondVariable('y', lambda a: a % 2 == 0)
z = CondVariable('z', lambda a: a > 3)
pattern = C('CAdd', (x, y))
assert list(unify(expr, pattern, {})) == \
[{x: 1, y: 2}]
z = CondVariable('z', lambda a: a > 3)
pattern = C('CAdd', (z, y))
assert list(unify(expr, pattern, {})) == []
def test_defaultdict():
assert next(unify(Variable('x'), 'foo')) == {Variable('x'): 'foo'}