In [18]:
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Union, Dict, Tuple

#Parameters
filepath = "Caves/easy/path_e1.txt"

# <editor-fold desc="Semantics and structure for FOL">
class LogicOperator(Enum):
    """
    The set of accepted logic operators,
    the first value is there ID, the second value is the number of operands
    the second value is there operand count binary, unary, n-ary.
    """
    OR = 1,2
    AND = 2,2
    NOT = 3,1
    LPAREN = 4,0
    RPAREN = 5,0
    COMMA  = 6,0
    FORALL = 7,0
    EXISTS = 8,0
    IMPLIES = 9,2
    IFF = 10,2

class LogicTerminal(Enum):
    U = -1 # Unknown
    F = 0 # False
    T = 1 # True

class LogicExpression:
    def  __init__(self,expression: List[Union[LogicOperator, LogicTerminal, str]] = None):
        self.variables = {}
        self.expression = []
        self.iter_count = 0
        if expression:
            self.set_expression(expression)

    def __str__(self):
        parts = []
        for value in self.expression:
            if isinstance(value, str): parts.append(value)
            else: parts.append(value.name)
        return " ".join(parts)

    def __next__(self):
        if self.iter_count < len(self.expression):
            value = self.expression[self.iter_count]
            self.iter_count += 1
            return value
        else:
            raise StopIteration

    def __iter__(self):
        self.iter_count = 0
        return self

    def peek(self)-> Union[LogicOperator, LogicTerminal, str, None]:
        if self.iter_count < len(self.expression):
            return self.expression[self.iter_count]
        else:
            return None

    def set_expression(self, expression: List[Union[LogicOperator, LogicTerminal, str]]):
        """
        :param expression: Takes a list as an input, the list is an ordered sequence of 3 possible types:
        --Logic Operators using the LogicOperator Enumerator
        --Logic Terminals using the LogicTerminal enumerator
        --Strings, these can be any value and represent variables within the logic.
        This value is set and stored.
        """
        for value in expression:
            if isinstance(value, str):
                self.variables[value] = LogicTerminal.U

        self.expression = expression

    def set_variable(self, variable, value) -> bool:
        if variable not in self.variables:
            return False
        self.variables[variable] = value
        return True

    def get_variable(self, variable):
        return self.variables[variable]

    def get_variables(self):
        return self.variables

@dataclass
class LogicFunction:
    name: str
    args: List[Union[str, 'LogicFunction']]
    def __str__(self):
        string_args = []
        for arg in self.args:
            if isinstance(arg, str):
                string_args.append(arg)
            else:
                string_args.append(str(arg))
        return f"{self.name}({", ".join(string_args)})"

@dataclass
class ParseNode:
    nodeType: Union['ParseNode', 'LogicOperator', 'LogicTerminal', str, None] = None
    children: List['ParseNode'] = field(default_factory=list)

class ExpressionParser:
    def __init__(self):
        self.expression = None   # will hold the LogicExpression during parsing

    def parse_expression(self, expression: LogicExpression) -> ParseNode:
        self.expression = expression
        self.expression.iter_count = 0
        node = self._parse_iff()  # unchanged precedence chain
        if self.expression.peek() is not None:
            raise ValueError(f"Expression not empty after parse: {self.expression.peek()}")
        return node

    def _parse_iff(self) -> ParseNode:
        node = self._parse_implies()
        while self.expression.peek() == LogicOperator.IFF:
            next(self.expression)
            rhs = self._parse_implies()
            node = self._reduce_iff(node, rhs)
        return node

    def _parse_implies(self) -> ParseNode:
        left = self._parse_or()
        if self.expression.peek() == LogicOperator.IMPLIES:
            next(self.expression)
            right = self._parse_implies()
            return self._reduce_implies(left, right)
        return left

    def _parse_or(self) -> ParseNode:
        node = self._parse_and()
        while self.expression.peek() == LogicOperator.OR:
            next(self.expression)
            rhs = self._parse_and()
            # keep n-ary AND/OR flat if you like, but binary is fine too
            node = ParseNode(nodeType=LogicOperator.OR, children=[node, rhs])
        return node

    def _parse_and(self) -> ParseNode:
        node = self._parse_not()
        while self.expression.peek() == LogicOperator.AND:
            next(self.expression)
            rhs = self._parse_not()
            node = ParseNode(nodeType=LogicOperator.AND, children=[node, rhs])
        return node

    def _parse_not(self) -> ParseNode:
        if self.expression.peek() == LogicOperator.NOT:
            next(self.expression)
            child = self._parse_not()
            return ParseNode(nodeType=LogicOperator.NOT, children=[child])
        return self._parse_atom()

    def _parse_atom(self) -> ParseNode:
        tok = self.expression.peek()
        if tok is None:
            raise ValueError("Unexpected end of expression")

        # FORALL(...) -> AND(children), EXISTS(...) -> OR(children)
        if tok == LogicOperator.FORALL:
            return self._parse_multivalued(operator=LogicOperator.AND)
        if tok == LogicOperator.EXISTS:
            return self._parse_multivalued(operator=LogicOperator.OR)

        # variable / bare identifier
        if isinstance(tok, str):
            next(self.expression)
            return ParseNode(nodeType=tok)

        # terminals
        if isinstance(tok, LogicTerminal):
            next(self.expression)
            return ParseNode(nodeType=tok)

        # parenthesized sub expression
        if tok == LogicOperator.LPAREN:
            next(self.expression)  # '('
            node = self._parse_iff()
            if self.expression.peek() != LogicOperator.RPAREN:
                raise ValueError("Expected ')'")
            next(self.expression)  # ')'
            return node

        raise ValueError(f"Unexpected value: {tok}")

    def _parse_multivalued(self, operator: LogicOperator) -> ParseNode:
        """
        Parse FORALL(...) or EXISTS(...) as a comma-separated list of FULL formulas,
        returning a single AND/OR node with all children.
        """
        # consume the quantifier token
        next(self.expression)

        # require '('
        if self.expression.peek() != LogicOperator.LPAREN:
            raise ValueError("Expected '(' after n-ary operator")
        next(self.expression)  # '('

        # parse first child as a full formula
        children: List[ParseNode] = [self._parse_iff()]

        # parse ", formula" zero or more times
        while self.expression.peek() == LogicOperator.COMMA:
            next(self.expression)  # ','
            children.append(self._parse_iff())

        # require ')'
        if self.expression.peek() != LogicOperator.RPAREN:
            raise ValueError("Expected ')' to close n-ary list")
        next(self.expression)  # ')'

        return ParseNode(nodeType=operator, children=children)

    def _reduce_implies(self, p: ParseNode, q: ParseNode) -> ParseNode:
        # p -> q == (¬p ∨ q)
        not_p = ParseNode(nodeType=LogicOperator.NOT, children=[p])
        return ParseNode(nodeType=LogicOperator.OR, children=[not_p, q])

    def _reduce_iff(self, p: ParseNode, q: ParseNode) -> ParseNode:
        # p <-> q == (p -> q) ∧ (q -> p)
        return ParseNode(nodeType=LogicOperator.AND,children=[self._reduce_implies(p, q),self._reduce_implies(q, p)])

    @staticmethod
    def pretty_print(node: ParseNode, indent: str = "", is_last: bool = True):
        # Decide connector symbols
        branch = "└── " if is_last else "├── "
        next_indent = indent + ("    " if is_last else "│   ")

        # Format node label
        if isinstance(node.nodeType, LogicOperator):
            label = str(node.nodeType.name)
        elif isinstance(node.nodeType, LogicTerminal):
            label = str(node.nodeType.name)
        else:
            label = str(node.nodeType)

        # Print this node
        print(indent + branch + label)

        # Recurse on children
        for i, child in enumerate(node.children):
            ExpressionParser.pretty_print(child, next_indent, i == len(node.children) - 1)

class ExpressionEvaluator:
    # Lattice order: F < U < T

    def __init__(self, expression: LogicExpression):
        self.expression: LogicExpression = expression
        self.environment: dict = expression.get_variables()
        self.parsed_expression: ParseNode = ExpressionParser().parse_expression(expression)
        self.evaluation: LogicTerminal = self.eval_tree(self.parsed_expression)

    @staticmethod
    def _eval_not(arguments: List[LogicTerminal]) -> LogicTerminal:
        """
        :param arguments: Expects list of length 1
        """
        if len(arguments) != 1: raise ValueError(f"Expected 1 argument, got {len(arguments)}")
        if arguments[0] is LogicTerminal.T: return LogicTerminal.F
        if arguments[0] is LogicTerminal.F: return LogicTerminal.T
        return LogicTerminal.U

    @staticmethod
    def _eval_and(arguments: List[LogicTerminal]) -> LogicTerminal:
        """
        :param arguments: Expects list of length at least 2
        """
        if len(arguments) < 2: raise ValueError(f"Expected at least 2 arguments, got {len(arguments)}")
        if LogicTerminal.F in arguments: return LogicTerminal.F
        if LogicTerminal.U in arguments: return LogicTerminal.U
        return LogicTerminal.T

    @staticmethod
    def _eval_or(arguments: List[LogicTerminal]) -> LogicTerminal:
        """
        :param arguments: Expects list of length at least 2
        """
        if len(arguments) < 2: raise ValueError(f"Expected at least 2 arguments, got {len(arguments)}")
        if LogicTerminal.T in arguments: return LogicTerminal.T
        if LogicTerminal.U in arguments: return LogicTerminal.U
        return LogicTerminal.F

    def eval_tree(self, node: ParseNode) -> LogicTerminal:
        """
        Evaluate a parsed tree (ParseNode) to a LogicTerminal.
        """
        node_type = node.nodeType

        # Variable leaf
        if isinstance(node_type, str):
            return self.environment[node_type]

        # Constant leaf
        if isinstance(node_type, LogicTerminal):
            return node_type

        # MultiValue Operands
        if isinstance(node_type, LogicOperator):
            child_values = [self.eval_tree(child) for child in node.children]
            match node_type:
                case LogicOperator.NOT: return self._eval_not(child_values)
                case LogicOperator.AND: return self._eval_and(child_values)
                case LogicOperator.OR:  return self._eval_or(child_values)

        raise ValueError(f"Cannot evaluate node type: {node_type}")

# </editor-fold>

# Wumpis World and FOL machine

class SquareSafetyLevel(Enum):
    SAFE = 1
    RISKY = 0
    UNSAFE = -1

class PuzzleParser:
    def __init__(self):
        self.size: Tuple[int,int] = (-1, -1)
        self.arrows: int = -1
        self.path:Dict[Tuple:Dict[str:bool]] = {} # Relates Position to boolean values of Breeze and Stench
        self.query:Tuple = (-1,-1)
        self.resolution: SquareSafetyLevel = SquareSafetyLevel.RISKY
        try:
            self.parse_puzzle(filepath)
        except FileNotFoundError:
            print(f"File {filepath} not found")
        except Exception:
            print(f"Bad File: {filepath}")

    def parse_puzzle(self, filepath: str):
        with open(filepath) as file:
            path: List[str] = []
            for line in file.readlines():
                if 'GRID: ' in line:
                    grid = line.replace('GRID: ','').strip()
                    self.size = tuple(map(int, grid.split('x')))
                if 'ARROW: ' in line:
                    self.arrows = int(line.replace('ARROWS: ','').strip())

                if 'QUERY: ' in line:
                    query = line.replace('QUERY: (','').strip()[:-1] # Removes both parenthesis
                    self.query = tuple(map(int, query.split(',')))

                if 'RESOLUTION: ' in line:
                    self.resolution = SquareSafetyLevel[line.replace('RESOLUTION: ','').strip()]

                if line[0] == '(':
                    path.append(line.strip())

            for step in path:
                position,breeze,stench = tuple(step[:-1].split(' '))
                position = position[1:-1] # Removes parenthesis
                row,col = tuple(map(int,position.split(',')))

                breeze = breeze[-1] == 'T' # this extracts the value into a boolean
                stench = stench[-1] == 'T'
                self.path[(row,col)] = {"Breeze": breeze, "Stench": stench}










"""

a = ExpressionParser()
# x AND (NOT y) OR T
expression = LogicExpression([
    "x",                          # variable
    LogicOperator.IFF,            # AND
    LogicOperator.LPAREN,         # (
        LogicOperator.NOT,        # NOT
        "y",                      # variable
    LogicOperator.RPAREN,         # )
    LogicOperator.IFF,             # OR
    LogicTerminal.T               # constant True
])

print(expression)
root = a.parse_expression(expression)
a.pretty_print(root)
print(ExpressionEvaluator(expression).evaluation)

"""




<__main__.PuzzleParser object at 0x00000168BF2473B0>


'\n\na = ExpressionParser()\n# x AND (NOT y) OR T\nexpression = LogicExpression([\n    "x",                          # variable\n    LogicOperator.IFF,            # AND\n    LogicOperator.LPAREN,         # (\n        LogicOperator.NOT,        # NOT\n        "y",                      # variable\n    LogicOperator.RPAREN,         # )\n    LogicOperator.IFF,             # OR\n    LogicTerminal.T               # constant True\n])\n\nprint(expression)\nroot = a.parse_expression(expression)\na.pretty_print(root)\nprint(ExpressionEvaluator(expression).evaluation)\n\n'