In [10]:
import sympy as sp
import functools, random, re

import synthast as ast

In [11]:
test_fn = '../../test-data/phase2/0308.txt'
with open(test_fn) as f:
    print('parsing "%s":' % (f.readline().strip(),))
expr = ast.parse_examples(test_fn)
print('  expr: %s' % (expr,))

parsing "# Multiply(Ite(Eq(x, 1), Multiply(x, x), Multiply(y, x)), Add(Add(z, z), 2))":
  expr: Multiply(Ite(Eq(x, 1), Multiply(x, x), Multiply(y, x)), Add(Add(z, z), 2))


In [12]:
expr.to_sym()

(2*z + 2)*Piecewise((x**2, Eq(x, 1)), (x*y, True))

In [13]:
expr_asts: [ast.AstNode] = []
bool_asts: [ast.AstNode] = []

def gen_all():
    global expr_asts, bool_asts
    new_expr_asts = [ast.c1, ast.c2, ast.c3, ast.x, ast.y, ast.z]
    new_bool_asts = []
    for s in expr_asts:
        for t in expr_asts:
            new_expr_asts.append(ast.Add(s, t))
            new_expr_asts.append(ast.Multiply(s, t))
            new_bool_asts.append(ast.Lt(s, t))
            new_bool_asts.append(ast.Eq(s, t))
    for p in bool_asts:
        new_bool_asts.append(ast.Not(p))
        for q in bool_asts:
            new_bool_asts.append(ast.And(p, q))
            new_bool_asts.append(ast.Or(p, q))
        for s in expr_asts:
            for t in expr_asts:
                new_expr_asts.append(ast.Ite(p, s, t))
    expr_asts = new_expr_asts
    bool_asts = new_bool_asts

gen_all()
gen_all()
gen_all()

print(len(expr_asts))
for s in expr_asts[:5]:
    print(s)
print('...')
for s in expr_asts[-5:]:
    print(s)


450222
1
2
3
x
y
...
Ite(Eq(z, z), Multiply(z, z), Multiply(z, x))
Ite(Eq(z, z), Multiply(z, z), Add(z, y))
Ite(Eq(z, z), Multiply(z, z), Multiply(z, y))
Ite(Eq(z, z), Multiply(z, z), Add(z, z))
Ite(Eq(z, z), Multiply(z, z), Multiply(z, z))


In [14]:
nums = [0, 1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59,
        61, 67, 71, 73, 79, 83, 89, 97]
nums += [-n for n in nums[1:]]
r = random.Random(53)

if False:

    best_d = ()
    best_d_vals = {}

    for i in range(10):
        d = (r.choice(nums), r.choice(nums), r.choice(nums))
        vals = {}
        e = ast.Env(*d)
        for s in expr_asts:
            v = (s.eval(e),)
            if v not in vals:
                vals[v] = []
            vals[v].append(s)
        if len(vals) > len(best_d_vals):
            best_d = d
            best_d_vals = vals

    print("%r: %d" % (best_d, len(best_d_vals)))

    for i in range(4):
        next_best_d = ()
        next_best_d_vals = {}
        for j in range(10):
            d = best_d + ((r.choice(nums), r.choice(nums), r.choice(nums)),)
            vals = {}
            e = ast.Env(*(d[-1]))
            for v, ss in best_d_vals.items():
                for s in ss:
                    cv = v + (s.eval(e),)
                    if cv not in vals:
                        vals[cv] = []
                    vals[cv].append(s)
            if len(vals) > len(next_best_d_vals):
                next_best_d = d
                next_best_d_vals = vals
        best_d = next_best_d
        best_d_vals = next_best_d_vals
        print('%d, %d: %r => %d' % (i, j, best_d, len(best_d_vals)))

else:
    best_d = ((-61, -13, 79), (-43, 67, -11), (67, 47, 43), (5, 83, 83), (41, 29, 37))
    best_d_vals = {}
    d_envs = [ast.Env(*d) for d in best_d]
    for s in expr_asts:
        cv = ()
        for e in d_envs:
            cv += (s.eval(e),)
        if cv not in best_d_vals:
            best_d_vals[cv] = []
        best_d_vals[cv].append(s)
    print('%r => %d' % (best_d, len(best_d_vals)))

d_equiv = {}
i = 0
for ss in best_d_vals.values():
    for s in ss:
        assert str(s) not in d_equiv
        d_equiv[str(s)] = (i, s)
    i += 1


((-61, -13, 79), (-43, 67, -11), (67, 47, 43), (5, 83, 83), (41, 29, 37)) => 8144


In [15]:
for i in range(len(expr_asts)):
    s = expr_asts[i]
    canon_ast = s.canonical()
    if canon_ast is not s:
        if not sp.ask(sp.Q.eq(s.to_sym(), s.canonical().to_sym())):
            print('Not equivalent: %s <-> %s' % (s, s.canonical()))
    pct = i * 100 // len(expr_asts)
    next_pct = (i + 1) * 100 // len(expr_asts)
    if pct // 10 < next_pct // 10:
        print('...%d%%' % (next_pct,))

...10%
...20%
...30%
...40%
...50%
...60%
...70%
...80%
...90%
...100%


In [18]:
equiv_classes = []
equiv_lookup = {}

def equiv(s, t):
    if s is t:
        return True
    s_str = str(s.canonical())
    t_str = str(t.canonical())
    if s_str == t_str:
        return True
    if s_str in equiv_lookup and t_str in equiv_lookup:
        return equiv_lookup[s_str] == equiv_lookup[t_str]
    s_de = d_equiv[s_str][0]
    t_de = d_equiv[t_str][0]
    if s_de != t_de:
        return False
    # if s.name == t.name:
    #     assert len(s.p) == len(t.p)
    #     for s_p, t_p in zip(s.p, t.p):
    #         if not equiv(s_p, t_p):
    #             return False
    s_sym = s.to_sym()
    t_sym = t.to_sym()
    return sp.ask(sp.Q.eq(s_sym, t_sym))

with open('output/eqclasses.txt', 'w') as f:
    for t in expr_asts:
        sym = t.to_sym()
        s = ''
        s_de = d_equiv[str(t)][0]
        for i in range(len(equiv_classes)):
            class_sym, asts = equiv_classes[i]
            if equiv(t, asts[0]):
                s = '%s == %s' % (t, asts[0])
                asts.append(t)
                equiv_lookup[str(t)] = i
                break
        else:
            s = '%s -> new' % (t,)
            equiv_lookup[str(t)] = len(equiv_classes)
            equiv_classes.append((sym, [t]))
        f.write(s + '\n')
        f.flush()
        print(s)

print('%d equivalence classes total.' % (len(equiv_classes),))

1 -> new
2 -> new
3 -> new
x -> new
y -> new
z -> new
Add(1, 1) == 2
Multiply(1, 1) == 1
Add(1, 2) == 3
Multiply(1, 2) == 2
Add(1, 3) -> new
Multiply(1, 3) == 3
Add(1, x) -> new
Multiply(1, x) == x
Add(1, y) -> new
Multiply(1, y) == y
Add(1, z) -> new
Multiply(1, z) == z
Add(1, Add(1, 1)) == 3
Multiply(1, Add(1, 1)) == 2
Add(1, Multiply(1, 1)) == 2
Multiply(1, Multiply(1, 1)) == 1
Add(1, Add(1, 2)) == Add(1, 3)
Multiply(1, Add(1, 2)) == 3
Add(1, Multiply(1, 2)) == 3
Multiply(1, Multiply(1, 2)) == 2
Add(1, Add(1, 3)) -> new
Multiply(1, Add(1, 3)) == Add(1, 3)
Add(1, Multiply(1, 3)) == Add(1, 3)
Multiply(1, Multiply(1, 3)) == 3
Add(1, Add(1, x)) -> new
Multiply(1, Add(1, x)) == Add(1, x)
Add(1, Multiply(1, x)) == Add(1, x)
Multiply(1, Multiply(1, x)) == x
Add(1, Add(1, y)) -> new
Multiply(1, Add(1, y)) == Add(1, y)
Add(1, Multiply(1, y)) == Add(1, y)
Multiply(1, Multiply(1, y)) == y
Add(1, Add(1, z)) -> new
Multiply(1, Add(1, z)) == Add(1, z)
Add(1, Multiply(1, z)) == Add(1, z)
Multiply(

KeyboardInterrupt: 