In [10]:
import json
import os
import time
import re
import warnings
from typing import List

import psutil
import subprocess
import logging
import threading

import lego_prover.utils as U


class SubprocessMonitor:
    def __init__(
        self,
        commands: List[str],
        name: str,
        ready_match: str = r".*",
        log_path: str = "logs",
        callback_match: str = r"^(?!x)x$",  # regex that will never match
        callback: callable = None,
        finished_callback: callable = None,
        cwd: str = os.path.expanduser("~"),
        server_port: int = -1,
    ):
        self.commands = commands
        self.server_port = server_port
        start_time = time.strftime("%Y%m%d_%H%M%S")
        self.name = name
        if name == "isabelle_server":
            os.makedirs(f'logs/{name}/{start_time}_logs', exist_ok=True)
            self.logger = logging.getLogger(f'{name}-{server_port}')
            handler = logging.FileHandler(f"logs/{name}/{start_time}_logs/rank_{server_port}.log")
        else:
            self.logger = logging.getLogger(name)
            handler = logging.FileHandler(U.f_join(log_path, f"{start_time}.log"))
        formatter = logging.Formatter(
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
        )
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)
        self.logger.setLevel(logging.INFO)
        self.process = None
        self.ready_match = ready_match
        self.ready_event = None
        self.ready_line = None
        self.callback_match = callback_match
        self.callback = callback
        self.finished_callback = finished_callback
        self.thread = None
        self.cwd = cwd

    def _start(self):
        self.logger.info(f"Starting subprocess with commands: {self.commands}")
        import pdb 
        print(self.commands, self.cwd)
        self.process = psutil.Popen(
            self.commands,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            stdin=subprocess.PIPE,
            universal_newlines=True,
            cwd=self.cwd
        )
        # pdb.set_trace()
        print(f"Subprocess {self.name} started with PID {self.process.pid}.")
        for line in iter(self.process.stdout.readline, ""):
            self.logger.info(line.strip())
            if re.search(self.ready_match, line):
                self.ready_line = line
                self.logger.info("Subprocess is ready.")
                print("Subprocess is ready.")
                self.ready_event.set()
                if "chroma" in self.name:
                    break
            if re.search(self.callback_match, line):
                self.callback()
        if not self.ready_event.is_set():
            self.ready_event.set()
            warnings.warn(f"Subprocess {self.name} failed to start.")
        if self.finished_callback:
            self.finished_callback()

    def run(self):
        self.ready_event = threading.Event()
        self.ready_line = None
        self.thread = threading.Thread(target=self._start)
        self.thread.start()
        self.ready_event.wait()

    def stop(self):
        self.logger.info("Stopping subprocess.")
        if self.process and self.process.is_running():
            self.process.terminate()
            self.process.wait()
    
    def terminate(self):
        parent = psutil.Process(self.process.pid)
        for child in parent.children(recursive=True):  # or parent.children() for recursive=False
            child.kill()
        parent.kill()

    def run_action(self, inputs):
        self.logger.info(f"Input: {inputs}")
        self.process.stdin.write(inputs + '\n')
        self.process.stdin.flush()

        for line in iter(self.process.stdout.readline, ""):
            self.logger.info(line)
            if line.startswith('{"error'):
                return json.loads(line)

    @property
    def is_running(self):
        if self.process is None:
            return False
        return self.process.is_running()
server_port = 8051
isabelle_server = SubprocessMonitor(
            commands=[
                "bash",
                "run_server.sh",
                str(server_port),
            ],
            name="isabelle_server",
            ready_match=r"Server is running. Press Ctrl-C to stop.",
            # log_path=U.f_join(self.log_path, "isabelle_server"),
            cwd=os.path.abspath("/root/Portal-to-ISAbelle"),
            server_port=server_port,
        )

isabelle_server.run()

class Checker(object):
    """A modified version of the Draft, Sketch, Prove proof-checking client.
    (https://github.com/albertqjiang/draft_sketch_prove/blob/main/autoformalization/checker.py)

    This checker supports Isabelle2022 via the new version of PISA
    (https://albertqjiang.github.io/Portal-to-ISAbelle/).

    It supports checking a miniF2F-style proof via `check`.

    Finally, it replaces `sledgehammer` with a call to `normalhammer`.
    """

    def __init__(self, working_dir, isa_path, theory_file, port=9000):
        sys.path.append(os.environ["PISA_PATH"])
        try:
            from pisa_client import initialise_env

            self.initialise_env = initialise_env
            print(self.initialise_env)
        except:
            print("Set $PISA_PATH to /yourpath/to/Portal-to-ISAbelle/src/main/python")

        self.working_dir = working_dir
        self.isa_path = isa_path
        self.theory_file = theory_file
        self.port = port

    def _initialize(self):
        env = self.initialise_env(
            self.port,
            isa_path=self.isa_path,
            theory_file_path=self.theory_file,
            working_directory=self.working_dir,
        )
        return env

    def _exit(self, env):
        try:
            env.post("exit")
        except:
            print("env.post('exit') timed out")
            pass
        os.system(
            "ps aux | grep Isabelle | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1"
        )
        os.system(
            "ps aux | grep poly | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1"
        )

    def _parse_output(self, obs):
        """Parse the sledgehammer output, otherwise return an empty string"""
        if "<hammer>" in obs:
            output = obs.split("<hammer>")[0]
        else:
            output = ""
        return output

    def _run_step(self, step, i, tls_name, env):
        obs, reward, done, metadata = env.step_to_top_level_state(
            action=step, tls_name=tls_name, new_name="default_%d" % i
        )
        error = None
        if "error:" in obs or "Step error" in obs or "Unknown error" in obs:
            error = obs
        return obs, reward, done, metadata, error

    def _run_sledgehammer(self, step, i, tls_name, env):
        # First try heuristics
        for heuristic in [
            "by auto",
            "by simp",
            "by blast",
            "by fastforce",
            "by force",
            "by eval",
            "by presburger",
            "by sos",
            "by arith",
            "by linarith",
            "by (auto simp: field_simps)",
        ]:
            step_ = step.replace("normalhammer", heuristic)
            obs, reward, done, metadata, error = self._run_step(step_, i, tls_name, env)
            if error is None:
                obs = "%s <hammer> %s" % (heuristic, obs)
                return obs, reward, done, metadata, error
        # Try sledgehammer
        out = self._run_step(step, i, tls_name, env)
        return out

    def check(self, statement_and_proof):
        # Initialize environment
        env = self._initialize()
        env.initialise()

        # Wrap and parse theorem
        theory = Checker.wrap_theorem(statement_and_proof)
        steps = Checker.get_parsed(env, theory)

        result = self._check(env, steps)
        return result

    def _check(self, env, steps):
        done = False
        reason = ""
        success = False
        step_results = []
        tls_name = "default"
        for i, step in enumerate(steps):
            try:
                time0 = time.time()
                if "normalhammer" in step:
                    obs, reward, done, metadata, error = self._run_sledgehammer(
                        step, i, tls_name, env
                    )
                else:
                    obs, reward, done, metadata, error = self._run_step(
                        step, i, tls_name, env
                    )
                step_time = time.time() - time0
                step_results.append(
                    dict(
                        index=i,
                        step=step,
                        output=self._parse_output(obs),
                        step_time=step_time,
                    )
                )
                if error is not None:
                    reason = error
                    success = False
                    done = False
                    break
            except:
                # Timeout - end the proof attempt
                success = False
                done = False
                reason = "timeout (%d)" % len(step_results)
                step_results.append(dict(index=i, step=step, output=""))
                break

            # Change when successful
            tls_name = "default_%d" % i

        if done and reward == 1.0:
            success = True

        result = {
            "success": success,
            "reason": reason,
            "num_steps": len(steps),
            "last_step": len(step_results),
            "step_results": step_results,
            "theorem_and_proof": self.reconstruct(step_results) if success else "",
        }
        # Exit environment
        self._exit(env)
        return result

    @staticmethod
    def reconstruct(step_results):
        steps = []
        for step_result in step_results[1:]:
            if step_result["output"] != "":
                steps.append(step_result["output"].strip())
            else:
                steps.append(step_result["step"].strip())
        theorem_and_proof = "\n".join(steps)
        return theorem_and_proof

    @staticmethod
    def wrap_theorem(theorem):
        return (
            'theory Interactive imports HOL.HOL Complex_Main "HOL-Library.Code_Target_Numeral" "HOL-Library.Sum_of_Squares" "Symmetric_Polynomials.Vieta" "HOL-Computational_Algebra.Computational_Algebra" "HOL-Number_Theory.Number_Theory" \n begin\n%s'
            % theorem
        )

    @staticmethod
    def get_parsed(env, theory, tls_name="default"):
        # HACK: the parsing doesn't work well with `normalhammer`, so we replace
        # all hammer calls with sorry, then replace sorry to normalhammer after parsing.
        theory = theory.replace("sledgehammer", "sorry")
        theory = theory.replace("normalhammer", "sorry")

        steps = env.post(f"<parse text> ${theory}")
        steps = steps.split("<SEP>")
        steps = [s for s in steps if s.strip() != ""]
        # remove weird '$' step and whitespace steps
        steps = [s for s in steps if s != "$" and s.strip() != ""]
        steps = [s.replace("sorry", "normalhammer") for s in steps]
        return steps

import sys
import os
sys.path.append('../')
os.environ['PISA_PATH'] = '/root/Portal-to-ISAbelle/src/main/python'

checker = Checker(
    working_dir='/root/Isabelle2022/src/HOL/Examples',
    isa_path='/root/Isabelle2022',
    theory_file='/root/Isabelle2022/src/HOL/Examples/Interactive.thy',
    port=8051
)

theorem_and_sledgehammer_proof = """theorem
  fixes x :: int
  assumes "even x"
  shows "odd (x + 5)"
proof -
  (* If $x$ is even, then by definition, $x$ can be expressed as $2k$ for some integer $k$. *)
  obtain k where "x = 2 * k" using `even x` unfolding even_def by blast
  (* Therefore, $x + 5 = 2k + 5$. *)
  have "x + 5 = 2 * k + 5" by (simp add: `x = 2 * k`)
  (* The sum of an even number and an odd number is always odd. *)
  have "odd (2 * k + 5)" unfolding odd_def
    sledgehammer
  (* Hence, $2k + 5$ is odd, which means that $x + 5$ is odd if $x$ is even. *)
  then show ?thesis using `x + 5 = 2 * k + 5` by simp
qed

"""


# """theorem gcd_lcm:
#   assumes "gcd (n :: nat) 4 = 1" 
#       and "lcm (n :: nat) 4 = 28"
#   shows "n = 7"
# proof -
#   have c1: "1*28 = n*4" using assms
#     sledgehammer
#   then have c2: "n = 1*28/4"
#     sledgehammer
#   then show ?thesis
#     sledgehammer
# qed"""
result = checker.check(theorem_and_sledgehammer_proof)

print("\n==== Success: %s" % result['success'])
print("--- Complete proof:\n%s" % result['theorem_and_proof'])

['bash', 'run_server.sh', '8051'] /root/Portal-to-ISAbelle
Subprocess isabelle_server started with PID 492459.
Subprocess is ready.
<function initialise_env at 0x7cbda2ec2440>
----------Path to Isabelle source----------
/root/Isabelle2022
----------Path to Isabelle working directory----------
/root/Isabelle2022/src/HOL/Examples
----------Path to Isabelle theory file----------
/root/Isabelle2022/src/HOL/Examples/Interactive.thy
<initialise>
<parse text> $theory Interactive imports HOL.HOL Complex_Main "HOL-Library.Code_Target_Numeral" "HOL-Library.Sum_of_Squares" "Symmetric_Polynomials.Vieta" "HOL-Computational_Algebra.Computational_Algebra" "HOL-Number_Theory.Number_Theory" 
 begin
theorem
  fixes x :: int
  assumes "even x"
  shows "odd (x + 5)"
proof -
  (* If $x$ is even, then by definition, $x$ can be expressed as $2k$ for some integer $k$. *)
  obtain k where "x = 2 * k" using `even x` unfolding even_def by blast
  (* Therefore, $x + 5 = 2k + 5$. *)
  have "x + 5 = 2 * k + 5" by (

['bash', 'run_server.sh', '8051'] /root/Portal-to-ISAbelle
Subprocess isabelle_server started with PID 50046.
Subprocess is ready.


<function initialise_env at 0x7cbda2ec2440>


----------Path to Isabelle source----------
/root/Isabelle2022
----------Path to Isabelle working directory----------
/root/Isabelle2022/src/HOL/Examples
----------Path to Isabelle theory file----------
/root/Isabelle2022/src/HOL/Examples/Interactive.thy
<initialise>
<parse text> $theory Interactive imports HOL.HOL Complex_Main "HOL-Library.Code_Target_Numeral" "HOL-Library.Sum_of_Squares" "Symmetric_Polynomials.Vieta" "HOL-Computational_Algebra.Computational_Algebra" "HOL-Number_Theory.Number_Theory" 
 begin
theorem gcd_lcm:
  assumes "gcd (n :: nat) 4 = 1" 
      and "lcm (n :: nat) 4 = 28"
  shows "n = 7"
proof -
  have c1: "1*28 = n*4" using assms
    sorry
  then have c2: "n = 1*28/4"
    sorry
  then show ?thesis
    sorry
qed
<apply to top level state> default <apply to top level state> theory Interactive imports HOL.HOL Complex_Main "HOL-Library.Code_Target_Numeral" "HOL-Library.Sum_of_Squares" "Symmetric_Polynomials.Vieta" "HOL-Computational_Algebra.Computational_Algebra" "HOL