In [33]:
import sympy as sp
import re

(sym_x, sym_y, sym_z) = sp.symbols('x y z')
sym_env = Env(sym_x, sym_y, sym_z)

class Env:
    x:int | sp.Symbol
    y:int | sp.Symbol
    z:int | sp.Symbol
    def __init__(self, x:int | sp.Symbol, y:int | sp.Symbol, z:int | sp.Symbol):
        self.x = x if isinstance(x, sp.Symbol) else int(x)
        self.y = y if isinstance(y, sp.Symbol) else int(y)
        self.z = z if isinstance(z, sp.Symbol) else int(z)
    def __str__(self):
        return 'x = %d, y = %d, z = %d' % (self.x, self.y, self.z)

class AstNode:
    p:[]
    def __init__(self, *p):
        self.p = p
    def to_sym(self):
        return self.to_sym_op(*[p.to_sym() for p in self.p])
    def to_sym_op(self, *p):
        return self.eval_op(*p)
    def eval(self, env: Env):
        return self.eval_op(*[p.eval(env) for p in self.p])
    def eval_op(self, *p):
        raise 'AstNode.eval_op must be overridden'
    def __str__(self):
        return '%s(%s)' % (self.__class__.__name__, ', '.join([str(p) for p in self.p]))

class Const(AstNode):
    def to_sym(self):
        return sp.sympify(self.p[0])
    def eval(self, _: Env):
        return self.p[0]
    def __str__(self):
        return str(self.p[0])

class Var(AstNode):
    def to_sym(self, env=sym_env):
        return self.p[0](env)
    def eval(self, env: Env):
        return self.p[0](env)
    def __str__(self):
        return self.p[1]

class Add(AstNode):
    def eval_op(self, e0, e1):
        return e0 + e1

class Multiply(AstNode):
    def eval_op(self, e0, e1):
        return e0 * e1

class Ite(AstNode):
    def eval_op(self, b, e0, e1):
        return e0 if b else e1
    def to_sym_op(self, b, e0, e1):
        return sp.Piecewise((e0, b), (e1, True))

class And(AstNode):
    def eval_op(self, b0, b1):
        return b0 and b1
    def to_sym_op(self, b0, b1):
        return sp.And(b0, b1)

class Or(AstNode):
    def eval_op(self, b0, b1):
        return b0 or b1
    def to_sym_op(self, b0, b1):
        return sp.Or(b0, b1)

class Not(AstNode):
    def eval_op(self, b):
        return not b
    def to_sym_op(self, b):
        return sp.Not(b)

class Lt(AstNode):
    def eval_op(self, e0, e1):
        return e0 < e1
    def to_sym_op(self, e0, e1):
        return sp.Lt(e0, e1)

class Eq(AstNode):
    def eval_op(self, e0, e1):
        return e0 == e1
    def to_sym_op(self, e0, e1):
        return sp.Eq(e0, e1)

def toExpr(a):
    if isinstance(a, AstNode):
        return a
    else:
        if a == 1: return c1
        elif a == 2: return c2
        elif a == 3: return c3
        else: raise 'what is this parameter to toExpr(%r)?' % (a,)

c1 = Const(1)
c2 = Const(2)
c3 = Const(3)

x = Var(lambda e: e.x, 'x')
y = Var(lambda e: e.y, 'y')
z = Var(lambda e: e.z, 'z')


In [34]:
f = open('../../test-data/phase2/0307.txt')
m = re.match('#\\s*(.*)\\s*$', f.readline())
if not m:
    print('No match!')
    exit(1)

expr_str = m.group(1)
#exprStr = 'Not(Eq(1, z))'
print('evaluating %s:' % (expr_str,))
expr = eval(expr_str.replace('1', 'c1').replace('2', 'c2').replace('3', 'c3'))
print('  expr: %s' % (expr,))
print('')

evaluating Ite(Not(Eq(1, z)), z, Multiply(2, y)):
  expr: Ite(Not(Eq(1, z)), z, Multiply(2, y))



In [35]:
expr.to_sym()

Piecewise((z, Ne(z, 1)), (2*y, True))

In [36]:
expr_asts = []
bool_asts = []

def gen_all():
    global expr_asts, bool_asts
    new_expr_asts = [c1, c2, c3, x, y, z]
    new_bool_asts = []
    for s in expr_asts:
        for t in expr_asts:
            new_expr_asts.append(Add(s, t))
            new_expr_asts.append(Multiply(s, t))
            new_bool_asts.append(Lt(s, t))
            new_bool_asts.append(Eq(s, t))
    for p in bool_asts:
        new_bool_asts.append(Not(p))
        for q in bool_asts:
            new_bool_asts.append(And(p, q))
            new_bool_asts.append(Or(p, q))
        for s in expr_asts:
            for t in expr_asts:
                new_expr_asts.append(Ite(p, s, t))
    expr_asts = new_expr_asts
    bool_asts = new_bool_asts

gen_all()
gen_all()
gen_all()

print(len(expr_asts))
for ast in expr_asts[-20:]:
    print(ast)


450222
Ite(Eq(z, z), Multiply(z, z), Add(y, 3))
Ite(Eq(z, z), Multiply(z, z), Multiply(y, 3))
Ite(Eq(z, z), Multiply(z, z), Add(y, x))
Ite(Eq(z, z), Multiply(z, z), Multiply(y, x))
Ite(Eq(z, z), Multiply(z, z), Add(y, y))
Ite(Eq(z, z), Multiply(z, z), Multiply(y, y))
Ite(Eq(z, z), Multiply(z, z), Add(y, z))
Ite(Eq(z, z), Multiply(z, z), Multiply(y, z))
Ite(Eq(z, z), Multiply(z, z), Add(z, 1))
Ite(Eq(z, z), Multiply(z, z), Multiply(z, 1))
Ite(Eq(z, z), Multiply(z, z), Add(z, 2))
Ite(Eq(z, z), Multiply(z, z), Multiply(z, 2))
Ite(Eq(z, z), Multiply(z, z), Add(z, 3))
Ite(Eq(z, z), Multiply(z, z), Multiply(z, 3))
Ite(Eq(z, z), Multiply(z, z), Add(z, x))
Ite(Eq(z, z), Multiply(z, z), Multiply(z, x))
Ite(Eq(z, z), Multiply(z, z), Add(z, y))
Ite(Eq(z, z), Multiply(z, z), Multiply(z, y))
Ite(Eq(z, z), Multiply(z, z), Add(z, z))
Ite(Eq(z, z), Multiply(z, z), Multiply(z, z))


In [37]:
equiv_classes = []

with open('output/eqclasses.txt', 'w') as f:
    for ast in expr_asts:
        sym = ast.to_sym()
        s = ''
        for (class_sym, asts) in equiv_classes:
            if sp.ask(sp.Q.eq(sym, class_sym)):
                s = '%s == %s' % (ast, asts[0])
                asts.append(ast)
                break
        else:
            s = '%s -> new' % (ast,)
            equiv_classes.append((sym, [ast]))
        f.write(s + '\n')
        f.flush()
        print(s)

print('%d equivalence classes total.' % (len(equiv_classes),))

1 -> new
2 -> new
3 -> new
x -> new
y -> new
z -> new
Add(1, 1) == 2
Multiply(1, 1) == 1
Add(1, 2) == 3
Multiply(1, 2) == 2
Add(1, 3) -> new
Multiply(1, 3) == 3
Add(1, x) -> new
Multiply(1, x) == x
Add(1, y) -> new
Multiply(1, y) == y
Add(1, z) -> new
Multiply(1, z) == z
Add(1, Add(1, 1)) == 3
Multiply(1, Add(1, 1)) == 2
Add(1, Multiply(1, 1)) == 2
Multiply(1, Multiply(1, 1)) == 1
Add(1, Add(1, 2)) == Add(1, 3)
Multiply(1, Add(1, 2)) == 3
Add(1, Multiply(1, 2)) == 3
Multiply(1, Multiply(1, 2)) == 2
Add(1, Add(1, 3)) -> new
Multiply(1, Add(1, 3)) == Add(1, 3)
Add(1, Multiply(1, 3)) == Add(1, 3)
Multiply(1, Multiply(1, 3)) == 3
Add(1, Add(1, x)) -> new
Multiply(1, Add(1, x)) == Add(1, x)
Add(1, Multiply(1, x)) == Add(1, x)
Multiply(1, Multiply(1, x)) == x
Add(1, Add(1, y)) -> new
Multiply(1, Add(1, y)) == Add(1, y)
Add(1, Multiply(1, y)) == Add(1, y)
Multiply(1, Multiply(1, y)) == y
Add(1, Add(1, z)) -> new
Multiply(1, Add(1, z)) == Add(1, z)
Add(1, Multiply(1, z)) == Add(1, z)
Multiply(