In [43]:
import sympy as sp
import functools, random, re

import synthast as ast

In [44]:
test_fn = '../../test-data/phase2/0308.txt'
with open(test_fn) as f:
    print('parsing "%s":' % (f.readline().strip(),))
expr = ast.parse_examples(test_fn)
print('  expr: %s' % (expr,))

parsing "# Multiply(Ite(Eq(x, 1), Multiply(x, x), Multiply(y, x)), Add(Add(z, z), 2))":
  expr: Multiply(Ite(Eq(x, 1), Multiply(x, x), Multiply(y, x)), Add(Add(z, z), 2))


In [45]:
expr.to_sym()

(2*z + 2)*Piecewise((x**2, Eq(x, 1)), (x*y, True))

In [46]:
expr_asts: [ast.AstNode] = []
bool_asts: [ast.AstNode] = []

def gen_all():
    global expr_asts, bool_asts
    new_expr_asts = [ast.c1, ast.c2, ast.c3, ast.x, ast.y, ast.z]
    new_bool_asts = []
    for s in expr_asts:
        for t in expr_asts:
            new_expr_asts.append(ast.Add(s, t))
            new_expr_asts.append(ast.Multiply(s, t))
            new_bool_asts.append(ast.Lt(s, t))
            new_bool_asts.append(ast.Eq(s, t))
    for p in bool_asts:
        new_bool_asts.append(ast.Not(p))
        for q in bool_asts:
            new_bool_asts.append(ast.And(p, q))
            new_bool_asts.append(ast.Or(p, q))
        for s in expr_asts:
            for t in expr_asts:
                new_expr_asts.append(ast.Ite(p, s, t))
    expr_asts = new_expr_asts
    bool_asts = new_bool_asts

gen_all()
gen_all()
gen_all()

print(len(expr_asts))
for s in expr_asts[:5]:
    print(s)
print('...')
for s in expr_asts[-5:]:
    print(s)


450222
1
2
3
x
y
...
Ite(Eq(z, z), Multiply(z, z), Multiply(z, x))
Ite(Eq(z, z), Multiply(z, z), Add(z, y))
Ite(Eq(z, z), Multiply(z, z), Multiply(z, y))
Ite(Eq(z, z), Multiply(z, z), Add(z, z))
Ite(Eq(z, z), Multiply(z, z), Multiply(z, z))


In [57]:
nums = [0, 1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59,
        61, 67, 71, 73, 79, 83, 89, 97]
nums += [-n for n in nums[1:]]
r = random.Random(53)

if False:

    best_d = ()
    best_d_vals = {}

    for i in range(10):
        d = (r.choice(nums), r.choice(nums), r.choice(nums))
        vals = {}
        e = ast.Env(*d)
        for s in expr_asts:
            v = (s.eval(e),)
            if v not in vals:
                vals[v] = []
            vals[v].append(s)
        if len(vals) > len(best_d_vals):
            best_d = d
            best_d_vals = vals

    print("%r: %d" % (best_d, len(best_d_vals)))

    for i in range(4):
        next_best_d = ()
        next_best_d_vals = {}
        for j in range(10):
            d = best_d + ((r.choice(nums), r.choice(nums), r.choice(nums)),)
            vals = {}
            e = ast.Env(*(d[-1]))
            for v, ss in best_d_vals.items():
                for s in ss:
                    cv = v + (s.eval(e),)
                    if cv not in vals:
                        vals[cv] = []
                    vals[cv].append(s)
            if len(vals) > len(next_best_d_vals):
                next_best_d = d
                next_best_d_vals = vals
        best_d = next_best_d
        best_d_vals = next_best_d_vals
        print('%d, %d: %r => %d' % (i, j, best_d, len(best_d_vals)))

else:
    best_d = ((-61, -13, 79), (-43, 67, -11), (67, 47, 43), (5, 83, 83),
              (1, 11, 59), (13, 1, 71), (19, 73, 1),
              (2, 23, 97), (29, 2, 89), (31, 79, 2),
              (3, 7, 83), (17, 3, 67), (23, 61, 3),
              )
    best_d_vals = {}
    d_envs = [ast.Env(*d) for d in best_d]
    for s in expr_asts:
        cv = ()
        for e in d_envs:
            cv += (s.eval(e),)
        if cv not in best_d_vals:
            best_d_vals[cv] = []
        best_d_vals[cv].append(s)
    print('%r => %d' % (best_d, len(best_d_vals)))

d_equiv = {}
for d, ss in best_d_vals.items():
    for s in ss:
        assert str(s) not in d_equiv
        d_equiv[str(s)] = (d, s)


((-61, -13, 79), (-43, 67, -11), (67, 47, 43), (5, 83, 83), (1, 11, 59), (13, 1, 71), (19, 73, 1), (2, 23, 97), (29, 2, 89), (31, 79, 2), (3, 7, 83), (17, 3, 67), (23, 61, 3)) => 25840


In [58]:
for i in range(len(expr_asts)):
    s = expr_asts[i]
    canon_ast = s.canonical()
    if canon_ast is not s:
        if not sp.ask(sp.Q.eq(s.to_sym(), s.canonical().to_sym())):
            print('Not equivalent: %s <-> %s' % (s, s.canonical()))
    pct = i * 100 // len(expr_asts)
    next_pct = (i + 1) * 100 // len(expr_asts)
    if pct // 10 < next_pct // 10:
        print('...%d%%' % (next_pct,))

...10%
...20%
...30%
...40%
...50%
...60%
...70%
...80%
...90%
...100%


In [59]:
equiv_classes = []
equiv_lookup = {}

def equiv(s, t):
    global equiv_lookup
    # If they are literally the same object they are equivalent
    if s is t:
        return True
    s_canon = s.canonical()
    t_canon = t.canonical()
    s_str = str(s_canon)
    t_str = str(t_canon)
    # If their canonical forms are equal they are equivalent
    if s_str == t_str:
        return True
    # If they are in the same equivalence class they are equivalent
    if s_str in equiv_lookup and t_str in equiv_lookup:
        return equiv_lookup[s_str] == equiv_lookup[t_str]
    # If they have different discriminators they are not equivalent
    if s_str in d_equiv and t_str in d_equiv:
        s_d = d_equiv[s_str][0]
        t_d = d_equiv[t_str][0]
        if s_d != t_d:
            return False
    # If they are the same operation, and all their operands are equivalent,
    # they are equivalent
    if s_canon.name == t_canon.name:
        assert len(s.p) == len(t.p)
        for s_p, t_p in zip(s_canon.p, t_canon.p):
            if not equiv(s_p, t_p):
                break
        else:
            return True
    # If their symbolic forms can be proven equal, they are equivalent
    s_sym = s_canon.to_sym()
    t_sym = t_canon.to_sym()
    if sp.ask(sp.Q.eq(s_sym, t_sym)):
        return True
    # Otherwise, we must assume they are not equivalent
    return False

known_equivs = {}

with open('output/eqclasses.txt', 'r') as f:
    for l in f.readlines():
        m = re.match('^(.*) -> new$', l)
        if m:
            known_equivs[m.group(1)] = None
        else:
            m = re.match('^(.*) == (.*)$', l)
            if m:
                known_equivs[m.group(1)] = m.group(2)

with open('output/eqclasses.txt', 'w') as f:
    i = 0
    for s in expr_asts:
        sym = s.to_sym()
        res = ''
        s_str = str(s)
        s_d = d_equiv[s_str][0]
        equiv_s = None
        if s_str in known_equivs:
            s_ke = known_equivs[s_str]
            if s_ke in equiv_lookup:
                class_sym, asts = equiv_classes[equiv_lookup[s_ke]]
                if equiv(s, asts[0]):
                    equiv_s = asts[0]
                    asts.append(s)
                    equiv_lookup[s_str] = equiv_lookup[s_ke]
                else:
                    print('bad equivalence: %s != %s!' % (s, equiv_s))
        if equiv_s == None:
            for t in best_d_vals[s_d]:
                t_str = str(t)
                if not t_str in equiv_lookup:
                    continue
                # print('len=%d, idx=%r' % (len(equiv_classes), equiv_lookup[t_str]))
                class_sym, asts = equiv_classes[equiv_lookup[t_str]]
                if equiv(s, asts[0]):
                    equiv_s = asts[0]
                    asts.append(s)
                    equiv_lookup[s_str] = equiv_lookup[t_str]
                    if s_str in known_equivs:
                        # equiv_s was previously None, i.e. read as '-> new'
                        print('bad inequivalence: %s == %s' % (s, equiv_s))
                    break
        if equiv_s != None:
            res = '%s == %s' % (s, equiv_s)
        else:
            res = '%s -> new' % (s,)
            equiv_lookup[s_str] = len(equiv_classes)
            equiv_classes.append((sym, [s]))
        f.write(res + '\n')
        f.flush()

        pct = i * 100 // len(expr_asts)
        next_pct = (i + 1) * 100 // len(expr_asts)
        if pct // 10 < next_pct // 10:
            print('...%d%%' % (next_pct,))
        i += 1
print('%d equivalence classes total.' % (len(equiv_classes),))

...10%
...20%
...30%
...40%
