-
Notifications
You must be signed in to change notification settings - Fork 0
/
engine.py
86 lines (71 loc) · 2.91 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import copy
import typing as t
import interpreter
import parse
class CycleError(Exception):
"""Error thrown when recursion causes an overflow due to an infinite cycle"""
class ContradictionError(Exception):
"""Error thrown when a hypothesis is both TRUE and FALSE."""
class InferenceEngine:
"""Inference engine using backward chaining to infer
if a hypothesis is true, starting from the hypothesis
to check if data confirms the rule."""
def __init__(self):
self.facts: t.Dict[str, bool] = {}
self.rules: t.Dict[str, interpreter.Rule] = {}
def add_fact(self, fact: interpreter.Fact) -> None:
self.facts[fact.name] = True
def add_rule(self, rule: interpreter.Rule) -> None:
"""Add a rule to the map, pipe it into the graph with an OR if a rule already exists."""
if rule.target not in self.rules:
self.rules[rule.target] = rule
else:
self.rules[rule.target] |= rule
def _ast(self, s: str) -> t.Optional[parse.AST]:
"""Returns the corresponding ast or None if no rule is found."""
rule = self.rules.get(s)
return rule.ast if rule else None
def infer_hypothesis(self, h: str) -> bool:
"""Infer the result of the hypothesis."""
return self._infer(h, set())
def infer(self, ast: t.Optional[parse.AST], seen) -> bool:
"""Operator dispatcher."""
if not ast:
return False
if ast.name == "+":
return self.infer(ast.operands[0], copy.copy(seen)) and self.infer(
ast.operands[1], seen
)
if ast.name == "|":
try:
loperand = self.infer(ast.operands[0], copy.copy(seen))
except CycleError:
loperand = False
return loperand or self.infer(ast.operands[1], seen)
if ast.name == "!":
return not self.infer(ast.operands[0], seen)
return self._infer(ast.name, seen)
def _infer(self, hypothesis: str, seen: t.Set[str]) -> bool:
"""Internal implementation of the inference."""
nhypothesis = (
"!" + hypothesis if not hypothesis.startswith("!") else hypothesis[1:]
)
exists = self.facts.get(hypothesis, False)
nexists = self.facts.get(nhypothesis, False)
if hypothesis in seen:
raise CycleError("Cycle detected.")
seen.add(hypothesis)
exists = exists or self.infer(self._ast(hypothesis), copy.copy(seen))
try:
nexists = nexists or self.infer(self._ast(nhypothesis), seen)
except CycleError as e:
if exists:
return exists
raise CycleError(*e.args)
if exists and nexists:
raise ContradictionError(
"Cannot infer hypothesis, {} and {} are both True.".format(
hypothesis, nhypothesis
)
)
return exists