In [None]:
%load_ext tikzmagic

In [None]:
import tfo_parser_ulambda as ul
import IPython.display
import copy

g_test = True
g_parser = ul.parser()

def f_expx(txt):
    def f_parse(txt):
        return g_parser.parse(txt)
    def f_tree(parsed):
        transl = {'application':'appl', 'abstraction':'abstr'}
        id_cnt = [0]
        def next_id():
            id_cnt[0] = id_cnt[0]+1; return id_cnt[0]-1;
        def recurse(node):
            if type(node) == ul.lark.tree.Tree:
                if node.data not in ['variable', 'term']:
                    return {'id':next_id(), 'type':transl.get(node.data, node.data), 'kids':[recurse(x) for x in node.children]}
                else:
                    return recurse(node.children[0])
            else:
                return {'id':next_id(), 'type':'var', 'name' : str(node)}
        root = recurse(parsed)
        return root
    parsed = f_parse(txt)
    tree = f_tree(parsed)
    return {'txt':txt, 'parsed':parsed, 'tree':tree}

def f_pretty(expx):
    return expx['parsed'].pretty()

f_list_flatten = lambda lists: reduce(lambda a,b: a+b, lists)
f_latex_std_abstr = lambda children_out: f_list_flatten((['(', '\\lambda'], [children_out[0]], ['.'], children_out[1:], [')']))
f_latex_std_appl = lambda children_out: f_list_flatten((['['], children_out, [']']))
f_latex_2d_abstr = lambda children_out: f_list_flatten((['\dfrac{'], [children_out[0]], ['}{'], children_out[1:], ['}']))
f_latex_2d_appl = lambda children_out: children_out
g_latex_style_std = {}
g_latex_style_2d = {'abstr':f_latex_2d_abstr, 'appl':f_latex_2d_appl}

def f_latex(expx, style = {}, key = 'tree'):
    func_abstr = style.get('abstr', f_latex_std_abstr)
    func_appl = style.get('appl', f_latex_std_appl)
    def recurse(node):
        if node['type'] not in ['var']:
            children_out = f_list_flatten([recurse(x) for x in node['kids']])
            if node['type'] == 'abstr':
                return func_abstr(children_out)
            elif node['type'] == 'appl':
                return func_appl(children_out)
        else:
            return [node['name']]
    return ' '.join(recurse(expx[key] if key is not None else expx))

def f_disp_latex_txt(tex):
    IPython.display.display(IPython.display.Latex('${}$'.format(tex)))
    
def f_disp_latex(expx, style = {}, key = 'tree'):
    f_disp_latex_txt(f_latex(expx, style, key))

def test_latex():
    display(IPython.display.Latex('${}$'.format(f_latex(f_expx('[x.[y z]]') ))))
    display(IPython.display.Latex('${}$'.format(f_latex(f_expx('[x.[x z]]'), g_latex_style_2d))))
    f_disp_latex(f_expx('[x.[y z]]'))
    f_disp_latex(f_expx('[x.[y z]]'), g_latex_style_2d)
    
if g_test:
    test_latex()

In [None]:
def f_tikz_tree(expx):
    transl = {'application':''}
    make_node = lambda content: 'node{{{}}}'.format(transl.get(content, content))
    make_child = lambda content: 'child{{{}}}'.format(transl.get(content, content))
    make_kid_nodes = lambda kids: ' '.join([make_child(kid) for kid in kids])
    def recurse(node, pars=('{', '}')):
        if node['type'] not in ['var']:
            children_out = [recurse(x) for x in node['kids']]
            if node['type'] == 'abstraction':
                return ' '.join([ ''.join([pars[0], children_out[0]]), make_kid_nodes(children_out[1:]), pars[1]])
            else:
                return ' '.join([ ''.join([pars[0], make_node(node['type'])]), make_kid_nodes(children_out), pars[1]])
        else:
            return make_node(node['name'])
    tikz_code = recurse(expx['tree'], ('\\',';'))
    return tikz_code

def f_disp_tikz_tree(expx, style = {}):
    IPython.get_ipython().run_cell('%%tikz\n' + f_tikz_tree(expx))

def test_tree():
    print f_tikz_tree(f_expx('[x.[y z]]'))
    print f_tikz_tree(f_expx('[x y]'))
    print f_tikz_tree(f_expx('[x. y]'))
    print f_tikz_tree(f_expx('[y z]'))
    print f_tikz_tree(f_expx('[x. [y z]]'))
    
if g_test:
    test_tree()
    
if g_test:
    f_disp_tikz_tree(f_expx('[x. [y [z. [z x]]]]'))

In [None]:
def f_cached_generic(expx, key, func):
    if key not in expx:
        expx[key] = func(expx)
    return expx[key]

def f_canon(expx):
    id_cnt = [0]
    def next_id(): id_cnt[0] = id_cnt[0]+1; return str(id_cnt[0]-1);
    def add_free(name, tbl): tbl[name] = next_id(); return tbl[name];
    lazy_get = lambda tbl, key, func: tbl[key] if key in tbl else func()
    def recurse(node, bound_to_canon, free_to_canon):
        if 'kids' in node:
            if node['type'] == 'abstr':
                n_bound_to_canon = copy.copy(bound_to_canon)
                n_bound_to_canon[node['kids'][0]['name']] = next_id()
                node['kids'] = [recurse(copy.copy(x), n_bound_to_canon, free_to_canon) for x in node['kids']]
            elif node['type'] == 'appl':
                node['kids'] = [recurse(copy.copy(x), bound_to_canon, free_to_canon) for x in node['kids']]
        else:
            node['name'] = lazy_get(bound_to_canon, node['name'], 
                                    lambda: lazy_get(free_to_canon, node['name'], 
                                        lambda: add_free(node['name'], free_to_canon) ) )
        return node     
    return recurse(copy.copy(expx['tree']), {}, {})
f_cached_canon = lambda expx: f_cached_generic(expx, 'canon', f_canon)    

def f_alpha_eq(expx1, expx2):
    def recurse(node1, node2):
        if node1['type'] != node2['type']:
            return False
        if 'kids' in node1:
            return len(node1['kids']) == len(node2['kids']) and all([recurse(*tup) for tup in zip(node1['kids'], node2['kids'])])
        else:
            return node1['name'] == node2['name']
    canon1 = f_cached_canon(expx1)
    canon2 = f_cached_canon(expx2)
    return recurse(canon1, canon2) 

def f_vars(expx):
    def recurse(node):
        if 'kids' in node:
            (l_all, l_free), (r_all, r_free) = [recurse(x) for x in node['kids']]
            if node['type'] == 'abstr':
                binding = l_free
                return (l_all | r_all, r_free - binding)
            elif node['type'] == 'appl':
                return (l_all | r_all, l_free | r_free)
        else:
            return (set(), set(node['name']))
    (v_all, v_free) = recurse(expx['tree'])
    return {'all':v_all, 'free':v_free}
f_cached_vars = lambda expx: f_cached_generic(expx, 'vars', f_vars)    

def f_closed(expx):
    return len(f_cached_vars(expx)['free']) == 0

test_txts = ['x', '[x. x]', '[x x]','[x y]', '[x. [x y]]', '[[x. [x y]] x]', '[[x. x] x]', '[y. [x. [x y]]]']

def test_vars():
    for txt in test_txts:
        print '{} : {}'.format(txt, f_vars(f_expx(txt)))
    
if g_test:
    test_vars()
    
def test_canon():
    for txt in test_txts:
        f_disp_latex_txt('\\text{{{}}} : {}'.format(txt, f_latex(f_canon(f_expx(txt)), {}, None)))

if g_test:
    test_canon()
    
def test_alpha_eq():
    print f_alpha_eq(f_expx('x'), f_expx('y'))
    print f_alpha_eq(f_expx('[x. [x y]]'), f_expx('[y. [y z]]'))
    print f_alpha_eq(f_expx('[x. y]'), f_expx('[x y]'))
    print f_alpha_eq(f_expx('[x y]'), f_expx('[x x]')) 
    
if g_test:
    test_alpha_eq()