# **Formal Prover firs overview**

# Preparations

## Lexer

In [27]:
from enum import Enum, auto
from typing import Tuple, NamedTuple, Any
from collections import deque, namedtuple

class TokenType(Enum):
    VAR = auto()        # Variable (e.g., P, Q)
    AND = auto()        # Logical AND
    OR = auto()         # Logical OR
    NOT = auto()        # Logical NOT
    IMPLIES = auto()    # Logical IMPLIES (e.g., ->)
    IFF = auto()        # Logical IFF (e.g., <->, biconditional)
    LPAREN = auto()     # (
    RPAREN = auto()     # )
    EOF = auto()        # End of formula (End of File/Input)

class Token(NamedTuple):
    type: TokenType
    value: Any = None

In [28]:
from typing import List, Set, Optional
import re

class Lexer:
    def __init__(self, text: str):
        self.text = text
        # Define token patterns: longer ones first for operators like "->" vs "<->"
        # For simplicity, we'll use keywords like AND, OR, IMPLIES, IFF
        # and single characters for parentheses.
        token_specification = [
            ('VAR',     r'[A-Za-z_][A-Za-z0-9_]*'), # Variables (must not be keywords)
            ('IMPLIES', r'IMPLIES'),
            ('IFF',     r'IFF'),
            ('AND',     r'AND'),
            ('OR',      r'OR'),
            ('NOT',     r'NOT'),
            ('LPAREN',  r'\('),
            ('RPAREN',  r'\)'),
            ('SKIP',    r'[ \t]+'),   # Skip spaces and tabs
            ('MISMATCH',r'.'),         # Any other character
        ]
        # Keywords must be checked after VAR pattern if VAR can match keywords
        self.keywords = {
            "AND": TokenType.AND,
            "OR": TokenType.OR,
            "NOT": TokenType.NOT,
            "IMPLIES": TokenType.IMPLIES,
            "IFF": TokenType.IFF,
        }
        
        tok_regex = '|'.join('(?P<%s>%s)' % pair for pair in token_specification)
        self.tokens: List[Token] = []
        for mo in re.finditer(tok_regex, self.text):
            kind = mo.lastgroup
            value = mo.group()

            if kind == 'VAR':
                if value.upper() in self.keywords: # Check if it's a keyword
                    self.tokens.append(Token(self.keywords[value.upper()], value.upper()))
                else:
                    self.tokens.append(Token(TokenType.VAR, value))
            elif kind in ['AND', 'OR', 'NOT', 'IMPLIES', 'IFF']: # Direct keyword match
                 self.tokens.append(Token(self.keywords[value.upper()], value.upper()))
            elif kind == 'LPAREN':
                self.tokens.append(Token(TokenType.LPAREN, value))
            elif kind == 'RPAREN':
                self.tokens.append(Token(TokenType.RPAREN, value))
            elif kind == 'SKIP':
                continue
            elif kind == 'MISMATCH':
                raise ValueError(f"Lexer error: Unexpected character '{value}'")
        self.tokens.append(Token(TokenType.EOF))
        self.token_idx = 0

    def get_next_token(self) -> Token:
        if self.token_idx < len(self.tokens):
            token = self.tokens[self.token_idx]
            self.token_idx += 1
            return token
        # Should ideally not be reached if EOF is always last and parser checks for it
        return Token(TokenType.EOF) 

    def peek_token(self) -> Token:
        if self.token_idx < len(self.tokens):
            return self.tokens[self.token_idx]
        return Token(TokenType.EOF)

## Parser

In [29]:
class ASTNode:
    """Base class for AST nodes."""
    def __repr__(self):
        return f"{self.__class__.__name__}(...)"

class Variable(ASTNode):
    def __init__(self, name: str):
        self.name = name
    def __repr__(self):
        return f"Variable(name='{self.name}')"

class UnaryOp(ASTNode):
    def __init__(self, op: Token, operand: ASTNode):
        self.op_token = op # The token for the operator (e.g., NOT)
        self.op = op.type # The TokenType
        self.operand = operand
    def __repr__(self):
        return f"UnaryOp(op='{self.op_token.value}', operand={self.operand})"

class BinaryOp(ASTNode):
    def __init__(self, left: ASTNode, op: Token, right: ASTNode):
        self.left = left
        self.op_token = op # The token for the operator (e.g., AND, OR)
        self.op = op.type  # The TokenType
        self.right = right
    def __repr__(self):
        return f"BinaryOp(left={self.left}, op='{self.op_token.value}', right={self.right})"

In [30]:
class Parser:
    def __init__(self, lexer: Lexer):
        self.lexer = lexer
        self.current_token: Token = self.lexer.get_next_token()

    def _eat(self, token_type: TokenType):
        if self.current_token.type == token_type:
            self.current_token = self.lexer.get_next_token()
        else:
            raise SyntaxError(
                f"Parser error: Expected token {token_type}, "
                f"but got {self.current_token.type} (value: '{self.current_token.value}')"
            )

    def _atom(self) -> ASTNode:
        """atom : VAR | LPAREN expression RPAREN"""
        token = self.current_token
        if token.type == TokenType.VAR:
            self._eat(TokenType.VAR)
            return Variable(token.value)
        elif token.type == TokenType.LPAREN:
            self._eat(TokenType.LPAREN)
            node = self._expression()
            self._eat(TokenType.RPAREN)
            return node
        else:
            raise SyntaxError(f"Parser error (_atom): Expected VAR or LPAREN, got {token.type}")

    def _negation(self) -> ASTNode:
        """negation : NOT negation | atom"""
        token = self.current_token
        if token.type == TokenType.NOT:
            self._eat(TokenType.NOT)
            # Note: NOT applies to the result of the next _negation call,
            # allowing for NOT NOT P or NOT (P AND Q)
            node = self._negation() 
            return UnaryOp(token, node)
        else:
            return self._atom()

    def _conjunction(self) -> ASTNode:
        """conjunction : negation (AND negation)*"""
        node = self._negation()
        while self.current_token.type == TokenType.AND:
            op_token = self.current_token
            self._eat(TokenType.AND)
            node = BinaryOp(node, op_token, self._negation())
        return node

    def _disjunction(self) -> ASTNode:
        """disjunction : conjunction (OR conjunction)*"""
        node = self._conjunction()
        while self.current_token.type == TokenType.OR:
            op_token = self.current_token
            self._eat(TokenType.OR)
            node = BinaryOp(node, op_token, self._conjunction())
        return node

    def _implication(self) -> ASTNode:
        """implication : disjunction (IMPLIES disjunction)*"""
        # This implements left-associative IMPLIES.
        # P IMPLIES Q IMPLIES R  ->  ((P IMPLIES Q) IMPLIES R)
        # Often, IMPLIES is right-associative: P IMPLIES (Q IMPLIES R)
        # For right-associativity, the rule would be:
        # implication : disjunction (IMPLIES implication)?
        node = self._disjunction()
        while self.current_token.type == TokenType.IMPLIES:
            op_token = self.current_token
            self._eat(TokenType.IMPLIES)
            node = BinaryOp(node, op_token, self._disjunction()) # For left-associativity
            # For right-associativity: node = BinaryOp(node, op_token, self._implication())
        return node
        
    def _expression(self) -> ASTNode:
        """expression : implication (IFF implication)*"""
        # IFF is also often non-associative or has specific rules.
        # Here, implemented as left-associative for simplicity.
        node = self._implication()
        while self.current_token.type == TokenType.IFF:
            op_token = self.current_token
            self._eat(TokenType.IFF)
            node = BinaryOp(node, op_token, self._implication())
        return node

    def parse(self) -> ASTNode:
        """Parses the full input string and returns the AST root."""
        ast_root = self._expression()
        if self.current_token.type != TokenType.EOF:
            raise SyntaxError(
                f"Parser error: Unexpected token {self.current_token} "
                f"(value: '{self.current_token.value}') at end of input. Expected EOF."
            )
        return ast_root

In [31]:
formula1_str = "P AND (Q OR NOT R)"
formula2_str = "P IMPLIES (Q AND P)"
formula_tautology_str = "A OR NOT A"  # Это тавтология
formula_contradiction_str = "A AND NOT A" # Это противоречие (не тавтология)
formula_complex_tautology = "(P IMPLIES Q) IFF (NOT P OR Q)" # Тавтология (определение импликации)
malformed_formula = "P AND (Q OR" # Синтаксически некорректная

In [32]:
formula = "P AND (Q OR NOT R)"
lexer = Lexer(formula)
parser = Parser(lexer)
ast = parser.parse()
print(ast)

BinaryOp(left=Variable(name='P'), op='AND', right=BinaryOp(left=Variable(name='Q'), op='OR', right=UnaryOp(op='NOT', operand=Variable(name='R'))))


## Logics Engine

In [33]:
class LogicsEngine:
    def __init__(self, axioms: Set[str]):
        self.known_theorems_str: Set[str] = set()
        for axiom_str in axioms:
            try:
                ast = self._parse_formula_string(axiom_str)
                canonical_str = self._ast_to_canonical_string(ast)
                if not canonical_str:
                    raise ValueError("Received empty canonical string for an axiom.")
                self.known_theorems_str.add(canonical_str)
            except (ValueError, SyntaxError) as e:
                raise ValueError(f"Initialization error: Invalid axiom '{axiom_str}': {e}")

    def _parse_formula_string(self, formula_str: str) -> ASTNode:
        if not formula_str.strip():
            raise ValueError("Cannot parse an empty formula string.")
        lexer = Lexer(formula_str)
        parser = Parser(lexer)
        return parser.parse()

    def _ast_to_canonical_string_recursive(self, node: ASTNode, is_operand: bool) -> str:
        if isinstance(node, Variable):
            return node.name
        elif isinstance(node, UnaryOp):
            operand_s = self._ast_to_canonical_string_recursive(node.operand, True)
            # Ensure NOT has its operand potentially bracketed, and the whole NOT expression is bracketed if it's an operand
            res_str = f"{node.op_token.value} {operand_s}"
            return f"({res_str})" if is_operand else res_str
        elif isinstance(node, BinaryOp):
            left_s = self._ast_to_canonical_string_recursive(node.left, True)
            right_s = self._ast_to_canonical_string_recursive(node.right, True)
            expr_str = f"{left_s} {node.op_token.value} {right_s}"
            return f"({expr_str})" if is_operand else expr_str
        else:
            raise TypeError(f"Unknown AST node type for string conversion: {type(node)}")

    def _ast_to_canonical_string(self, node: ASTNode) -> str:
        return self._ast_to_canonical_string_recursive(node, False)

    def add_theorem_ast(self, formula_ast: ASTNode) -> bool:
        canonical_str = self._ast_to_canonical_string(formula_ast)
        if canonical_str not in self.known_theorems_str:
            self.known_theorems_str.add(canonical_str)
            return True
        return False
    
    def _add_and_return_canonical(self, formula_ast: ASTNode) -> str:
        """Helper to add theorem and return its canonical string."""
        self.add_theorem_ast(formula_ast)
        return self._ast_to_canonical_string(formula_ast)

    def is_known(self, formula_str: str) -> bool:
        try:
            ast = self._parse_formula_string(formula_str)
            canonical_str = self._ast_to_canonical_string(ast)
            return canonical_str in self.known_theorems_str
        except (ValueError, SyntaxError):
            return False

    def get_known_formulas(self) -> Set[str]:
        return self.known_theorems_str.copy()

    # --- Inference Rules ---

    def apply_modus_ponens(self, premise_p_str: str, premise_p_implies_q_str: str) -> Optional[str]:
        if not (self.is_known(premise_p_str) and self.is_known(premise_p_implies_q_str)):
            return None
        try:
            ast_p = self._parse_formula_string(premise_p_str)
            ast_p_implies_q = self._parse_formula_string(premise_p_implies_q_str)
        except (ValueError, SyntaxError): return None

        # Case 1: ast_p_implies_q is (P IMPLIES Q), ast_p is P
        if isinstance(ast_p_implies_q, BinaryOp) and ast_p_implies_q.op_token.type == TokenType.IMPLIES:
            antecedent = ast_p_implies_q.left
            consequent = ast_p_implies_q.right
            if self._ast_to_canonical_string(antecedent) == self._ast_to_canonical_string(ast_p):
                return self._add_and_return_canonical(consequent)
        
        # Case 2: ast_p is (P IMPLIES Q), ast_p_implies_q is P (swapped arguments)
        if isinstance(ast_p, BinaryOp) and ast_p.op_token.type == TokenType.IMPLIES:
            antecedent = ast_p.left
            consequent = ast_p.right
            if self._ast_to_canonical_string(antecedent) == self._ast_to_canonical_string(ast_p_implies_q):
                 return self._add_and_return_canonical(consequent)
        return None

    def apply_modus_tollens(self, premise_p_implies_q_str: str, premise_not_q_str: str) -> Optional[str]:
        """Applies Modus Tollens: (P IMPLIES Q), (NOT Q) |- (NOT P)"""
        if not (self.is_known(premise_p_implies_q_str) and self.is_known(premise_not_q_str)):
            return None
        try:
            ast_p_implies_q = self._parse_formula_string(premise_p_implies_q_str)
            ast_not_q = self._parse_formula_string(premise_not_q_str)
        except (ValueError, SyntaxError): return None

        # Check structure of (P IMPLIES Q)
        if not (isinstance(ast_p_implies_q, BinaryOp) and ast_p_implies_q.op_token.type == TokenType.IMPLIES):
            return None
        
        p_ast = ast_p_implies_q.left  # This is P
        q_ast = ast_p_implies_q.right # This is Q

        # Check structure of (NOT Q)
        if not (isinstance(ast_not_q, UnaryOp) and ast_not_q.op_token.type == TokenType.NOT):
            return None
        
        q_from_not_q_ast = ast_not_q.operand # This is Q from (NOT Q)

        # Verify that Q from (P IMPLIES Q) is the same as Q from (NOT Q)
        if self._ast_to_canonical_string(q_ast) == self._ast_to_canonical_string(q_from_not_q_ast):
            # Construct (NOT P)
            not_p_ast = UnaryOp(Token(TokenType.NOT, "NOT"), p_ast)
            return self._add_and_return_canonical(not_p_ast)
        return None

    def apply_and_introduction(self, premise_p_str: str, premise_q_str: str) -> Optional[str]:
        """Applies And Introduction: P, Q |- (P AND Q)"""
        if not (self.is_known(premise_p_str) and self.is_known(premise_q_str)):
            return None
        try:
            ast_p = self._parse_formula_string(premise_p_str)
            ast_q = self._parse_formula_string(premise_q_str)
        except (ValueError, SyntaxError): return None

        # Construct (P AND Q)
        p_and_q_ast = BinaryOp(ast_p, Token(TokenType.AND, "AND"), ast_q)
        return self._add_and_return_canonical(p_and_q_ast)

    def apply_and_elimination1(self, premise_p_and_q_str: str) -> Optional[str]:
        """Applies And Elimination 1: (P AND Q) |- P"""
        if not self.is_known(premise_p_and_q_str): return None
        try:
            ast_p_and_q = self._parse_formula_string(premise_p_and_q_str)
        except (ValueError, SyntaxError): return None

        if isinstance(ast_p_and_q, BinaryOp) and ast_p_and_q.op_token.type == TokenType.AND:
            p_ast = ast_p_and_q.left
            return self._add_and_return_canonical(p_ast)
        return None

    def apply_and_elimination2(self, premise_p_and_q_str: str) -> Optional[str]:
        """Applies And Elimination 2: (P AND Q) |- Q"""
        if not self.is_known(premise_p_and_q_str): return None
        try:
            ast_p_and_q = self._parse_formula_string(premise_p_and_q_str)
        except (ValueError, SyntaxError): return None

        if isinstance(ast_p_and_q, BinaryOp) and ast_p_and_q.op_token.type == TokenType.AND:
            q_ast = ast_p_and_q.right
            return self._add_and_return_canonical(q_ast)
        return None

    def apply_double_negation_elimination(self, premise_not_not_p_str: str) -> Optional[str]:
        """Applies Double Negation Elimination: (NOT (NOT P)) |- P"""
        if not self.is_known(premise_not_not_p_str): return None
        try:
            ast_not_not_p = self._parse_formula_string(premise_not_not_p_str)
        except (ValueError, SyntaxError): return None

        if isinstance(ast_not_not_p, UnaryOp) and ast_not_not_p.op_token.type == TokenType.NOT:
            inner_operand = ast_not_not_p.operand
            if isinstance(inner_operand, UnaryOp) and inner_operand.op_token.type == TokenType.NOT:
                p_ast = inner_operand.operand
                return self._add_and_return_canonical(p_ast)
        return None

    def apply_double_negation_introduction(self, p_str: str) -> Optional[str]:
        """
        Applies Double Negation Introduction: P |- NOT (NOT P)
        """
        if not self.is_known(p_str):
            return None
        try:
            p_ast = self._parse_formula_string(p_str)
            
            # NOT P
            not_p_ast = UnaryOp(Token(TokenType.NOT, "NOT"), p_ast)
            # NOT (NOT P)
            not_not_p_ast = UnaryOp(Token(TokenType.NOT, "NOT"), not_p_ast)
            
            return self._add_and_return_canonical(not_not_p_ast)
        except (ValueError, SyntaxError):
            return None

    def _recursive_or_simplification(self, node: ASTNode, part_canonical_str: str) -> Tuple[ASTNode, bool]:
        """
        Recursively traverses the AST. If it finds a disjunction (OR) where one of
        the operands matches part_canonical_str, it replaces the entire disjunction
        with that operand.

        Args:
            node: The current ASTNode to process.
            part_canonical_str: The canonical string of the part known to be true.

        Returns:
            A tuple containing:
                - The new (potentially simplified) ASTNode.
                - A boolean flag indicating whether a change was made in this subtree.
        """
        if isinstance(node, Variable):
            return node, False

        if isinstance(node, UnaryOp):
            new_operand, changed = self._recursive_or_simplification(node.operand, part_canonical_str)
            if changed:
                return UnaryOp(node.op_token, new_operand), True
            return node, False

        if isinstance(node, BinaryOp):
            # First, check if the current node itself is an OR expression to be simplified.
            if node.op == TokenType.OR:
                left_canonical = self._ast_to_canonical_string(node.left)
                right_canonical = self._ast_to_canonical_string(node.right)

                if part_canonical_str == left_canonical:
                    # The known part matches the left side of the OR.
                    # Replace 'P OR Q' with 'P'.
                    return node.left, True 
                if part_canonical_str == right_canonical:
                    # The known part matches the right side of the OR.
                    # Replace 'P OR Q' with 'Q'.
                    return node.right, True

            # If the node itself was not simplified, recursively call for its children.
            new_left, left_changed = self._recursive_or_simplification(node.left, part_canonical_str)
            new_right, right_changed = self._recursive_or_simplification(node.right, part_canonical_str)

            if left_changed or right_changed:
                # If any child node changed, create a new BinaryOp node with the updated children.
                return BinaryOp(new_left, node.op_token, new_right), True
            
            # No changes were made in this subtree.
            return node, False
        
        # For other node types (if they appear in the future)
        return node, False


    def apply_or_elimination(self, part_str: str, complex_formula_str: str) -> Optional[str]:
        """
        Applies a custom version of Or Elimination via recursive substitution.
        If we know P and have a formula phi(A OR P), we can infer phi(P).
        This simplifies the complex formula by replacing a disjunction with its known part.

        Args:
            part_str: The part of a disjunction known to be true (e.g., 'P').
            complex_formula_str: The complex formula potentially containing the disjunction.
        """
        if not (self.is_known(part_str) and self.is_known(complex_formula_str)):
            return None
        
        try:
            # Get the canonical representation of the part we know.
            part_ast = self._parse_formula_string(part_str)
            part_canonical = self._ast_to_canonical_string(part_ast)

            # Parse the complex formula we want to simplify.
            complex_ast = self._parse_formula_string(complex_formula_str)

            # Perform the recursive simplification.
            new_ast, changed = self._recursive_or_simplification(complex_ast, part_canonical)
            
            if not changed:
                # The known part was not found within a disjunction in the complex formula,
                # or no simplification was possible.
                return None

            return self._add_and_return_canonical(new_ast)

        except (ValueError, SyntaxError):
            return None

    def apply_hypothetical_syllogism(self, premise1_str: str, premise2_str: str) -> Optional[str]:
        """
        Applies Hypothetical Syllogism: (P IMPLIES Q), (Q IMPLIES R) |- (P IMPLIES R)
        premise1_str: P IMPLIES Q
        premise2_str: Q IMPLIES R
        """
        if not (self.is_known(premise1_str) and self.is_known(premise2_str)):
            return None
        
        try:
            ast1 = self._parse_formula_string(premise1_str) # P IMPLIES Q
            ast2 = self._parse_formula_string(premise2_str) # Q IMPLIES R

            # Check structure of premise1: P IMPLIES Q
            if not (isinstance(ast1, BinaryOp) and ast1.op == TokenType.IMPLIES):
                return None
            p_ast = ast1.left
            q1_ast = ast1.right

            # Check structure of premise2: Q IMPLIES R
            if not (isinstance(ast2, BinaryOp) and ast2.op == TokenType.IMPLIES):
                return None
            q2_ast = ast2.left
            r_ast = ast2.right

            # Check if Q from (P IMPLIES Q) is the same as Q from (Q IMPLIES R)
            q1_str_canonical = self._ast_to_canonical_string(q1_ast)
            q2_str_canonical = self._ast_to_canonical_string(q2_ast)

            if q1_str_canonical != q2_str_canonical:
                return None
                
            # Construct P IMPLIES R
            p_implies_r_ast = BinaryOp(p_ast, Token(TokenType.IMPLIES, "IMPLIES"), r_ast)
            
            return self._add_and_return_canonical(p_implies_r_ast)
        except (ValueError, SyntaxError):
            return None

# Setup

In [39]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np # For observation space if using numerical representations
from typing import Set, List, Tuple, Dict, Any

## Env

In [40]:
class InferenceRule(Enum):
    MODUS_PONENS = auto()
    MODUS_TOLLENS = auto()
    AND_INTRODUCTION = auto()
    AND_ELIMINATION1 = auto() # P AND Q |- P
    AND_ELIMINATION2 = auto() # P AND Q |- Q
    DOUBLE_NEGATION_ELIMINATION = auto()
    DOUBLE_NEGATION_INTRODUCTION = auto()
    OR_ELIMINATION = auto() # Q, P OR Q |- Q
    HYPOTHETICAL_SYLLOGISM = auto() # P IMPLIES Q, Q IMPLIES R |- P IMPLIES R

class ProofSearchEnv(gym.Env):
    metadata = {'render_modes': ['human'], 'render_fps': 4}

    def __init__(self, axioms: Set[str], goal_formula_str: str,
                 rules: List[InferenceRule],
                 max_formulas_in_state: int = 20, # Max formulas agent can directly reference by index
                 max_formula_id_for_obs: int = 100, # Max ID for observation space
                 max_steps: int = 50):
        super().__init__()

        self.initial_axioms = axioms.copy()
        self.goal_formula_str = goal_formula_str
        self.max_formulas_in_actionable_list = max_formulas_in_state
        self.max_steps = max_steps
        self.available_rules = rules

        self.engine = LogicsEngine(self.initial_axioms)
        try:
            self.goal_ast = self.engine._parse_formula_string(self.goal_formula_str)
            self.goal_canonical_str = self.engine._ast_to_canonical_string(self.goal_ast)
        except (ValueError, SyntaxError) as e:
            raise ValueError(f"Goal formula '{goal_formula_str}' is invalid: {e}")

        # For mapping formula strings to unique integer IDs for the observation vector
        self.formula_to_observation_id: Dict[str, int] = {}
        self.next_observation_id_counter: int = 1 # 0 is reserved for "empty/padding"

        # List of formula strings that the agent can select as premises by their index
        self.actionable_formulas_list: List[Optional[str]] = [None] * self.max_formulas_in_actionable_list
        self.next_actionable_slot_idx: int = 0

        self.observation_space = spaces.MultiDiscrete([max_formula_id_for_obs] * self.max_formulas_in_actionable_list)

        self._initialize_actionable_formulas_and_vocab()

        # Action: (rule_index, index_premise1, index_premise2)
        self.action_space = spaces.Tuple((
            spaces.Discrete(len(self.available_rules)),
            spaces.Discrete(self.max_formulas_in_actionable_list),
            spaces.Discrete(self.max_formulas_in_actionable_list)
        ))
        
        self.current_step = 0

    def _get_or_assign_observation_id(self, formula_str: str) -> int:
        """Assigns a new ID if formula_str is new, or returns existing ID."""
        if formula_str not in self.formula_to_observation_id:
            # Simple strategy: if we exceed max_formula_id_for_obs, reuse last ID or a special one.
            # This is a limitation; proper vocabulary management is complex.
            if self.next_observation_id_counter < self.observation_space.nvec[0]: # Check against max ID for observation
                self.formula_to_observation_id[formula_str] = self.next_observation_id_counter
                self.next_observation_id_counter += 1
            else:
                # Out of observation IDs, assign a "fallback" ID (e.g., max_id - 1)
                # This means new, distinct formulas might map to the same observation ID.
                return self.observation_space.nvec[0] - 1 
        return self.formula_to_observation_id[formula_str]

    def _ensure_formula_in_actionable_list(self, formula_str: str):
        """Adds formula_str to actionable_formulas_list if not present and space allows."""
        # Check if already present (by value)
        if formula_str in self.actionable_formulas_list:
            return

        if self.next_actionable_slot_idx < self.max_formulas_in_actionable_list:
            self.actionable_formulas_list[self.next_actionable_slot_idx] = formula_str
            self.next_actionable_slot_idx += 1
        # If no space, the formula is known by the engine but not directly actionable by index.
        
        # Ensure it has an observation ID
        self._get_or_assign_observation_id(formula_str)


    def _initialize_actionable_formulas_and_vocab(self):
        """Called by __init__ and reset."""
        self.formula_to_observation_id.clear()
        self.next_observation_id_counter = 1 
        self.actionable_formulas_list = [None] * self.max_formulas_in_actionable_list
        self.next_actionable_slot_idx = 0

        # Add initial axioms to actionable list and observation vocabulary
        for axiom_str in self.initial_axioms:
            self._ensure_formula_in_actionable_list(axiom_str)
        
        # Ensure goal formula is in observation vocabulary (even if not actionable initially)
        if self.goal_canonical_str:
            self._get_or_assign_observation_id(self.goal_canonical_str)


    def _get_observation(self) -> np.ndarray:
        obs = np.zeros(self.max_formulas_in_actionable_list, dtype=int)
        for i, formula_str in enumerate(self.actionable_formulas_list):
            if formula_str:
                obs[i] = self._get_or_assign_observation_id(formula_str)
            else:
                obs[i] = 0 # 0 for empty slot
        return obs

    def reset(self, seed=None, options=None) -> Tuple[np.ndarray, Dict]:
        super().reset(seed=seed)
        self.engine = LogicsEngine(self.initial_axioms) # Re-init engine
        self.current_step = 0
        self._initialize_actionable_formulas_and_vocab()
        return self._get_observation(), {}

    def step(self, action: Tuple[int, int, int]) -> Tuple[np.ndarray, float, bool, bool, Dict]:
        self.current_step += 1
        rule_idx, premise1_idx, premise2_idx = action
        
        selected_rule = self.available_rules[rule_idx]

        reward = -0.1  # Default penalty per step
        terminated = False
        truncated = False
        newly_derived_formula_str: Optional[str] = None # Store the string of the derived formula

        premise1_str = self.actionable_formulas_list[premise1_idx]
        premise2_str = self.actionable_formulas_list[premise2_idx]
        
        num_known_before_engine = len(self.engine.get_known_formulas())
        derived_q_engine_str: Optional[str] = None # String returned by engine method

        # Check if premises are valid (not None) before calling engine
        can_apply_unary = premise1_str is not None
        can_apply_binary = premise1_str is not None and premise2_str is not None

        if selected_rule == InferenceRule.MODUS_PONENS:
            if can_apply_binary:
                derived_q_engine_str = self.engine.apply_modus_ponens(premise1_str, premise2_str)
        elif selected_rule == InferenceRule.MODUS_TOLLENS:
            if can_apply_binary:
                derived_q_engine_str = self.engine.apply_modus_tollens(premise1_str, premise2_str)
        elif selected_rule == InferenceRule.AND_INTRODUCTION:
            if can_apply_binary:
                # Prevent P AND P if P is selected twice for binary op, unless desired
                if premise1_str == premise2_str: 
                    reward = -0.3 # Penalize redundant AND intro
                else:
                    derived_q_engine_str = self.engine.apply_and_introduction(premise1_str, premise2_str)
        elif selected_rule == InferenceRule.AND_ELIMINATION1:
            if can_apply_unary:
                derived_q_engine_str = self.engine.apply_and_elimination1(premise1_str)
        elif selected_rule == InferenceRule.AND_ELIMINATION2:
            if can_apply_unary:
                derived_q_engine_str = self.engine.apply_and_elimination2(premise1_str)
        elif selected_rule == InferenceRule.DOUBLE_NEGATION_ELIMINATION:
            if can_apply_unary:
                derived_q_engine_str = self.engine.apply_double_negation_elimination(premise1_str)
        elif selected_rule == InferenceRule.DOUBLE_NEGATION_INTRODUCTION:
            if can_apply_unary: # P |- NOT NOT P
                derived_q_engine_str = self.engine.apply_double_negation_introduction(premise1_str)
        elif selected_rule == InferenceRule.OR_ELIMINATION: # P OR Q, P |- Q
            if can_apply_binary: # premise1_str is P->Q, premise2_str is Q->R
                derived_q_engine_str = self.engine.apply_or_elimination(premise1_str, premise2_str)
        elif selected_rule == InferenceRule.HYPOTHETICAL_SYLLOGISM: # P IMPLIES Q, Q IMPLIES R |- P IMPLIES R
            if can_apply_binary: # premise1_str is P->Q, premise2_str is Q->R
                derived_q_engine_str = self.engine.apply_hypothetical_syllogism(premise1_str, premise2_str)
        elif (can_apply_unary or can_apply_binary): # Premises were valid, but rule didn't apply (e.g. wrong form)
            reward = -0.5
        else: 
            reward = -2.0

        if derived_q_engine_str:
            newly_derived_formula_str = derived_q_engine_str # Keep this for info
            
            # Check if the formula was TRULY new to the engine's knowledge base
            if len(self.engine.get_known_formulas()) > num_known_before_engine:
                reward = 0.8
                self._ensure_formula_in_actionable_list(derived_q_engine_str) # Make it actionable
            else: # Formula was already known by the engine
                reward = -0.2 

            if derived_q_engine_str == self.goal_canonical_str:
                reward = 64.0
                terminated = True
        # Penalties for invalid actions or rules that didn't apply
        elif not can_apply_unary and selected_rule in [InferenceRule.AND_ELIMINATION1, InferenceRule.AND_ELIMINATION2, InferenceRule.DOUBLE_NEGATION_ELIMINATION, InferenceRule.DOUBLE_NEGATION_INTRODUCTION]:
            reward = -1.0 # Tried unary rule on empty slot
        elif not can_apply_binary and selected_rule in [InferenceRule.MODUS_PONENS, InferenceRule.MODUS_TOLLENS, InferenceRule.AND_INTRODUCTION, InferenceRule.OR_ELIMINATION, InferenceRule.HYPOTHETICAL_SYLLOGISM]:
            reward = -1.0 # Tried binary rule with at least one empty slot

        if self.current_step >= self.max_steps:
            truncated = True
        
        observation = self._get_observation()
        return observation, reward, terminated, truncated, {
            "newly_derived_formula": newly_derived_formula_str, 
            "rule_used": selected_rule.name,
            "action_indices": (premise1_idx, premise2_idx)
        }

    def render(self):
        if self.render_mode == 'human':
            print(f"--- Step: {self.current_step} ---")
            print("Actionable Formulas (agent's view for premise selection):")
            for i, f_str in enumerate(self.actionable_formulas_list):
                if f_str:
                    obs_id = self.formula_to_observation_id.get(f_str, "N/A")
                    print(f"  idx {i}: {f_str} (obs_id: {obs_id})")
                else:
                    print(f"  idx {i}: <Empty>")
            
            print(f"\nEngine's Known Formulas ({len(self.engine.get_known_formulas())}):")
            # for f_str_known in sorted(list(self.engine.get_known_formulas())):
            #     print(f"  - {f_str_known}")

            print(f"Goal: {self.goal_canonical_str}")
            if self.engine.is_known(self.goal_canonical_str):
                print(">>> Goal has been reached! <<<")
            print("-" * 20)

    def close(self):
        pass

## Agent

### Actor-Critic Network

In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ActorCriticNetwork(nn.Module):
    def __init__(self, input_dim, rule_dim, num_actions, hidden_dim=256, dropout=0.2):
        super().__init__()
        self.body = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.LayerNorm(hidden_dim),
            
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.LayerNorm(hidden_dim)
        )
        
        # Actor heads
        self.rule_head = nn.Linear(hidden_dim, rule_dim)
        self.action1_head = nn.Linear(hidden_dim, num_actions)
        self.action2_head = nn.Linear(hidden_dim, num_actions)
        
        # Critic head
        self.value_head = nn.Linear(hidden_dim, 1)
    
    def forward(self, x):
        features = self.body(x)
        
        # Actor outputs (logits for probability distributions)
        rule_logits = self.rule_head(features)
        action1_logits = self.action1_head(features)
        action2_logits = self.action2_head(features)
        
        # Critic output (state value)
        value = self.value_head(features)
        
        return rule_logits, action1_logits, action2_logits, value


### Replay Buffer

In [42]:
Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])

class ReplayBuffer:
    def __init__(self, buffer_size, batch_size, device="cpu"):
        self.memory = deque(maxlen=buffer_size)
        self.batch_size = batch_size
        self.device = device

    def add(self, state, action, reward, next_state, done):
        self.memory.append(Experience(state, action, reward, next_state, done))

    def sample(self):
        experiences = random.sample(self.memory, k=self.batch_size)
        states = torch.from_numpy(np.vstack([e.state for e in experiences])).float().to(self.device)
        actions = torch.from_numpy(np.vstack([e.action for e in experiences])).long().to(self.device)
        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences])).float().to(self.device)
        next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences])).float().to(self.device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences]).astype(np.uint8)).float().to(self.device)
        return (states, actions, rewards, next_states, dones)

    def __len__(self):
        return len(self.memory)

### Actor-Critic Agent

In [71]:
from torch.distributions import Categorical

class ActorCriticAgent:
    class Mode(Enum):
        TRAIN = 'train'
        EVAL = 'eval'

    def __init__(self,
        state_size,
        action_size, # num_actions for action1 and action2
        rule_size,
        hidden_size,
        buffer, # ReplayBuffer instance
        gamma=0.99, 
        lr=1e-3, 
        entropy_coeff=0.01,
        device=None
    ):
        self.state_size = state_size
        self.action_size = action_size
        self.rule_size = rule_size
        self.gamma = gamma
        self.buffer = buffer
        self.entropy_coeff = entropy_coeff
        
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device

        self._mode = self.Mode.TRAIN
        
        self.network = ActorCriticNetwork(state_size, rule_size, action_size, hidden_size).to(self.device)
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)

        self.train()

    @property
    def mode(self):
        return self._mode

    def eval(self):
        self._mode = self.Mode.EVAL
        self.network.eval()

    def train(self):
        self._mode = self.Mode.TRAIN
        self.network.train()

    def act(self, state, deterministic=None):
        """
        Choose actions based on the current state.
        deterministic: If True, selects the most likely actions. 
                       If False, samples from the distribution.
                       If None, uses self.mode (deterministic for EVAL).
        """
        if deterministic is None:
            deterministic = (self.mode == self.Mode.EVAL)

        state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            rule_logits, a1_logits, a2_logits, _ = self.network(state_tensor)

        rule_dist = Categorical(logits=rule_logits)
        a1_dist = Categorical(logits=a1_logits)
        a2_dist = Categorical(logits=a2_logits)

        if deterministic:
            rule_action = rule_dist.probs.argmax(dim=-1)
            a1_action = a1_dist.probs.argmax(dim=-1)
            a2_action = a2_dist.probs.argmax(dim=-1)
        else:
            rule_action = rule_dist.sample()
            a1_action = a1_dist.sample()
            a2_action = a2_dist.sample()
        
        # Save log probabilities of selected actions for training
        # (For simplicity, they will be recalculated in learn() based on the batch)
        # action_log_probs = rule_dist.log_prob(rule_action) + \
        #                    a1_dist.log_prob(a1_action) + \
        #                    a2_dist.log_prob(a2_action)

        # Return actions as a tuple of Python int
        return (rule_action.item(), a1_action.item(), a2_action.item())

    def learn(self): # Rename step to learn for clarity
        if len(self.buffer) < self.buffer.batch_size:
            return None # Return None if there's nothing to learn

        states, actions, rewards, next_states, dones = self.buffer.sample()
        # actions has shape (batch_size, 3)

        # Get current predictions from the network for states in the batch
        rule_logits, a1_logits, a2_logits, state_values = self.network(states)
        
        # Critic loss
        with torch.no_grad():
            _, _, _, next_state_values = self.network(next_states)
            # If done, then next_state_value = 0
            td_targets = rewards + self.gamma * next_state_values * (1 - dones)
        
        advantage = td_targets - state_values

        critic_loss = advantage.pow(2).mean()

        # Actor loss
        rule_dist = Categorical(logits=rule_logits)
        a1_dist = Categorical(logits=a1_logits)
        a2_dist = Categorical(logits=a2_logits)

        # Log probabilities of selected actions (from the buffer)
        log_probs_rules = rule_dist.log_prob(actions[:, 0])
        log_probs_a1 = a1_dist.log_prob(actions[:, 1])
        log_probs_a2 = a2_dist.log_prob(actions[:, 2])
        
        total_log_probs = log_probs_rules + log_probs_a1 + log_probs_a2
        
        actor_loss = -(total_log_probs * advantage.detach()).mean()
        
        # Entropy bonus for encouraging exploration
        entropy_rules = rule_dist.entropy().mean()
        entropy_a1 = a1_dist.entropy().mean()
        entropy_a2 = a2_dist.entropy().mean()
        total_entropy = entropy_rules + entropy_a1 + entropy_a2
        
        actor_loss -= self.entropy_coeff * total_entropy

        total_loss = actor_loss + critic_loss
        
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()
        
        return total_loss.item()

# Main Loop

In [72]:
active_rules = [
    InferenceRule.MODUS_PONENS,
    InferenceRule.MODUS_TOLLENS,
    InferenceRule.AND_INTRODUCTION,
    InferenceRule.AND_ELIMINATION1,
    InferenceRule.AND_ELIMINATION2,
    InferenceRule.DOUBLE_NEGATION_ELIMINATION,
    InferenceRule.DOUBLE_NEGATION_INTRODUCTION,
    InferenceRule.OR_ELIMINATION,
    InferenceRule.HYPOTHETICAL_SYLLOGISM
]

axioms = {
    "A AND B IMPLIES C",
    "E OR F IMPLIES B",
    #"P OR Q IMPLIES F",
    "A",
    "E"
}
goal = "C"

env = ProofSearchEnv(axioms=axioms, goal_formula_str=goal, rules=active_rules, max_formulas_in_state=8, max_steps=128)

In [73]:
agent = ActorCriticAgent(
    env.max_formulas_in_actionable_list,
    env.max_formulas_in_actionable_list,
    len(active_rules),
    hidden_size=72,
    buffer=ReplayBuffer(buffer_size=1000, batch_size=64),
    gamma=0.9,
    entropy_coeff=0.01,
    lr=3e-4
)

In [74]:
total_params = sum(p.numel() for p in agent.network.parameters())
print(f"Total parameters: {total_params:,}")

Total parameters: 8,090


In [75]:
agent.train()
episodes = 256
for episode in range(episodes):
    state, info = env.reset()
    steps = 0
    terminated, truncated = False, False
    while not terminated and not truncated:
        action_tuple = agent.act(state)
        next_state, reward, terminated, truncated, info = env.step(action_tuple)
        agent.buffer.add(state, action_tuple, reward, next_state, terminated)
        agent.learn()
        state = next_state
        steps += 1
    print(f"ep: {episode+1:>3}/{episodes} | steps: {steps:<3} | {'success' if terminated else 'failure'}")

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

# Results observation

In [None]:
agent.eval()
state, _ = env.reset()
print("axioms", env.engine.get_known_formulas(), end="\n\n")
while True:
    action_tuple = agent.act(state)
    next_state, reward, terminated, truncated, info = env.step(action_tuple)
    print("env.formulas", env.engine.get_known_formulas())
    print(f"reward {reward}, rule {info['rule_used']} on [{env.actionable_formulas_list[action_tuple[1]]}] [{env.actionable_formulas_list[action_tuple[2]]}] returns [{info['newly_derived_formula']}]")
    print()
    state = next_state
    if terminated or truncated:
        break

axioms {'DBAccess', '(DBAccess OR ContactEstablished) IMPLIES MissionPossible'}

env.formulas {'DBAccess', '(DBAccess OR ContactEstablished) IMPLIES MissionPossible', 'DBAccess IMPLIES MissionPossible'}
reward 0.8, rule OR_ELIMINATION on [DBAccess] [DBAccess OR ContactEstablished IMPLIES MissionPossible] returns [DBAccess IMPLIES MissionPossible]

env.formulas {'DBAccess', '(DBAccess OR ContactEstablished) IMPLIES MissionPossible', 'MissionPossible', 'DBAccess IMPLIES MissionPossible'}
reward 64.0, rule MODUS_PONENS on [DBAccess IMPLIES MissionPossible] [DBAccess] returns [MissionPossible]

