In [1]:
import re, time

In [2]:
def adjust_depth(line, d=0):
    return d + line.count('(') - line.count(')')

def split_paren(tuples):
    return [s for s in re.split('\(|\)|\n', tuples)
            if (s != '' and s != ' ')]

def find_arguments(line):
    depth_key = {'(': 1, ')': -1}
    arg_ind = line.find('(',2)
    i = 0
    d = 0
    for c in line[arg_ind:]:
        i += 1
        if c in depth_key:
            d += depth_key[c]
        if d == 0:
            break
    args = [s.split(' ') for s in split_paren(line[arg_ind : arg_ind+i])]
    return {a[0]: a[1] for a in args}

def find_symbols(line):
    args = [s.split(' ') for s in split_paren(line[line.find('('):])]
    return {a[0]: a[1] for a in args}

In [14]:
def read_grammar(filename):
    '''
    First line has arguments
    Second line has symbols
    No two rules on the same line
    No symbol (should yet) have two rules
    '''
    with open(filename) as f:
        d = 0
        lem_symb = False
        lem_rule = False
        G = {}
        R = []
        for line in f:
            if line.isspace():
                continue
            elif '(synth-fun lemma' in line:
                d = adjust_depth(line, d)
                G['arguments'] = find_arguments(line)
                lem_symb = True
            elif lem_symb:
                d = adjust_depth(line, d)
                d_rules = d
                lem_symb = False
                G['symbols'] = find_symbols(line)
                lem_rule = True
                G['rules'] = []
            elif lem_rule:
                d = adjust_depth(line, d)
                # re has more! use assumption from grammar formatting
                R.extend([s for s in re.split(' |\n', line) if s != ''])
                if d <= d_rules+1 and len(R) > 0:
                    if len(R[0]) > 1 and R[0][1] == '(':
                        R[0] = R[0][1:]
                    s = ' '.join(R)
                    if set(s) - {'(', ')'}:
                        G['rules'].append(s)
                    R = []
                if d <= d_rules:
                    lem_rule = False
                    break
            else:
                continue
    print(G)
    return G

In [21]:
def var_appears(line, var):
    return var in re.split('\(|\)|\n| ', line)

def process_rules(rules):
    R = []
    curr = ''
    d = 0
    for r in rules.split(' '):
        d = adjust_depth(r, d)
        curr += r + ' '
        if d <= 0:
            if not curr.isspace():
                while adjust_depth(curr) < 0 or curr[-1].isspace():
                    curr = curr[:-1]
                R.append(curr)
            curr = ''
    return R

'''
Loc1 x y becomes... # takes n many bools
(and (=> b1 (and x (not (or b2 b3 b4))))
     (=> b2 (and y (not (or b1 b3 b4))))
     (=> b3 (and (lft x) (not (or b1 b2 b4))))
     (=> b4 (and (rght x) (not (or b1 b2 b3))))
     (or b1 b2 b3 b4)
)
'''
def rule_options_bool(rules, i=1):
    replacement = ''
    n = len(rules)
    if n > 2:
        replacement += '(and'
        for j,rule in enumerate(rules):
            replacement += ' (=> b{} (and {} (not (or {}))))'.format(
                i+j,
                rule,
                ' '.join(['b{}'.format(i+k) for k in range(n) if k != j])
            )
        replacement += ' (or {}) )'.format(
            ' '.join(['b{}'.format(i+k) for k in range(n)])
        )
    elif n == 2:
        replacement += '(ite b{} {} {})'.format(
            i,
            rules[0],
            rules[1],
        )
    else:
        replacement += rules[0]
    return replacement

'''
(ite b1 x (ite b2 y (ite b3 (lft x) (rght x))))
'''
def rule_options(rules, p=1):
    n = len(rules)
    if n == 1:
        replacement = rules[0]
    elif n > 1:
        replacement = ' '.join(['(ite p{} {} '.format(
            p+j,
            rule,
        ) for j,rule in enumerate(rules) if j < n-1])
        replacement += ' ' + rules[-1] + ')'*(n-1)
        p += n-1
    return replacement, p

# Unpack grammar dict to construct raw lemma using terminal symbols
def intuit(H, terminal):
    lem = 'Start'
    nonterminal = {'Start'}
    p = 1
    # Ordering based on distance from terminal symbols would be optimal
    # If grammar is finite, this while loop will halt
    # Need treatment of "partially terminal" symbols?
    while nonterminal:
        symb = nonterminal.pop()
        while symb in lem:
            rule, p = rule_options(H[symb]['rule'], p)
            # context-free 
            # recursive call could change
            lem = lem.replace(symb, rule, 1)
        nonterminal |= H[symb]['nonterminal']
    return lem, p

def parse(G):
    terminal = set()
    H = {}
    for rule in G['rules']:
        rule = rule.split(' ', 1)
        symb = rule[0][1:]
        repl = rule[-1].split('(',1)[-1][:-2]
        H[symb] = {
            'rule' : process_rules(repl),
            'dependent' : {s for s in G['symbols'] if var_appears(repl,s)},
            'variables' : [v for v in G['arguments'] if var_appears(repl,v)],
        }
        if not H[symb]['dependent']:
            terminal.add(symb)
    for symb in H:
        H[symb]['nonterminal'] = H[symb]['dependent'] - terminal
        H[symb]['terminal'] = H[symb]['dependent'] - H[symb]['nonterminal']
    return H, terminal

def process_terminals(H, terminal, lem):
    term_count = {symb : lem.count(symb) for symb in terminal}
    for symb in term_count:
        args = ''
        for a in H[symb]['variables']:
            args += ' ' + a
        lem_list = lem.split()
        j = 1
        for i, word in enumerate(lem_list):
            if symb in word:
                lem_list[i] = word.replace(symb, ''.join([
                    '(', symb,
                    '_' if symb[-1].isdigit() else '',
                    str(j),args,')']))
                j += 1
        lem = ' '.join(lem_list)
    return lem

def synthesize_lemma(G):
    H, terminal = parse(G)
    lem, p = intuit(H, terminal)
    lem = process_terminals(H, terminal, lem)
    print(H)
    return H, terminal, lem, p

In [22]:
def gen_function(arguments, n):
    fun = ''
    d = 0
    for a in arguments[:-1]:
        fun += '(ite b{} {} '.format(n+d, a)
        d += 1
    fun += arguments[-1] + ')'*d
    return fun

def format_lem(lem):
    # could have lem written as a nested data structure to avoid need for this
    L = lem.split(' ')
    d = 0
    d_ref = [0]
    ind_ref = [0]
    ind_curr = 0
    for i,word in enumerate(L):
        d = adjust_depth(word, d)
        ind_curr += len(word) + 1
        if d <= d_ref[-1]:
            if d < d_ref[-1]:
                while d < d_ref[-1]:
                    d_ref.pop(-1)
                    ind_ref.pop(-1)
            L[i] += '\n' + ' '*(ind_ref[-1]-1)
            ind_curr = 0
        if word in {'(=>', '(ite', '(and', '(or', '(<='}:
            d_ref.append(d)
            ind_ref.append(ind_ref[-1]+ind_curr)
            ind_curr = 0
    return ' '.join(L)

def scan_lemma(lem, terminal, G, H, p, indent):
    f_counts = {symbol: lem.count(symbol) for symbol in terminal}
    f_total = sum(f_counts.values())
    fun = [''] * f_total
    flow = ['(declare-const p{} Bool)'.format(r+1) for r in range(p-1)]
    c = 0
    i = 1
    for symbol in terminal:
        for f in range(f_counts[symbol]):
            fun[c] = '(define-fun {}{} ({}) {}\n{}\n)'.format(
                symbol,
                f+1 if not symbol[-1].isdigit() else ''.join(['_',str(f+1)]),
                ' '.join(['({} {})'.format(v, G['arguments'][v]) for v in H[symbol]['variables']]),
                G['symbols'][symbol],
                gen_function(H[symbol]['rule'], i),
            )
            c += 1
            i += len(H[symbol]['rule']) - 1
    const = ['(declare-const b{} Bool)'.format(c+1) for c in range(i-1)]
    lemma = '(define-fun lemma ({}) Bool\n{})'.format(
        ' '.join(['({} {})'.format(v, G['arguments'][v]) for v in G['arguments']]),
        format_lem(lem) if indent else lem,
    )
    return const, flow, fun, lemma

In [23]:
def write_lemma(filename, const, flow, fun, lemma):
    new_filename = filename[:-4] + '_syn.txt'
    with open(new_filename, 'w') as file:
        for S in [const, flow, fun]:
            for s in S:
                file.write(s)
                file.write('\n')
            if S:
                file.write('\n')
        file.write(lemma)

def synthesize(filename, indent=True):
    G = read_grammar(filename)
    H, terminal, lem, p = synthesize_lemma(G)
    const, flow, fun, lemma = scan_lemma(lem, terminal, G, H, p, indent)
    write_lemma(filename, const, flow, fun, lemma)

-------------------

In [24]:
for filename in {'lem_test.txt'}:#'lem_test2.txt'}:
    start = time.time()
    synthesize(filename)
    print('runtime: {:.4f}s'.format(time.time()-start))

{'arguments': {'x': 'Int', 'y': 'Int'}, 'symbols': {'Start': 'Bool', 'B1': 'Bool', 'B2': 'Bool', 'B3': 'Bool', 'Loc': 'Int'}, 'rules': ['(Start Bool ( (=> B1 (and B2 B3))))', '(B1 Bool ((member Loc (hbst Loc))))', '(B2 Bool ((<= (key Loc) (maxr Loc))))', '(B3 Bool ((<= (minr Loc) (key Loc))))', '(Loc Int (x y))']}
{'Start': {'rule': ['(=> B1 (and B2 B3))'], 'dependent': {'B2', 'B1', 'B3'}, 'variables': [], 'nonterminal': {'B2', 'B1', 'B3'}, 'terminal': set()}, 'B1': {'rule': ['(member Loc (hbst Loc))'], 'dependent': {'Loc'}, 'variables': [], 'nonterminal': set(), 'terminal': {'Loc'}}, 'B2': {'rule': ['(<= (key Loc) (maxr Loc))'], 'dependent': {'Loc'}, 'variables': [], 'nonterminal': set(), 'terminal': {'Loc'}}, 'B3': {'rule': ['(<= (minr Loc) (key Loc))'], 'dependent': {'Loc'}, 'variables': [], 'nonterminal': set(), 'terminal': {'Loc'}}, 'Loc': {'rule': ['x', 'y'], 'dependent': set(), 'variables': ['x', 'y'], 'nonterminal': set(), 'terminal': set()}}
runtime: 0.0017s


Loc x y z
yields (ite b1 x (ite b2 y z))
#or...
(and (=> b1 x)
     (=> b2 y)
     (=> b3 z)
     (xor b1 b2 b3))

Loc -> x y (lft x) (rght x)

Loc1 x y becomes... # takes n many bools
(and (=> b1 (and x (not (or b2 b3 b4))))
     (=> b2 (and y (not (or b1 b3 b4))))
     (=> b3 (and (lft x) (not (or b1 b2 b4))))
     (=> b4 (and (rght x) (not (or b1 b2 b3))))
     (or b1 b2 b3 b4)
)

OR becomes...
(ite b1 x (ite b2 y (ite b3 (lft x) (rght x)))) # takes n-1 many bools

OR becomes...
(ite b1 (ite b2 x y) (ite b3 (lft x) (rght x))) # takes log_2(n) many bools

In [5]:
dic = {1: 2, '3': '4'}
print(set(dic.keys()))

{1, '3'}
