In [1]:
import os
import json

import pprint as pp

import gymnasium as gym
import logic_gym
import graphviz

MAX_THEOREMS = 100
ONE_HOP_THEOREMS_FOLDER = "dataprep"
ONE_HOP_THEOREMS = "1hop_trueontology.json"


class TheoremsManipulator:
    def __init__(self):
        self.tasks = {}
        self.english_theorems = {}
        self.file_path = os.path.join(ONE_HOP_THEOREMS_FOLDER, ONE_HOP_THEOREMS)
        with open(self.file_path, "r") as file:
            json_theorems = json.load(file)
            self.json_theorems = json_theorems

            for key in json_theorems:
                # print(json_theorems[key]['test_example'])
                english_theorems = json_theorems[key]['test_example']["question"]
                english_queries = json_theorems[key]['test_example']["query"]
                self.english_theorems[key] = english_theorems.split(". ")
                self.english_theorems[key].extend(english_queries.split(". "))
                
                if "formalized" in json_theorems[key]['test_example']:
                    self.tasks[key] = (json_theorems[key]['test_example']['formalized'])
                    

    def get_theorem(self, theorem = 1):
        if f"example{theorem}" in self.tasks:
            return self.tasks[f"example{theorem}"]
        else:
            return None
        
    def get_english_theorem(self, theorem = 1):
        if f"example{theorem}" in self.english_theorems:
            return self.english_theorems[f"example{theorem}"]
        else:
            return None

    def set_proved(self, theorem = 1, proved = False):
        self.json_theorems[f"example{theorem}"]['test_example']["proved"] = proved
        
    def get_proved(self, theorem = 1):
        if "proved" in self.json_theorems[f"example{theorem}"]['test_example']:
            return self.json_theorems[f"example{theorem}"]['test_example']["proved"]
        return False
    
    
        
        
    def set_proof(self, theorem = 1, proof = None):
        self.json_theorems[f"example{theorem}"]['test_example']["proof"] = proof
    
    def write_back_updated_file(self):
        with open(self.file_path, "w") as file:
            json.dump(self.json_theorems, file, indent=4)



In [2]:
get_theorem = TheoremsManipulator()

THEOREM_NUMBER = 12

theorem = get_theorem.get_theorem(THEOREM_NUMBER)
english_theorem = get_theorem.get_english_theorem(THEOREM_NUMBER)

print("\n".join(english_theorem))
print("-------------------------------------")
print(theorem)


Mersenne primes are prime
Composite numbers are not prime
Mersenne primes are prime numbers
Prime numbers are prime
Prime numbers are natural numbers
Each natural number is positive
Natural numbers are integers
Integers are real numbers
Each real number is real
Real numbers are numbers
131071 is a Mersenne prime.
True or false: 131071 is not prime.
-------------------------------------
# Predicates
class MersennePrime(Relation):
    "Is x a Mersenne Prime?"
    def __init__(self, *args):  # seq of args
        Relation.__init__(self, *args)

class Prime(Relation):
    "Is x prime?"
    def __init__(self, *args):  # seq of args
        Relation.__init__(self, *args)

class Composite(Relation):
    "Is x composite?"
    def __init__(self, *args):  # seq of args
        Relation.__init__(self, *args)

class PrimeNumber(Relation):
    "Is x a prime number?"
    def __init__(self, *args):  # seq of args
        Relation.__init__(self, *args)

class NaturalNumber(Relation):
    "Is x a natur

In [3]:
dot = graphviz.Digraph()

# Adding nodes
for level in range(1, 6):
    dot.node(str(level), f"Node {level}")

# Adding edges
# dot.edges(['12', '23', '34', '45', '51'])
dot.edge("1", "2", "12")
dot.edge("2", "3")
dot.edge("3", "4")
dot.edge("4", "5")
dot.edge("5", "1")

# Render the graph
dot.render("sample_graph", format="png", view=False)

'sample_graph.png'

In [4]:
class Node:
    def __init__(
        self,
        index,
        level,
        observation_before,
        action,
        observation_after,
        is_done,
        terminated,
        truncated,
        proof,
        parent=None,
        children= None,
        reward=0,
        logic_gym_state=None,
    ):
        self.index = index
        self.level = level
        self.observation_before = observation_before
        self.action = action
        self.observation_after = observation_after
        self.is_done = is_done
        self.terminated = terminated
        self.truncated = truncated
        self.proof = proof
        self.parent = parent
        self.children = children if children is not None else []
        self.reward = reward
        self.logic_gym_state=logic_gym_state


    def add_child(self, child_node):
        self.children.append(child_node)
        child_node.parent = self
    def __str__(self) -> str:
        return f"Node {self.index} at level {self.level} with action {self.action} and reward {self.reward}"

In [5]:
def prove_theorem_by_search(env, depth = 3, exit_on_proof = False):
    root = Node(
        index=0,
        level=0,
        observation_before=None,
        action=None,
        observation_after=env.reset(),
        is_done=False,
        terminated=False,
        truncated=False,
        proof=None,
    )
    root.logic_gym_state = env.unwrapped.get_state()
    nodes = [root]
    nodes_by_level = {level: [] for level in range(depth + 1)}
    nodes_by_level[0] = [root]
    
    index = 1
    
    proved = False
    proof = None
    
    for level in range(depth-1):
        for node in nodes_by_level[level]:
            env.unwrapped.set_state(node.logic_gym_state)
            for action in env.unwrapped.get_all_actions():
                # print(node.index, node.level, action)
                # print(node.logic_gym_state)
                # print(node.index, node.level)
                # print(node.logic_gym_state["state"])
                # print("--------------------------------------")
                observation, reward, terminated,truncated, info, next_state = env.unwrapped.step_and_step_back(action)
                if info["bad_action"]: # check if the action is valid
                    continue
                
                is_done = terminated or truncated
                state = env.unwrapped.get_state()   
                child = Node(
                    index=index,
                    level=node.level + 1,
                    observation_before=node.observation_after,
                    action=action,
                    observation_after=observation,
                    is_done=is_done,
                    terminated=terminated,
                    truncated=truncated,
                    proof=next_state["state"],
                    reward=reward,
                    logic_gym_state=next_state,
                    parent=node,
                )
                if terminated:
                    print("proof found")
                    print(child.proof)
                    proved = True
                    proof = child.proof
                    
                    
                    
                index += 1
                node.add_child(child)
                nodes_by_level[level+1].append(child)
                if is_done and exit_on_proof:
                    break
            if is_done and exit_on_proof:
                break
        if is_done and exit_on_proof:
            break
    return root, nodes_by_level, proved, proof
    

    
    

In [6]:
env = gym.make("logic_gym/LogicGym-v0")


for x in range(MAX_THEOREMS):
    theorem = get_theorem.get_theorem(x+1)
    if theorem is not None:
        if not get_theorem.get_proved(x+1):
            env = gym.make("logic_gym/LogicGym-v0")
            env.unwrapped.set_task(theorem)
            print(f"Proving theorem {x+1}")
            try:
                root, nodes_by_level, proved, proof = prove_theorem_by_search(env, depth=4, exit_on_proof=True)
            except Exception as e:
                print(f"Error in proving theorem {x+1}")
                print(e)
                proved = False
                proof = None
            get_theorem.set_proved(x+1, proved)
            get_theorem.set_proof(x+1, proof)


get_theorem.write_back_updated_file()









Proving theorem 21
proof found
Ax.(MersennePrime(x) -> Prime(x))                             (0)  Given
Ax.(MersennePrime(x) -> PrimeNumber(x))                       (1)  Given
Ax.(PrimeNumber(x) -> Prime(x))                               (2)  Given
Ax.(PrimeNumber(x) -> NaturalNumber(x))                       (3)  Given
Ax.(NaturalNumber(x) -> ~Negative(x))                         (4)  Given
Ax.(NaturalNumber(x) -> Integer(x))                           (5)  Given
Ax.(CompositeNumber(x) -> ~Prime(x))                          (6)  Given
Ax.(Integer(x) -> RealNumber(x))                              (7)  Given
Ax.(RealNumber(x) -> Real(x))                                 (8)  Given
Ax.(RealNumber(x) -> Number(x))                               (9)  Given
MersennePrime(3)                                             (10)  Given
Prime(3)                                                     (11)  Goal
MersennePrime(3) -> Prime(3)                                 (12)  A-Elimination (0), with 3
P

In [7]:
def BFS(root, dot):
    queue = [root]
    while queue:
        node = queue.pop(0)
        dot.node(str(node.index), f"Node {node.index}, terminal: {node.terminated}")
        if node.parent is not None:
            dot.edge(str(node.parent.index), str(node.index), f"Action: {node.action}")
        print(node)
        queue.extend(node.children)
