In [1]:
import sympy as sp
import functools, random, re

import synthast as ast

In [2]:
test_fn = '../../test-data/phase2/0308.txt'
with open(test_fn) as f:
    print('parsing "%s":' % (f.readline().strip(),))
expr = ast.parse_examples(test_fn)
print('  expr: %s' % (expr,))

parsing "# Multiply(Ite(Eq(x, 1), Multiply(x, x), Multiply(y, x)), Add(Add(z, z), 2))":
  expr: Multiply(Ite(Eq(x, 1), Multiply(x, x), Multiply(y, x)), Add(Add(z, z), 2))


In [3]:
expr.to_sym()

(2*z + 2)*Piecewise((x**2, Eq(x, 1)), (x*y, True))

In [4]:
expr_asts: [ast.AstNode] = []
bool_asts: [ast.AstNode] = []

def gen_all():
    global expr_asts, bool_asts
    new_expr_asts = [ast.c1, ast.c2, ast.c3, ast.x, ast.y, ast.z]
    new_bool_asts = []
    for s in expr_asts:
        for t in expr_asts:
            new_expr_asts.append(ast.Add(s, t))
            new_expr_asts.append(ast.Multiply(s, t))
            new_bool_asts.append(ast.Lt(s, t))
            new_bool_asts.append(ast.Eq(s, t))
    for p in bool_asts:
        new_bool_asts.append(ast.Not(p))
        for q in bool_asts:
            new_bool_asts.append(ast.And(p, q))
            new_bool_asts.append(ast.Or(p, q))
        for s in expr_asts:
            for t in expr_asts:
                new_expr_asts.append(ast.Ite(p, s, t))
    expr_asts = new_expr_asts
    bool_asts = new_bool_asts

gen_all()
gen_all()
gen_all()

expr_asts.sort()
bool_asts.sort()

print(len(expr_asts))
for s in expr_asts[:10]:
    print(s)
print('...')
for s in expr_asts[-5:]:
    print(s)


450222
1
2
3
x
y
z
Add(1, 1)
Add(1, 2)
Add(1, 3)
Add(1, x)
...
Ite(Eq(z, z), Multiply(z, z), Multiply(z, 2))
Ite(Eq(z, z), Multiply(z, z), Multiply(z, 3))
Ite(Eq(z, z), Multiply(z, z), Multiply(z, x))
Ite(Eq(z, z), Multiply(z, z), Multiply(z, y))
Ite(Eq(z, z), Multiply(z, z), Multiply(z, z))


In [5]:
nums = [0, 1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59,
        61, 67, 71, 73, 79, 83, 89, 97]
nums += [-n for n in nums[1:]]
r = random.Random(53)

if False:
    best_d = ()
    best_d_vals = {}

    for i in range(10):
        d = (r.choice(nums), r.choice(nums), r.choice(nums))
        vals = {}
        e = ast.Env(*d)
        for s in expr_asts:
            v = (s.eval(e),)
            if v not in vals:
                vals[v] = []
            vals[v].append(s)
        if len(vals) > len(best_d_vals):
            best_d = d
            best_d_vals = vals

    print("%r: %d" % (best_d, len(best_d_vals)))

    for i in range(4):
        next_best_d = ()
        next_best_d_vals = {}
        for j in range(10):
            d = best_d + ((r.choice(nums), r.choice(nums), r.choice(nums)),)
            vals = {}
            e = ast.Env(*(d[-1]))
            for v, ss in best_d_vals.items():
                for s in ss:
                    cv = v + (s.eval(e),)
                    if cv not in vals:
                        vals[cv] = []
                    vals[cv].append(s)
            if len(vals) > len(next_best_d_vals):
                next_best_d = d
                next_best_d_vals = vals
        best_d = next_best_d
        best_d_vals = next_best_d_vals
        print('%d, %d: %r => %d' % (i, j, best_d, len(best_d_vals)))

    print('Discriminators %r:\n  yield %d classes' % (best_d, len(best_d_vals)))
    del best_d_vals

else:
    best_d = ((-61, -13, 79), (-43, 67, -11), (67, -47, -43),
              (-5, -83, -83),
              (1, 11, 59), (13, 1, 71), (19, 73, 1),
              (2, 97, 23), (89, 2, 29), (79, 31, 2),
              (3, 7, 83), (17, 3, 67), (23, 61, 3),
              (5, 7, 7), (17, 5, 17), (23, 23, 5),
              (13, 13, 13),
              )

d_envs = [ast.Env(*d) for d in best_d]

def at_boundary(i, total, divs):
    return i * divs // total < (i + 1) * divs // total

progress_count = 0
progress_total = 1

def progress_start(total):
    global progress_count, progress_total
    progress_count = 0
    progress_total = total

def progress_next():
    global progress_count, progress_total
    assert progress_count < progress_total
    if progress_count >= progress_total:
        return
    if at_boundary(progress_count, progress_total, 100):
        if at_boundary(progress_count, progress_total, 10):
            print('%d%%' % ((progress_count + 1) * 100 // len(expr_asts),), end='')
        else:
            print('.', end='')
    progress_count += 1
    if progress_count >= progress_total:
        print()


In [6]:
progress_start(len(expr_asts))
for i in range(len(expr_asts)):
    s = expr_asts[i]
    canon_ast = s.canonical()
    if canon_ast is not s:
        if not sp.ask(sp.Q.eq(s.to_sym(), s.canonical().to_sym())):
            print('Not equivalent: %s <-> %s' % (s, s.canonical()))
    progress_next()


.........10%.........20%.........30%.........40%.........50%.........60%.........70%.........80%.........90%.........100%


In [7]:
class EquivalenceClass:
    ident: any
    rep: ast.AstNode
    members: [ast.AstNode]

    def __init__(self, ident):
        self.ident = ident
        self.rep = None
        self.members = []

    def ensure(self, a: ast.AstNode):
        assert a is a.canonical()
        if self.rep is None:
            self.rep = a
        self.members.append(a)


class Equivalence:
    equivs: [EquivalenceClass]
    idents: {any: int}
    known: {str: int}

    def __init__(self):
        self.equivs = []
        self.idents = {}
        self.known = {}

    def _compute_ident(self, a: ast.AstNode, hints: [ast.AstNode]):
        assert False

    def _arbitrary_new_ident(self):
        return len(self.equivs)

    def equiv(self, a: ast.AstNode, hints: [ast.AstNode] = []):
        # canonicalize before doing any checks
        a_canon = a.canonical()
        a_s = str(a_canon)
        # early-out if we have already seen this (canonical) a
        if a_s in self.known:
            return self.equivs[self.known[a_s]]
        # otherwise compute its class ident
        a_i = self._compute_ident(a_canon, hints)
        if a_i in self.idents:
            # this class ident already known, return existing
            n = self.idents[a_i]
            a_c = self.equivs[n]
        else:
            # this class ident not known, make a new one
            a_c = EquivalenceClass(a_i)
            n = len(self.equivs)
            self.idents[a_i] = n
            self.equivs.append(a_c)
        a_c.ensure(a_canon)
        self.known[a_s] = n
        return self.equivs[n]

    def tidy(self):
        for e in self.equivs:
            e.members.sort()
            e.rep = e.members[0]


class DiscrimEquiv(Equivalence):
    def _compute_ident(self, a_canon: ast.AstNode, hints: [ast.AstNode]):
        global d_envs
        return tuple((a_canon.eval(e) for e in d_envs))


class SymEquiv(Equivalence):
    subeq: DiscrimEquiv

    def __init__(self):
        super().__init__()
        self.subeq = DiscrimEquiv()

    def _compute_ident(self, a_canon: ast.AstNode, hints: [ast.AstNode]):
        a_sym = a_canon.to_sym()
        # check the hints first
        for h in hints:
            h_s = str(h.canonical())
            if h_s in self.known:
                c = self.equivs[self.known[h_s]]
                # symbolic equivalence check still required to confirm
                c_sym = c.rep.to_sym()
                if sp.ask(sp.Q.eq(a_sym, c_sym)):
                    return c.ident
                else:
                    print("bad equivalence hint: %s == %s" % (a_canon, h_s))
        # check discriminator equivalence
        sub_c = self.subeq.equiv(a_canon)
        if len(sub_c.members) == 1:
            assert sub_c.members[0] is a_canon
            # discriminator equivalence not found, we are unique
            return self._arbitrary_new_ident()
        # slower checks -- first check structural equivalence
        # recursive self.equiv means sub_c.members might grow while we iterate
        i = 0
        while i < len(sub_c.members):
            m = sub_c.members[i]
            i += 1
            # this expr will have been added to sub_c while being checked
            if m is a_canon:
                continue
            # if these are the same operation, and all operands are equivalent,
            # they are equivalent
            if a_canon.name == m.name:
                assert len(a_canon.p) == len(m.p)
                for a_p, m_p in zip(a_canon.p, m.p):
                    if self.equiv(a_p) is not self.equiv(m_p):
                        break
                else:
                    # definitely equivalent!
                    m_s = str(m)
                    assert m_s in self.known
                    return self.equivs[self.known[m_s]].ident
        # if symbolic forms can be proven equal, they are equivalent
        for m in sub_c.members:
            if m is a_canon:
                continue
            m_sym = m.to_sym()
            if sp.ask(sp.Q.eq(a_sym, m_sym)):
                return self.equivs[self.known[str(m)]].ident
        # otherwise, we must assume this is a new class
        return self._arbitrary_new_ident()

    def tidy(self):
        self.subeq.tidy()
        super().tidy()


symeq = SymEquiv()

In [8]:
if False:
    print("Reading known equivalences")
    with open('output/eqclasses.txt', 'r') as f:
        lines = list(f.readlines())
    known_equivs = {}
    progress_start(len(lines))
    for l in lines:
        m = re.match('^(.*) -> new$', l)
        if m:
            s = m.group(1)
            s_ex = ast.parse_expr_str(s)
            known_equivs[s] = symeq.equiv(s_ex).rep
        else:
            m = re.match('^(.*) == (.*)$', l)
            if m:
                s = m.group(1)
                t = m.group(2)
                if t in known_equivs:
                    s_ex = ast.parse_expr_str(s)
                    known_equivs[s] = symeq.equiv(s_ex, hints=[known_equivs[t]]).rep
        progress_next()

if True:
    print("Building equivalence classes")
    progress_start(len(expr_asts))
    for s in expr_asts:
        c = symeq.equiv(s)
        progress_next()

print("Canonicalizing equivalence classes")
symeq.tidy()

print('%d equivalence classes in subeq approximation' % (len(symeq.subeq.equivs),))
print('%d equivalence classes in symbolic equivalence' % (len(symeq.equivs),))

print("Writing known equivalences")
if True:
    with open('output/eqclasses.txt', 'w') as f:
        progress_start(len(expr_asts))
        for s in expr_asts:
            c = symeq.equiv(s)
            if s == c.rep:
                res = '%s -> new' % (s,)
            else:
                res = '%s == %s' % (s, c.rep)
            f.write(res + '\n')
            f.flush()
            progress_next()


Building equivalence classes
.........10%.........20%.........30%.........40%.........50%.........60%.........70%.........80%.........90%.........100%
Canonicalizing equivalence classes
29635 equivalence classes in subeq approximation
36734 equivalence classes in symbolic equivalence
Writing known equivalences
.........10%.........20%.........30%.........40%.........50%.........60%.........70%.........80%.........90%.........100%


In [9]:
s = ast.parse_expr_str('Add(1, Add(1, 3))')
c = symeq.equiv(s)
print(str(c.rep))
print(c.rep < s)
print(s < c.rep)


Add(2, 3)
True
False
