In [1]:
import argparse
import logging
import multiprocessing as mp
import os
import pytz
import shutil
from datetime import datetime

from lego_prover.env.chromas import ChromaBridge
from lego_prover.evolver import Evolver
from lego_prover.prover import Prover
import lego_prover.utils as U
# from openai_key import *

In [2]:
parser = argparse.ArgumentParser(description='LEGO-Prover')
parser.add_argument('--resume', action='store_true',
                    help='whether to resume from the checkpoint')
parser.add_argument('--data_split', type=str, choices=['valid', 'test'], 
                    default='valid', help='data split to use in the miniF2F dataset')
parser.add_argument('--ckpt_dir', type=str, default='checkpoints/lego_prover_valid_2023_10_27',
                    help='path to the checkpoint directory')
parser.add_argument('--isabelle_path', type=str, default='/data2/wanghaiming/Isabelle2022/',
                    help='path to the Isabelle2022 directory')
parser.add_argument('--model_name', type=str, choices=["gpt-3.5-turbo", "gpt-4"], 
                    default='gpt-3.5-turbo', help='OpenAI model name')
parser.add_argument('--temperature', type=float, default=0.7,
                    help='temperature for sampling the LLM')
parser.add_argument('--num_prover', type=int, default=3,
                    help='number of prover processes')
parser.add_argument('--num_evolver', type=int, default=8,
                    help='number of evolver processes')
parser.add_argument('--num_attempts', type=int, default=100,
                    help='number of proving attempts for each problem in the dataset')
args = parser.parse_args([])
# for arg in args:
#     print(arg)


In [3]:
resume = args.resume
data_split = args.data_split
ckpt_dir = args.ckpt_dir

args.isabelle_path =  "/workspace/huangyongfeng/ananke/example/pipeline/math_experiment/LEGO-Prover/Isabelle2022"
isabelle_path = args.isabelle_path
model_name = args.model_name
temperature = args.temperature

args.num_prover = 1
number_of_prover_processes = args.num_prover
number_of_evolver_processes = args.num_evolver
number_of_prover_attempts = args.num_attempts

if os.path.exists(ckpt_dir) and not resume:
    text = input(f"the checkpoint directory {ckpt_dir} is already exist, and" + \
                 f"you are not resuming from it, do you want to delete it? (y/n)")
    if "y" in text.lower():
        shutil.rmtree(ckpt_dir, ignore_errors=True)
        resume = False
    else:
        resume = True


the checkpoint directory checkpoints/lego_prover_valid_2023_10_27 is already exist, andyou are not resuming from it, do you want to delete it? (y/n) y


In [4]:
# load miniF2F tasks and resume from the checkpoint
miniF2F_tasks = mp.Queue()
problem_names = []
if resume:
    if os.path.exists(f"{ckpt_dir}/curriculum/completed_tasks.json"):
        completed_tasks = U.load_json(
            f"{ckpt_dir}/curriculum/completed_tasks.json")
    if os.path.exists(f"{ckpt_dir}/curriculum/failed_tasks.json"):
        failed_tasks = U.load_json(f"{ckpt_dir}/curriculum/failed_tasks.json")
    print("Current progress: ", len(completed_tasks) + len(set(failed_tasks)))
else:
    completed_tasks = []
    failed_tasks = []
for name in os.listdir(f"data/full_data/{data_split}"):
    path = os.path.join(f"data/full_data/{data_split}", name)
    context = U.load_json(path)
    problem_names.append((path, len(context["informal_proof"])))
problem_names = sorted(problem_names, key=lambda x: x[1])
problem_names = [pn[0] for pn in problem_names]
problem_names = problem_names * number_of_prover_attempts     # 10 * 20 = 200 sketch
for pn in problem_names:
    if pn in completed_tasks:
        continue
    if pn in failed_tasks:
        failed_tasks.remove(pn)
        continue
    miniF2F_tasks.put(pn)
print(f"Sketch to finish: {miniF2F_tasks.qsize()}")

Sketch to finish: 24400


In [7]:
# setup multiprocessing logger
start_time = datetime.now(pytz.timezone(
    'Asia/Shanghai')).strftime("%Y%m%d_%H%M%S")

os.makedirs(f'logs/prover/{start_time}_logs', exist_ok=True)
for rank in range(number_of_prover_processes):
    logger = logging.getLogger(f'prover-{rank}')
    handler = logging.FileHandler(
        f"logs/prover/{start_time}_logs/rank_{rank}.log")
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.setLevel(logging.INFO)

os.makedirs(f'logs/evolver/{start_time}_logs', exist_ok=True)
for evolver_rank in range(number_of_evolver_processes):
    evolver_rank += number_of_prover_processes
    logger = logging.getLogger(f'evolver-{evolver_rank}')
    handler = logging.FileHandler(
        f"logs/evolver/{start_time}_logs/rank_{evolver_rank}.log")
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.setLevel(logging.INFO)

In [8]:
# define the function to run the prover and evolver
def run_prover(rank, tasks, skill_manager_lock, curriculum_agent_lock, chroma_bridge):
    server_port = 8051 + rank

    prover = Prover(
        rank=rank,
        isabelle_path=isabelle_path,
        server_port=server_port,
        model_name=model_name,
        skill_manager_lock=skill_manager_lock,
        action_agent_task_max_retries=1,
        curriculum_task_type="queue_curriculum",
        curriculum_agent_lock=curriculum_agent_lock,
        resume=resume,
        temperature=temperature,
        miniF2F_tasks=tasks,
        ckpt_dir=ckpt_dir,
        chroma_bridge=chroma_bridge,
    )
    prover.learn()

def run_evolver(rank, skill_manager_lock, chroma_bridge):
    server_port = 8011 + rank
    evolver = Evolver(
        rank=rank,
        isabelle_path=isabelle_path,
        ckpt_dir=ckpt_dir,
        server_port=server_port,
        data_split=data_split,
        skill_manager_lock=skill_manager_lock,
        model_name=model_name,
        temperature=temperature,
        chroma_bridge=chroma_bridge
    )
    evolver.evolve()

In [9]:
isabelle_path='/root/Isabelle2022/'

In [10]:
processes = []
skill_manager_lock = mp.Lock()
curriculum_agent_lock = mp.Lock()
chroma_bridge = ChromaBridge(ckpt_path=ckpt_dir, resume=resume)
rank, tasks, skill_manager_lock, curriculum_agent_lock, chroma_bridge = rank, miniF2F_tasks, skill_manager_lock, curriculum_agent_lock, chroma_bridge

server_port = 8051 + rank
print(server_port)
prover = Prover(
    rank=rank,
    isabelle_path=isabelle_path,
    server_port=server_port,
    model_name=model_name,
    skill_manager_lock=skill_manager_lock,
    action_agent_task_max_retries=1,
    curriculum_task_type="queue_curriculum",
    curriculum_agent_lock=curriculum_agent_lock,
    resume=resume,
    temperature=temperature,
    miniF2F_tasks=tasks,
    ckpt_dir=ckpt_dir,
    chroma_bridge=chroma_bridge,
)




Subprocess chroma_worker started with PID 22012.




8051
Subprocess isabelle_server started with PID 22070.


[36m[2024-03-29 09:34:23,494] [Azure] [DEBUG] - azure[0m


In [9]:
!kill -9 9734 9793

In [7]:
processes = []
skill_manager_lock = mp.Lock()
curriculum_agent_lock = mp.Lock()
chroma_bridge = ChromaBridge(ckpt_path=ckpt_dir, resume=resume)


# creating processes
for rank in range(number_of_prover_processes):
    p = mp.Process(target=run_prover, args=(rank, miniF2F_tasks,
                   skill_manager_lock, curriculum_agent_lock, chroma_bridge))
    processes.append(p)
    p.start()

# for rank in range(number_of_evolver_processes):
#     rank += number_of_prover_processes
#     p = mp.Process(target=run_evolver, args=(
#         rank, skill_manager_lock, chroma_bridge))
#     processes.append(p)
#     p.start()

# completing process
for p in processes:
    p.join()

Subprocess chroma_worker started with PID 272541.




Subprocess isabelle_server started with PID 272656.


[36m[2024-03-29 07:28:25,975] [Azure] [DEBUG] - azure[0m


Subprocess isabelle_server started with PID 272764.




Subprocess isabelle_server started with PID 272830.




Subprocess isabelle_server started with PID 272917.




Subprocess isabelle_server started with PID 273195.




Subprocess isabelle_server started with PID 273273.




Subprocess isabelle_server started with PID 273394.




Subprocess isabelle_server started with PID 273481.




Subprocess isabelle_server started with PID 273587.




Subprocess isabelle_server started with PID 273700.




Subprocess isabelle_server started with PID 273777.




Subprocess isabelle_server started with PID 273894.




Subprocess isabelle_server started with PID 273988.




Subprocess isabelle_server started with PID 274079.




Subprocess isabelle_server started with PID 274184.




Subprocess isabelle_server started with PID 274258.




Subprocess isabelle_server started with PID 274354.




Subprocess isabelle_server started with PID 274429.




Subprocess isabelle_server started with PID 274515.




Subprocess isabelle_server started with PID 274614.




Subprocess isabelle_server started with PID 274694.




Subprocess isabelle_server started with PID 274795.




Subprocess isabelle_server started with PID 274858.




Subprocess isabelle_server started with PID 274948.




Subprocess isabelle_server started with PID 275034.




Subprocess isabelle_server started with PID 275096.




Subprocess isabelle_server started with PID 275183.




Subprocess isabelle_server started with PID 275253.




Subprocess isabelle_server started with PID 275355.




Subprocess isabelle_server started with PID 275441.




Subprocess isabelle_server started with PID 275504.




Subprocess isabelle_server started with PID 275597.




Subprocess isabelle_server started with PID 275667.




Subprocess isabelle_server started with PID 275754.




Subprocess isabelle_server started with PID 275840.




Subprocess isabelle_server started with PID 275899.




Subprocess isabelle_server started with PID 275986.




Subprocess isabelle_server started with PID 276048.




Subprocess isabelle_server started with PID 276134.




Subprocess isabelle_server started with PID 276225.




Subprocess isabelle_server started with PID 276290.




Subprocess isabelle_server started with PID 276379.




Subprocess isabelle_server started with PID 276477.




Subprocess isabelle_server started with PID 276583.




Subprocess isabelle_server started with PID 276698.




Subprocess isabelle_server started with PID 276809.




Subprocess isabelle_server started with PID 276953.




Subprocess isabelle_server started with PID 277063.




Subprocess isabelle_server started with PID 277170.




Subprocess isabelle_server started with PID 277264.




Subprocess isabelle_server started with PID 277340.




Subprocess isabelle_server started with PID 277454.




Subprocess isabelle_server started with PID 277524.




Subprocess isabelle_server started with PID 277632.




Subprocess isabelle_server started with PID 277743.




Subprocess isabelle_server started with PID 277938.




Subprocess isabelle_server started with PID 278046.




Subprocess isabelle_server started with PID 278112.




Subprocess isabelle_server started with PID 278229.




Subprocess isabelle_server started with PID 278328.




Subprocess isabelle_server started with PID 278402.




Subprocess isabelle_server started with PID 278509.




Subprocess isabelle_server started with PID 278624.




Subprocess isabelle_server started with PID 278731.




Subprocess isabelle_server started with PID 278838.




Subprocess isabelle_server started with PID 278900.




Subprocess isabelle_server started with PID 279013.




Subprocess isabelle_server started with PID 279092.




Subprocess isabelle_server started with PID 279187.




Subprocess isabelle_server started with PID 279295.




Subprocess isabelle_server started with PID 279361.




In [None]:
# Copyright 2024 undefined
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     https://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.



In [17]:
import json
import os
import time
import re
import warnings
from typing import List

import psutil
import subprocess
import logging
import threading

import lego_prover.utils as U


class SP:
    def __init__(
        self,
        commands: List[str],
        name: str,
        ready_match: str = r".*",
        log_path: str = "logs",
        callback_match: str = r"^(?!x)x$",  # regex that will never match
        callback: callable = None,
        finished_callback: callable = None,
        cwd: str = os.path.expanduser("~"),
        server_port: int = -1,
    ):
        self.commands = commands
        self.server_port = server_port
        start_time = time.strftime("%Y%m%d_%H%M%S")
        self.name = name
        if name == "isabelle_server":
            os.makedirs(f'logs/{name}/{start_time}_logs', exist_ok=True)
            self.logger = logging.getLogger(f'{name}-{server_port}')
            handler = logging.FileHandler(f"logs/{name}/{start_time}_logs/rank_{server_port}.log")
        else:
            self.logger = logging.getLogger(name)
            handler = logging.FileHandler(U.f_join(log_path, f"{start_time}.log"))
        formatter = logging.Formatter(
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
        )
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)
        self.logger.setLevel(logging.INFO)
        self.process = None
        self.ready_match = ready_match
        self.ready_event = None
        self.ready_line = None
        self.callback_match = callback_match
        self.callback = callback
        self.finished_callback = finished_callback
        self.thread = None
        self.cwd = cwd

    def _start(self):
        self.logger.info(f"Starting subprocess with commands: {self.commands}")
        import pdb 
        print(self.commands, self.cwd)
        self.process = psutil.Popen(
            self.commands,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            stdin=subprocess.PIPE,
            universal_newlines=True,
            cwd=self.cwd
        )
        # pdb.set_trace()
        print(f"Subprocess {self.name} started with PID {self.process.pid}.")
        for line in iter(self.process.stdout.readline, ""):
            self.logger.info(line.strip())
            if re.search(self.ready_match, line):
                self.ready_line = line
                self.logger.info("Subprocess is ready.")
                print("Subprocess is ready.")
                self.ready_event.set()
                if "chroma" in self.name:
                    break
            if re.search(self.callback_match, line):
                self.callback()
        if not self.ready_event.is_set():
            self.ready_event.set()
            warnings.warn(f"Subprocess {self.name} failed to start.")
        if self.finished_callback:
            self.finished_callback()

    def run(self):
        self.ready_event = threading.Event()
        self.ready_line = None
        self.thread = threading.Thread(target=self._start)
        self.thread.start()
        self.ready_event.wait()

    def stop(self):
        self.logger.info("Stopping subprocess.")
        if self.process and self.process.is_running():
            self.process.terminate()
            self.process.wait()
    
    def terminate(self):
        parent = psutil.Process(self.process.pid)
        for child in parent.children(recursive=True):  # or parent.children() for recursive=False
            child.kill()
        parent.kill()

    def run_action(self, inputs):
        self.logger.info(f"Input: {inputs}")
        self.process.stdin.write(inputs + '\n')
        self.process.stdin.flush()

        for line in iter(self.process.stdout.readline, ""):
            self.logger.info(line)
            if line.startswith('{"error'):
                return json.loads(line)

    @property
    def is_running(self):
        if self.process is None:
            return False
        return self.process.is_running()

In [18]:
server_port = 8051
isabelle_server = SP(
            commands=[
                "bash",
                "run_server.sh",
                str(server_port),
            ],
            name="isabelle_server",
            ready_match=r"Server is running. Press Ctrl-C to stop.",
            # log_path=U.f_join(self.log_path, "isabelle_server"),
            cwd=os.path.abspath("/root/Portal-to-ISAbelle"),
            server_port=server_port,
        )

In [19]:
isabelle_server.run()

['bash', 'run_server.sh', '8051'] /root/Portal-to-ISAbelle
Subprocess isabelle_server started with PID 23828.
Subprocess is ready.


In [20]:
!kill -9 14767

/usr/bin/sh: 1: kill: No such process



In [2]:
class Checker(object):
    """A modified version of the Draft, Sketch, Prove proof-checking client.
    (https://github.com/albertqjiang/draft_sketch_prove/blob/main/autoformalization/checker.py)

    This checker supports Isabelle2022 via the new version of PISA
    (https://albertqjiang.github.io/Portal-to-ISAbelle/).

    It supports checking a miniF2F-style proof via `check`.

    Finally, it replaces `sledgehammer` with a call to `normalhammer`.
    """

    def __init__(self, working_dir, isa_path, theory_file, port=9000):
        sys.path.append(os.environ["PISA_PATH"])
        try:
            from pisa_client import initialise_env

            self.initialise_env = initialise_env
            print(self.initialise_env)
        except:
            print("Set $PISA_PATH to /yourpath/to/Portal-to-ISAbelle/src/main/python")

        self.working_dir = working_dir
        self.isa_path = isa_path
        self.theory_file = theory_file
        self.port = port

    def _initialize(self):
        env = self.initialise_env(
            self.port,
            isa_path=self.isa_path,
            theory_file_path=self.theory_file,
            working_directory=self.working_dir,
        )
        return env

    def _exit(self, env):
        try:
            env.post("exit")
        except:
            print("env.post('exit') timed out")
            pass
        os.system(
            "ps aux | grep Isabelle | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1"
        )
        os.system(
            "ps aux | grep poly | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1"
        )

    def _parse_output(self, obs):
        """Parse the sledgehammer output, otherwise return an empty string"""
        if "<hammer>" in obs:
            output = obs.split("<hammer>")[0]
        else:
            output = ""
        return output

    def _run_step(self, step, i, tls_name, env):
        obs, reward, done, metadata = env.step_to_top_level_state(
            action=step, tls_name=tls_name, new_name="default_%d" % i
        )
        error = None
        if "error:" in obs or "Step error" in obs or "Unknown error" in obs:
            error = obs
        return obs, reward, done, metadata, error

    def _run_sledgehammer(self, step, i, tls_name, env):
        # First try heuristics
        for heuristic in [
            "by auto",
            "by simp",
            "by blast",
            "by fastforce",
            "by force",
            "by eval",
            "by presburger",
            "by sos",
            "by arith",
            "by linarith",
            "by (auto simp: field_simps)",
        ]:
            step_ = step.replace("normalhammer", heuristic)
            obs, reward, done, metadata, error = self._run_step(step_, i, tls_name, env)
            if error is None:
                obs = "%s <hammer> %s" % (heuristic, obs)
                return obs, reward, done, metadata, error
        # Try sledgehammer
        out = self._run_step(step, i, tls_name, env)
        return out

    def check(self, statement_and_proof):
        # Initialize environment
        env = self._initialize()
        env.initialise()

        # Wrap and parse theorem
        theory = Checker.wrap_theorem(statement_and_proof)
        steps = Checker.get_parsed(env, theory)

        result = self._check(env, steps)
        return result

    def _check(self, env, steps):
        done = False
        reason = ""
        success = False
        step_results = []
        tls_name = "default"
        for i, step in enumerate(steps):
            try:
                time0 = time.time()
                if "normalhammer" in step:
                    obs, reward, done, metadata, error = self._run_sledgehammer(
                        step, i, tls_name, env
                    )
                else:
                    obs, reward, done, metadata, error = self._run_step(
                        step, i, tls_name, env
                    )
                step_time = time.time() - time0
                step_results.append(
                    dict(
                        index=i,
                        step=step,
                        output=self._parse_output(obs),
                        step_time=step_time,
                    )
                )
                if error is not None:
                    reason = error
                    success = False
                    done = False
                    break
            except:
                # Timeout - end the proof attempt
                success = False
                done = False
                reason = "timeout (%d)" % len(step_results)
                step_results.append(dict(index=i, step=step, output=""))
                break

            # Change when successful
            tls_name = "default_%d" % i

        if done and reward == 1.0:
            success = True

        result = {
            "success": success,
            "reason": reason,
            "num_steps": len(steps),
            "last_step": len(step_results),
            "step_results": step_results,
            "theorem_and_proof": self.reconstruct(step_results) if success else "",
        }
        # Exit environment
        self._exit(env)
        return result

    @staticmethod
    def reconstruct(step_results):
        steps = []
        for step_result in step_results[1:]:
            if step_result["output"] != "":
                steps.append(step_result["output"].strip())
            else:
                steps.append(step_result["step"].strip())
        theorem_and_proof = "\n".join(steps)
        return theorem_and_proof

    @staticmethod
    def wrap_theorem(theorem):
        return (
            'theory Interactive imports HOL.HOL Complex_Main "HOL-Library.Code_Target_Numeral" "HOL-Library.Sum_of_Squares" "Symmetric_Polynomials.Vieta" "HOL-Computational_Algebra.Computational_Algebra" "HOL-Number_Theory.Number_Theory" \n begin\n%s'
            % theorem
        )

    @staticmethod
    def get_parsed(env, theory, tls_name="default"):
        # HACK: the parsing doesn't work well with `normalhammer`, so we replace
        # all hammer calls with sorry, then replace sorry to normalhammer after parsing.
        theory = theory.replace("sledgehammer", "sorry")
        theory = theory.replace("normalhammer", "sorry")

        steps = env.post(f"<parse text> ${theory}")
        steps = steps.split("<SEP>")
        steps = [s for s in steps if s.strip() != ""]
        # remove weird '$' step and whitespace steps
        steps = [s for s in steps if s != "$" and s.strip() != ""]
        steps = [s.replace("sorry", "normalhammer") for s in steps]
        return steps


In [3]:
import sys
import os
sys.path.append('../')
os.environ['PISA_PATH'] = '/root/Portal-to-ISAbelle/src/main/python'

checker = Checker(
    working_dir='/root/Isabelle2022/src/HOL/Examples',
    isa_path='/root/Isabelle2022',
    theory_file='/root/Isabelle2022/src/HOL/Examples/Interactive.thy',
    port=8051
)

<function initialise_env at 0x745883c5ae60>


In [4]:
theorem_and_sledgehammer_proof = """theorem gcd_lcm:
  assumes "gcd (n :: nat) 4 = 1" 
      and "lcm (n :: nat) 4 = 28"
  shows "n = 7"
proof -
  have c1: "1*28 = n*4" using assms
    sledgehammer
  then have c2: "n = 1*28/4"
    sledgehammer
  then show ?thesis
    sledgehammer
qed"""
result = checker.check(theorem_and_sledgehammer_proof)

print("\n==== Success: %s" % result['success'])
print("--- Complete proof:\n%s" % result['theorem_and_proof'])

----------Path to Isabelle source----------
/root/Isabelle2022
----------Path to Isabelle working directory----------
/root/Isabelle2022/src/HOL/Examples
----------Path to Isabelle theory file----------
/root/Isabelle2022/src/HOL/Examples/Interactive.thy
<initialise>
<parse text> $theory Interactive imports HOL.HOL Complex_Main "HOL-Library.Code_Target_Numeral" "HOL-Library.Sum_of_Squares" "Symmetric_Polynomials.Vieta" "HOL-Computational_Algebra.Computational_Algebra" "HOL-Number_Theory.Number_Theory" 
 begin
theorem gcd_lcm:
  assumes "gcd (n :: nat) 4 = 1" 
      and "lcm (n :: nat) 4 = 28"
  shows "n = 7"
proof -
  have c1: "1*28 = n*4" using assms
    sorry
  then have c2: "n = 1*28/4"
    sorry
  then show ?thesis
    sorry
qed
exit

==== Success: False
--- Complete proof:

