In [1]:
import time
import sys
import os

In [2]:
import os
import time
import sys

class Checker(object):
    """A modified version of the Draft, Sketch, Prove proof-checking client.
    (https://github.com/albertqjiang/draft_sketch_prove/blob/main/autoformalization/checker.py)

    This checker supports Isabelle2022 via the new version of PISA
    (https://albertqjiang.github.io/Portal-to-ISAbelle/).

    It supports checking a miniF2F-style proof via `check`.

    Finally, it replaces `sledgehammer` with a call to `normalhammer`.
    """
    def __init__(self, working_dir, isa_path, theory_file_path, port=9000):
        sys.path.append(os.environ.get('PISA_PATH', ''))
        try:
            from pisa_client import initialise_env
            self.initialise_env = initialise_env
        except ImportError:
            print("Set $PISA_PATH to /yourpath/to/Portal-to-ISAbelle/src/main/python")

        self.working_dir = working_dir
        self.isa_path = isa_path
        self.theory_file_path = theory_file_path
        self.port = port

    def _initialize(self):
        """Initialize the PISA environment."""
        env = self.initialise_env(
            self.port,
            isa_path=self.isa_path,
            theory_file_path=self.theory_file_path,
            working_directory=self.working_dir
        )
        return env

    def _exit(self, env):
        """Exit the environment and clean up resources."""
        try:
            env.post('exit')
        except Exception:
            pass
        os.system("ps aux | grep Isabelle | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1")
        os.system("ps aux | grep poly | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1")

    def _parse_output(self, obs):
        """Parse the sledgehammer output, returning the relevant part."""
        return obs.split('<hammer>')[0] if '<hammer>' in obs else ''

    def _run_step(self, step, i, tls_name, env):
        """Run a single proof step."""
        try:
            obs, reward, done, metadata = env.step_to_top_level_state(
                action=step,
                tls_name=tls_name,
                new_name=f'default_{i}'
            )
            return obs, reward, done, metadata, None
        except Exception as e:
            return '', 0, False, None, str(e)

    def _run_sledgehammer(self, step, i, tls_name, env):
        """Run sledgehammer or fallback heuristics on a step."""
        heuristics = [
            'by auto', 'by simp', 'by blast', 'by fastforce',
            'by force', 'by eval', 'by presburger', 'by sos',
            'by arith', 'by linarith', 'by (auto simp: field_simps)'
        ]
        for heuristic in heuristics:
            step_ = step.replace('normalhammer', heuristic)
            obs, reward, done, metadata, error = self._run_step(step_, i, tls_name, env)
            if error is None:
                obs = f'{heuristic} <hammer> {obs}'
                return obs, reward, done, metadata, error
        return self._run_step(step.replace("normalhammer", "sledgehammer"), i, tls_name, env)

    def check(self, statement_and_proof):
        """Check the given proof."""
        env = self._initialize()
        env.initialise()

        theory = self.wrap_theorem(statement_and_proof)
        steps = self.get_parsed(env, theory)

        result = self._check(env, steps)
        self._exit(env)

        # Output the result
        print("\n==== Success: %s" % result['success'])
        print("--- Complete proof:\n%s" % result['theorem_and_proof'])
        return result

    def _check(self, env, steps):
        """Run the proof steps and collect results."""
        success, reason, done = False, '', False
        step_results = []
        tls_name = 'default'

        for i, step in enumerate(steps):
            time0 = time.time()
            if 'normalhammer' in step or 'sledgehammer' in step:
                obs, reward, done, metadata, error = self._run_sledgehammer(step, i, tls_name, env)
            else:
                obs, reward, done, metadata, error = self._run_step(step, i, tls_name, env)

            step_time = time.time() - time0
            step_results.append({
                'index': i, 'step': step, 
                'output': self._parse_output(obs), 
                'step_time': step_time
            })

            if error:
                reason = error
                break
            tls_name = f'default_{i}'

        success = done and reward == 1.0
        return {
            'success': success,
            'reason': reason,
            'num_steps': len(steps),
            'last_step': len(step_results),
            'step_results': step_results,
            'theorem_and_proof': self.reconstruct(step_results) if success else ''
        }

    @staticmethod
    def reconstruct(step_results):
        """Reconstruct the complete proof."""
        return '\n'.join(
            step_result['output'].strip() if step_result['output'] else step_result['step'].strip()
            for step_result in step_results[1:]
        )

    @staticmethod
    def wrap_theorem(theorem):
        """Wrap the theorem in a theory file."""
        return (
            'theory Interactive imports HOL.HOL Complex_Main '
            '"HOL-Library.Code_Target_Numeral" "HOL-Library.Sum_of_Squares" '
            '"Symmetric_Polynomials.Vieta" "HOL-Computational_Algebra.Computational_Algebra" '
            '"HOL-Number_Theory.Number_Theory" \n begin\n%s' % theorem
        )

    @staticmethod
    def get_parsed(env, theory):
        """Parse the theory and extract proof steps."""
        raw_steps = env.post(f"<parse text> ${theory}")
        steps = [s.strip() for s in raw_steps.split('<SEP>') if s.strip() and s != '$']
        processed_steps = []
        for i, step in enumerate(steps):
            if step.lower() == "then" and (i == 0 or steps[i - 1].startswith("proof")):
                continue
            processed_steps.append(step)
        return processed_steps


In [3]:
sys.path.append('../')
os.environ['PISA_PATH'] = '/home/siai/Portal-to-ISAbelle/src/main/python'

import dsp_utils

checker = Checker(
    working_dir='/home/siai/Isabelle2022/src/HOL/Examples',
    isa_path='/home/siai/Isabelle2022',
    theory_file_path='/home/siai/Isabelle2022/src/HOL/Examples/Interactive.thy',
    port=9000
)


In [14]:
theorem_and_sledgehammer_proof = """theorem amc12a_2008_p8:
  fixes x y::real
  assumes h0: "0 < x \<and> 0 < y"
    and h1: "y^3 = 1"
    and h2: "6 * x^2 = 2 * (6 * y^2)"
  shows "x^3 = 2 * sqrt 2"
  using assms
  by (smt (verit, best) mult_cancel_left2 one_power2 
      power2_eq_square power2_le_imp_le 
      power2_sum power3_eq_cube power_Suc_less 
      power_commutes power_gt1_lemma real_le_lsqrt 
      real_le_rsqrt)
"""

result = checker.check(theorem_and_sledgehammer_proof)



==== Success: True
--- Complete proof:
theorem amc12a_2008_p8:
  fixes x y::real
  assumes h0: "0 < x \<and> 0 < y"
    and h1: "y^3 = 1"
    and h2: "6 * x^2 = 2 * (6 * y^2)"
  shows "x^3 = 2 * sqrt 2"
using assms
by (smt (verit, best) mult_cancel_left2 one_power2 
      power2_eq_square power2_le_imp_le 
      power2_sum power3_eq_cube power_Suc_less 
      power_commutes power_gt1_lemma real_le_lsqrt 
      real_le_rsqrt)


In [19]:

theorem_and_sledgehammer_proof = """

lemma "alpha (trans_r R as) = {b | there is a in as such that (a, b) in trans_r as}"
proof-
  let ?as = "as";
  let ?R = "R";
  
  assume "trans_r R as";
  from ?R have "trans_r ?R as" using trans_r_less_trans_reverses
    using trans_r_trans_reverses
  with ?as have "b' \<in> ?R" by (rule trans_r_def)
  therefore "b' = ?as" by (unfold trans_r_def, blast);
  with ?as have "b \<in> ?rs ?as" by (rule trans_r_def, intro rev_image_trans_r_inv R_def) (auto)
  then have "b \<in> ?as" by (unfold trans_r_def, simp)
  hence "b' = ?as" by (unfold trans_r_def, blast)
  thus ?thesis by (rule refl)
qed

"""

result = checker.check(theorem_and_sledgehammer_proof)
print(result['success'])






==== Success: False
--- Complete proof:

False


In [20]:
new = result.get("success")
print(new)
rewards = []

False


In [21]:
if result.get("success", False):
    rewards.append(1.0)
else:
    rewards.append(0.0)

In [22]:
print(rewards)

[0.0]
