In [None]:
from lean_dojo import LeanGitRepo, trace, Dojo, DojoInitError, DojoCrashError

In [None]:
# repo = LeanGitRepo(
#     "https://github.com/leanprover-community/mathlib4",
#     "29dcec074de168ac2bf835a77ef68bbe069194c5",
# )

repo = LeanGitRepo(
    "https://github.com/yangky11/lean4-example",
    "7b6ecb9ad4829e4e73600a3329baeb3b5df8d23f"
)

repo

In [None]:
repo.get_config("lean-toolchain")

In [None]:
traced_repo = trace(repo)

In [None]:
def best_first_search(
        theorem,
        model,
        tokenizer,
        max_iters,
        temperatures,
        num_samples,
        batch_size,
        timeout=600,
        early_stop=False,
        max_seq_len=512,
        top_k=200
) -> dict:
    """Best first search."""
    attempt_results = []
    print("theorem: ", theorem)
    try:
        with Dojo(theorem, hard_timeout=timeout) as (dojo, init_state):

            start = time.time()
            proof_finished = False
            cnt = 0
            states, steps, traces = [], [], []
            for i in range(num_samples):
                states.append(init_state)
                steps.append([])
                traces.append([])

            for iteration in trange(max_iters):
                istart = time.time()
                if istart - start > timeout:
                    break
                if proof_finished:
                    break

                ts = [_tactic_state(state) for state in states]

                step_cands = generate_tactic(
                    ts,
                    model,
                    tokenizer,
                    max_seq_len=max_seq_len,
                    num_samples=1,
                    temperature=temperatures
                )

                #if iteration < 2:
                #    print(iteration, " # state: ",ts[0])
                #    print("tatics: ", step_cands[0])

                step_cots = step_cands
                step_cands = [s.split("```lean\n")[-1].split('```')[0].split('---')[0].strip() for s in step_cands]
                #print(step_cands[:10])
                for i in range(num_samples):
                    state, step, step_cot = states[i], step_cands[i], step_cots[i]
                    result = dojo.run_tac(state, step)
                    step_trace = {
                        "tactic": step,
                        "full_cot": step_cot,
                        "state_before": _tactic_state(state)
                    }
                    if isinstance(result, ProofFinished):
                        attempt_results.append({
                            'theorem': theorem.full_name,
                            'proof': steps[i] + [step],
                            'success': True,
                            'failure_reason': '',
                            'trace': traces[i] + [step_trace],
                            'temperature': temperatures,
                            'elapsed': start - time.time(),
                            'iteration': iteration
                        })
                        if early_stop:
                            return attempt_results
                        proof_finished = True
                        break
                    elif isinstance(result, TacticState):
                        #if _tactic_state(result) not in visited:
                        # Score is negative log probability summed across steps
                        #new_score = (total_score - score)
                        cnt += 1
                        states[i] = result
                        steps[i].append(step)
                        traces[i].append(step_trace)
    except (DojoInitError, DojoHardTimeoutError, DojoCrashError) as e:
        print("Error: ", e)
        if len(attempt_results) == 0:
            attempt_results.append({
                'theorem': theorem.full_name,
                'success': False,
                'failure_reason': type(e).__name__
            })

    if len(attempt_results) == 0:
        attempt_results.append({
            'theorem': theorem.full_name,
            'success': False,
            'failure_reason': 'SearchEnded'
        })

    return attempt_results