In [2]:
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import List, Union, Dict, Tuple, Optional, Any, TypeAlias

#Parameters
filepath = "Caves/easy/path_e1.txt"

# <editor-fold desc="Semantics and structure for FOL">

# Terms
@dataclass(frozen=True)
class Variable:
    name: str
    def __str__(self) -> str: return self.name

@dataclass(frozen=True)
class Constant:
    name: str
    def __str__(self) -> str: return self.name

class LogicOperator(Enum):
    XOR = auto()
    OR = auto()
    AND = auto()
    NOT = auto()
    LPAREN = auto()
    RPAREN = auto()
    COMMA  = auto()
    ALL = auto()
    ANY = auto()
    IMPLIES = auto()
    IFF = auto()

    def __str__(self): return self.name

class LogicTerminal(Enum):
    U = auto()  # Unknown
    F = auto()  # False
    T = auto()  # True

    def __str__(self):
        if self is LogicTerminal.U: return "Unknown"
        if self is LogicTerminal.F: return "False"
        if self is LogicTerminal.T: return "True"
        return "NULL"

    def __bool__(self): return self is LogicTerminal.T

@dataclass(frozen=True)
class Predicate:
    name: str
    args: Tuple[Any, ...] = field(default_factory=tuple)  # must be immutable for hashing

    def __str__(self):
        return f"{self.name}({', '.join(map(str, self.args))})"

Term: TypeAlias = Union[Variable, Constant, LogicOperator, Predicate, LogicTerminal, None]

# Expressions

class LogicExpression:
    def  __init__(self, expression: List[Term] = None):
        self.expression: List[Term] = []
        self.iter_count = 0
        if expression:
            self.set_expression(expression)

    def __contains__(self, item):
        term_types = (Variable, Constant, LogicOperator, Predicate, LogicTerminal)
        if isinstance(item, term_types):
            return item in self.expression
        if isinstance(item, LogicExpression):
            return str(item)[1:-1] in str(self.expression)[1:-1]
        return False

    def __str__(self):
        parts = []
        for value in self.expression:
            if isinstance(value, LogicOperator): parts.append(value.name)
            elif isinstance(value, LogicTerminal): parts.append(value.name)
            elif isinstance(value, (Variable, Constant, Predicate)): parts.append(str(value))
            elif isinstance(value, str): parts.append(value)
            elif value is None: parts.append("None")
            else: parts.append(str(value))
        return " ".join(parts)

    def __next__(self):
        if self.iter_count < len(self.expression):
            value = self.expression[self.iter_count]
            self.iter_count += 1
            return value
        raise StopIteration

    def __iter__(self):
        self.iter_count = 0
        return self

    def __len__(self):
        return len(self.expression)

    def __add__(self, other):
        if isinstance(other, LogicExpression):
            self.expression.extend(other.expression)
        else:
            self.expression.append(other)
        return self

    def peek(self) -> Term:
        return self.expression[self.iter_count] if self.iter_count < len(self.expression) else None

    def set_expression(self, expression: List[Term]):
        self.expression = expression
        self.iter_count = 0

@dataclass
class ParseNode:
    nodeType: Term = None
    children: List['ParseNode'] = field(default_factory=list)

class ExpressionParser:
    def __init__(self):
        self.expression: Optional[LogicExpression] = None

    def parse_expression(self, expression: LogicExpression) -> ParseNode:
        self.expression = expression
        self.expression.iter_count = 0
        node = self._parse_iff()
        if self.expression.peek() is not None:
            raise ValueError(f"Expression not empty after parse: {self.expression.peek()}")
        return node

    def _parse_iff(self) -> ParseNode:
        node = self._parse_implies()
        while self.expression.peek() == LogicOperator.IFF:
            next(self.expression)
            rhs = self._parse_implies()
            node = self._reduce_iff(node, rhs)
        return node

    def _parse_implies(self) -> ParseNode:
        left = self._parse_xor()
        if self.expression.peek() == LogicOperator.IMPLIES:
            next(self.expression)
            right = self._parse_implies()
            return self.reduce_implies(left, right)
        return left

    def _parse_xor(self) -> ParseNode:
        node = self._parse_or()
        while self.expression.peek() == LogicOperator.XOR:
            next(self.expression)
            rhs = self._parse_or()
            node = ParseNode(nodeType=LogicOperator.XOR, children=[node, rhs])
        return node

    def _parse_or(self) -> ParseNode:
        node = self._parse_and()
        while self.expression.peek() == LogicOperator.OR:
            next(self.expression)
            rhs = self._parse_and()
            node = ParseNode(nodeType=LogicOperator.OR, children=[node, rhs])
        return node

    def _parse_and(self) -> ParseNode:
        node = self._parse_not()
        while self.expression.peek() == LogicOperator.AND:
            next(self.expression)
            rhs = self._parse_not()
            node = ParseNode(nodeType=LogicOperator.AND, children=[node, rhs])
        return node

    def _parse_not(self) -> ParseNode:
        if self.expression.peek() == LogicOperator.NOT:
            next(self.expression)
            child = self._parse_not()
            return ParseNode(nodeType=LogicOperator.NOT, children=[child])
        return self._parse_atom()

    def _parse_atom(self) -> ParseNode:
        tok = self.expression.peek()
        if tok is None:
            raise ValueError("Unexpected end of expression")

        if tok == LogicOperator.ALL:
            return self._parse_multivalued(operator=LogicOperator.AND)
        if tok == LogicOperator.ANY:
            return self._parse_multivalued(operator=LogicOperator.OR)

        if isinstance(tok, Predicate):
            next(self.expression)             # <<< consume the token
            return ParseNode(nodeType=tok)

        if isinstance(tok, (Variable, Constant, LogicTerminal)):
            next(self.expression)
            return ParseNode(nodeType=tok)

        if tok == LogicOperator.LPAREN:
            next(self.expression)  # '('
            node = self._parse_iff()
            if self.expression.peek() != LogicOperator.RPAREN:
                raise ValueError("Expected ')'")
            next(self.expression)  # ')'
            return node

        raise ValueError(f"Unexpected value: {tok}")

    def _parse_multivalued(self, operator: LogicOperator) -> ParseNode:
        next(self.expression)  # consume quantifier
        if self.expression.peek() != LogicOperator.LPAREN:
            raise ValueError("Expected '(' after quantifier")
        next(self.expression)  # '('
        children: List[ParseNode] = [self._parse_iff()]
        while self.expression.peek() == LogicOperator.COMMA:
            next(self.expression)
            children.append(self._parse_iff())
        if self.expression.peek() != LogicOperator.RPAREN:
            raise ValueError("Expected ')' to close n-ary list")
        next(self.expression)
        return ParseNode(nodeType=operator, children=children)

    @staticmethod
    def reduce_implies(p: ParseNode, q: ParseNode) -> ParseNode:
        not_p = ParseNode(nodeType=LogicOperator.NOT, children=[p])
        return ParseNode(nodeType=LogicOperator.OR, children=[not_p, q])

    def _reduce_iff(self, p: ParseNode, q: ParseNode) -> ParseNode:
        return ParseNode(nodeType=LogicOperator.AND,
                         children=[self.reduce_implies(p, q), self.reduce_implies(q, p)])

    @staticmethod
    def pretty_print(node: ParseNode, indent: str = "", is_last: bool = True):
        branch = "└── " if is_last else "├── "
        next_indent = indent + ("    " if is_last else "│   ")
        if isinstance(node.nodeType, (LogicOperator, LogicTerminal)):
            label = node.nodeType.name
        else:
            label = str(node.nodeType)
        print(indent + branch + label)
        for i, child in enumerate(node.children):
            ExpressionParser.pretty_print(child, next_indent, i == len(node.children) - 1)

class ExpressionEvaluator:
    def __init__(self, expression: LogicExpression | ParseNode,
                 environment: Optional[Dict[Any, LogicTerminal]] = None):
        self.environment: Dict[Any, LogicTerminal] = environment or {}
        # Allow constructing from a LogicExpression or a ready ParseNode:
        if isinstance(expression, ParseNode):
            self.parsed_expression = expression
        else:
            self.parsed_expression = ExpressionParser().parse_expression(expression)
        self.evaluation: LogicTerminal = self.eval_tree(self.parsed_expression)

    def __eq__(self, other): return self.evaluation == other.evaluation

    @staticmethod
    def expression_equality(expression1: LogicExpression, expression2: LogicExpression) -> LogicTerminal:
        p = ExpressionParser().parse_expression(expression1)
        q = ExpressionParser().parse_expression(expression2)
        node = ParseNode(nodeType=LogicOperator.AND, children=[
            ExpressionParser().reduce_implies(p, q),
            ExpressionParser().reduce_implies(q, p),
        ])
        return ExpressionEvaluator(node).evaluation

    @staticmethod
    def _eval_not(args: List[LogicTerminal]) -> LogicTerminal:
        if len(args) != 1: raise ValueError(f"NOT expects 1 arg, got {len(args)}")
        a = args[0]
        if a is LogicTerminal.T: return LogicTerminal.F
        if a is LogicTerminal.F: return LogicTerminal.T
        return LogicTerminal.U

    @staticmethod
    def _eval_and(args: List[LogicTerminal]) -> LogicTerminal:
        if len(args) < 2: raise ValueError(f"AND expects >=2 args, got {len(args)}")
        if LogicTerminal.F in args: return LogicTerminal.F
        if LogicTerminal.U in args: return LogicTerminal.U
        return LogicTerminal.T

    @staticmethod
    def _eval_or(args: List[LogicTerminal]) -> LogicTerminal:
        if len(args) < 2: raise ValueError(f"OR expects >=2 args, got {len(args)}")
        if LogicTerminal.T in args: return LogicTerminal.T
        if LogicTerminal.U in args: return LogicTerminal.U
        return LogicTerminal.F

    @staticmethod
    def _eval_xor(args: List[LogicTerminal]) -> LogicTerminal:
        if len(args) < 2: raise ValueError(f"XOR expects >=2 args, got {len(args)}")
        unknown = any(a is LogicTerminal.U for a in args)
        true_count = sum(1 for a in args if a is LogicTerminal.T)
        if unknown:
            return LogicTerminal.U
        return LogicTerminal.T if (true_count == 1) else LogicTerminal.F

    def eval_tree(self, node: ParseNode) -> LogicTerminal:
        node_type = node.nodeType

        if isinstance(node_type, (Predicate, Variable, Constant)):
            return self.environment.get(node_type, LogicTerminal.U)

        if isinstance(node_type, LogicTerminal):
            return node_type

        if isinstance(node_type, LogicOperator):
            child_vals = [self.eval_tree(c) for c in node.children]
            if node_type is LogicOperator.NOT: return self._eval_not(child_vals)
            if node_type is LogicOperator.AND: return self._eval_and(child_vals)
            if node_type is LogicOperator.OR:  return self._eval_or(child_vals)
            if node_type is LogicOperator.XOR: return self._eval_xor(child_vals)
            raise ValueError(f"Unsupported operator in eval: {node_type}")

        raise ValueError(f"Cannot evaluate node type: {node_type}")


# </editor-fold>

# <editor-fold desc="Wumpis World">

class Safety(Enum):
    SAFE = auto()
    RISKY = auto()
    UNSAFE = auto()
    UNKNOWN = auto()

class PuzzleParser:
    def __init__(self):
        self.size: Tuple[int,int] = (-1, -1)
        self.arrows: int = -1
        self.path:Dict[Tuple:Dict[str:bool]] = {} # Relates Position to boolean values of Breeze and Stench
        self.query:Tuple = (-1,-1)
        self.resolution: Safety = Safety.UNKNOWN
        self.file_read = False

        try:
            self.parse_puzzle()
            self.file_read = True
        except FileNotFoundError:
            print(f"File {filepath} not found")
            self.file_read = False
        except Exception:
            print(f"Bad File: {filepath}")
            self.file_read = False

    def __bool__(self):
        return self.file_read

    def parse_puzzle(self):
        with open(filepath) as file:
            path: List[str] = []
            for line in file.readlines():
                if 'GRID: ' in line:
                    grid = line.replace('GRID: ','').strip()
                    self.size = tuple(map(int, grid.split('x')))
                if 'ARROW: ' in line:
                    self.arrows = int(line.replace('ARROWS: ','').strip())

                if 'QUERY: ' in line:
                    query = line.replace('QUERY: (','').strip()[:-1] # Removes both parenthesis
                    self.query = tuple(map(int, query.split(',')))

                if 'RESOLUTION: ' in line:
                    self.resolution = Safety[line.replace('RESOLUTION: ', '').strip()]

                if line[0] == '(':
                    path.append(line.strip())

            for step in path:
                position,breeze,stench = tuple(step[:-1].split(' '))
                position = position[1:-1] # Removes parenthesis
                row,col = tuple(map(int,position.split(',')))

                breeze = breeze[-1] == 'T' # this extracts the value into a boolean
                stench = stench[-1] == 'T'
                self.path[(row,col)] = {"Breeze": breeze, "Stench": stench}

    def get_size(self):
        return self.size

    def get_path(self):
        return self.path



class KnowledgeBase:
    def __init__(self):
        # Knowledge is stored as something like Breeze((1,2)) = True
        # or Wumpis = True
        self.rules: Optional[Dict[[Predicate, str]: LogicTerminal]] = None # TODO: Write rules to use to solve puzzle
        self.facts: Optional[Dict[[Predicate, str]: LogicTerminal]] = None

        self.puzzle = PuzzleParser()
        if not self.puzzle: return

    def get_neighbors(self,square: tuple)->Tuple[Tuple[int,int],...]:
        neighbors: List[Tuple[int,int]] = []
        xbounds,ybounds = zip((0,0),self.puzzle.get_size())

        for diff in [(1,0),(-1,0),(0,1),(0,-1)]:
            neighbor = (square[0] + diff[0], square[1] + diff[1])
            if neighbor[0] < xbounds[0] or neighbor[0] >= xbounds[1]:
                continue
            if neighbor[1] < ybounds[0] or neighbor[1] >= ybounds[1]:
                continue
            neighbors.append(neighbor)

        return tuple(neighbors)

    def get_puzzle_facts(self):
        path = self.puzzle.get_path()
        for key in path:
            if path[key]["Breeze"]:
                self.facts[Predicate("Breeze", self.get_neighbors(key))] = LogicTerminal.T
            else:
                self.facts[Predicate("Breeze", self.get_neighbors(key))] = LogicTerminal.F

            if path[key]["Stench"]:
                self.facts[Predicate("Stench", self.get_neighbors(key))] = LogicTerminal.T
            else:
                self.facts[Predicate("Stench", self.get_neighbors(key))] = LogicTerminal.F


class InferenceEngine:
    # TODO: Write this class
    def __init__(self):
        self.knowledge = KnowledgeBase()

    def get_knowledge_base(self):
        return self.knowledge

    def unify(self, a: Any, b: Any, theta: Optional[Dict[Any, Any]] = None) -> Optional[Dict[Any, Any]]:
        """
        Try to unify two terms/literals/predicates.
        Return a substitution dict if successful, else None.
        """
        # TODO: implement Robinson unification (occurs check, recursive structure)
        return None

    def derive(self, max_iterations: int = 1000) -> int:
        """
        Saturate the KB by forward-chaining: repeatedly apply rules to add new facts.
        Returns number of new facts derived.
        """
        # TODO: forward chaining / resolution
        new_count = 0
        return new_count

    def classify(self, cell: Tuple[int, int]) -> Safety:
        """
        Use the KB to classify a cell as SAFE/UNSAFE/RISKY/UNKNOWN.
        """
        # TODO: encode domain-specific predicates, e.g., SAFE(r,c), PIT(r,c), WUMPUS(r,c)
        return Safety.UNKNOWN

class OutputWriter:
    # TODO: Write this class
    def __init__(self):
        self.engine = InferenceEngine()
        self.knowledge = self.engine.get_knowledge_base()

    def write_result(self, kb: KnowledgeBase) -> None:
        """
        Emit a final report describing how the puzzle was solved:
        - metrics
        - facts used
        - (optionally) rules and/or proof traces
        """
        pass

# </editor-fold>

"""
a = ExpressionParser()
# x AND (NOT y) OR T
expression = LogicExpression([
    "x",                          # variable
    LogicOperator.IFF,            # AND
    LogicOperator.LPAREN,         # (
        LogicOperator.NOT,        # NOT
        "y",                      # variable
    LogicOperator.RPAREN,         # )
    LogicOperator.IFF,             # OR
    LogicTerminal.T               # constant True
])

print(expression)
root = a.parse_expression(expression)
a.pretty_print(root)
print(ExpressionEvaluator(expression).evaluation)

"""



x IFF LPAREN NOT y RPAREN IFF T
└── AND
    ├── OR
    │   ├── NOT
    │   │   └── AND
    │   │       ├── OR
    │   │       │   ├── NOT
    │   │       │   │   └── x
    │   │       │   └── NOT
    │   │       │       └── y
    │   │       └── OR
    │   │           ├── NOT
    │   │           │   └── NOT
    │   │           │       └── y
    │   │           └── x
    │   └── T
    └── OR
        ├── NOT
        │   └── T
        └── AND
            ├── OR
            │   ├── NOT
            │   │   └── x
            │   └── NOT
            │       └── y
            └── OR
                ├── NOT
                │   └── NOT
                │       └── y
                └── x
LogicTerminal.U
