In [61]:
import numpy as np
from pathlib import Path
import json

from difflib import SequenceMatcher, Differ
import re

with open('extracted_theorems.json', 'r') as f:
    lines = f.readlines()

theorems = json.loads(lines[0])


def hole_terms():
    i = -1
    while True:
        if i == -1:
            i += 1
            # yield '?Goal'
            yield '?Goal'
        else:
            out = f'?Goal{i}'
            i += 1
            yield out


def get_hole_term_re(i = -1):
    if i == -1:
        # regex get start with ?Goal but not ?Goal0
        return re.compile(r'\?Goal(?!\d)')
        # return re.compile(r'\?Goal\_')
    else:
        return re.compile(r'\?Goal' + str(i))


def hole_terms_re():
    i = -1
    while True:
        out = get_hole_term_re(i)
        i += 1
        yield out
        
        
def find_all(a_str, sub):
    start = 0
    while True:
        start = a_str.find(sub, start)
        if start == -1: return
        yield start
        start += len(sub) # use start += 1 to find overlapping matches
        

def find_parens(s):
    toret = {}
    pstack = []

    for i, c in enumerate(s):
        if c == '(':
            pstack.append(i)
        elif c == ')':
            if len(pstack) == 0:
                raise IndexError("No matching closing parens at: " + str(i))
            toret[pstack.pop()] = i

    if len(pstack) > 0:
        raise IndexError("No matching opening parens at: " + str(pstack.pop()))

    myKeys = list(toret.keys())
    myKeys.sort()
    sorted_dict = {i: toret[i] for i in myKeys}
    
    return sorted_dict


def get_goal_digit(goal_term):
    digit_start = False
    digits = ''
    for i, c in enumerate(goal_term[5:]):
        if i == 0 and not c.isdigit():
            return -1
        if not digit_start:
            if c.isdigit():
                digit_start = True
                digits += c
        else:
            if c.isdigit():
                digits += c
            else:
                break

    if digit_start == False:
        return -1
    else:
        return int(digits)

In [62]:
def get_stem(s, goal_idx = -1):
    parens = find_parens(s)
    lst = ['' for each in s]
    idx_lst = list(parens.keys()) + list(parens.values())
    for start, end in parens.items():
        lst[start] = '('
        lst[end] = ')'
    
    cnt = -1
    for each_hole_re in hole_terms_re():
        find = each_hole_re.search(s)
        if find:
            hole_idx = find.start()
            if cnt == goal_idx:
                lst[hole_idx] = '_'
            else:
                lst[hole_idx] = '?'
            idx_lst.append(hole_idx)
            cnt += 1
        else:
            break
    
    return ''.join(lst), parens, sorted(idx_lst), cnt


In [68]:
from pprint import pprint

theorem = theorems['aaron_stump_cse545']

for idx, steps in enumerate(theorem['proof_steps']):
    if idx == 0:
        continue
    # pprint(steps['proof_term_before'][0])
    # pprint(steps['proof_term_after'][0])
    print(f'+++++++++++++ {idx} ++++++++++++++')
    proof_term_before = ' '.join(steps['proof_term_before'][0].replace('\n', ' ').split()) if len(steps['proof_term_before']) != 0 else ''
    proof_term_after = ' '.join(steps['proof_term_after'][0].replace('\n', ' ').split()) if len(steps['proof_term_after']) != 0 else ''


    hole_idx_ranges = {}
    holes = {}

    proof_term_split_before = proof_term_before.split(' ')
    proof_term_split_after = proof_term_after.split(' ')

    goal_filled = None
    goal_filled_idx = None
    include_pattern_matching = False

    for b, a in zip(proof_term_split_before, proof_term_split_after):
        if b != a:
            if '?Goal' in b and '?Goal' not in a:
                goal_filled = b
                goal_filled_idx = get_goal_digit(b)
                break
            if b == '=>':
                include_pattern_matching = True
                
    print('goal_filled', goal_filled, goal_filled_idx)

    cnt = -1
    for hole_term_re in hole_terms_re():
        search = hole_term_re.search(proof_term_before)
        if search is None:
            break
        if cnt == goal_filled_idx:
            print(f'--------- {cnt} ---------')
            
            stem_before, parens_before, parens_idx_lst_before, cnt_before = get_stem(proof_term_before, cnt)
            stem_after, parens_after, parens_idx_lst_after, cnt_after = get_stem(proof_term_after, None)

            # tmp_stem_before = stem_before.replace(str(cnt), '_')
            # for j in range(cnt_before):
            #     if j != cnt:
            #         tmp_stem_before = tmp_stem_before.replace(str(j), '+')
            fwd_diff = list(Differ().compare(stem_before, stem_after))
            bwd_diff = list(Differ().compare(stem_before[::-1], stem_after[::-1]))


            hole_end = -1
            hole_start = -1
            in_hole = False
            for i, p in enumerate(bwd_diff):
                if not in_hole:
                    if p == '- _':
                        hole_end = len(bwd_diff) - i - 1
                        in_hole = True
                else:
                    if p[0] == '+':
                        hole_start = len(bwd_diff) - i - 1
                    else:
                        break
            # print(hole_start, hole_end)
            # print(bwd_diff[::-1][hole_start:hole_end])

            # hole_idx_ranges[cnt] = (hole_start, hole_end)
            print()
            search = get_hole_term_re(cnt).search(proof_term_before)
            hole_text_start = search.start()
            if include_pattern_matching:
                hole_text_start = hole_text_start - 3

            hole_text_end = parens_idx_lst_after[hole_end]
            print(hole_text_start, hole_text_end)
            print(len(proof_term_after))
            holes[cnt] = proof_term_after[hole_text_start:hole_text_end]

            # for f, b in zip(fwd_diff, bwd_diff[::-1]):
            #     print(f, '|', b)
        
        
        for k, v in holes.items():
            print(k, v)
        
        cnt += 1 
    
    print(get_stem(proof_term_before)[0])
    print(get_stem(proof_term_after)[0])
    print(proof_term_before)
    print(proof_term_after)

+++++++++++++ 1 ++++++++++++++
goal_filled ?Goal) -1
--------- -1 ---------

101 245
246
-1 nat_ind (fun n0 : nat => forall n' : nat, n' <= n0 -> Q n') ?Goal (fun (n0 : nat) (IHn : forall n' : nat, n' <= n0 -> Q n') => ?Goal0@{n:=n0}) n
(()(())()_)
(()(())()()_(()()?))
(fun (Q : nat -> Prop) (sIH : forall n : nat, (forall n' : nat, n' < n -> Q n') -> Q n) (n : nat) => ?Goal)
(fun (Q : nat -> Prop) (sIH : forall n : nat, (forall n' : nat, n' < n -> Q n') -> Q n) (n : nat) => nat_ind (fun n0 : nat => forall n' : nat, n' <= n0 -> Q n') ?Goal (fun (n0 : nat) (IHn : forall n' : nat, n' <= n0 -> Q n') => ?Goal0@{n:=n0}) n)
+++++++++++++ 2 ++++++++++++++
goal_filled ?Goal -1
--------- -1 ---------

161 204
282
-1 (fun (n' : nat) (Hn' : n' <= 0) => ?Goal0) 
(()(())()()_(()()?))
(()(())()()(()()?)(()()_))
(fun (Q : nat -> Prop) (sIH : forall n : nat, (forall n' : nat, n' < n -> Q n') -> Q n) (n : nat) => nat_ind (fun n0 : nat => forall n' : nat, n' <= n0 -> Q n') ?Goal (fun (n0 : nat) (IHn : fo