diff --git a/.gitignore b/.gitignore index 646f537a..347f7941 100755 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ __pycache__/ *.py[cod] *$py.class -"pyreason/.cache_status.yaml +pyreason/.cache_status.yaml # Cache status file when initialized # Keep the initial false version in repo, ignore when it becomes true .DS_STORE diff --git a/tests/unit/disable_jit/interpretations/test_ground_rule_helpers.py b/tests/unit/disable_jit/interpretations/test_ground_rule_helpers.py index c5430145..52f58c0e 100644 --- a/tests/unit/disable_jit/interpretations/test_ground_rule_helpers.py +++ b/tests/unit/disable_jit/interpretations/test_ground_rule_helpers.py @@ -919,6 +919,8 @@ def get_clauses(self): return self._clauses def get_thresholds(self): return self._thresholds def get_annotation_function(self): return self._ann_fn def get_edges(self): return self._rule_edges + def get_head_function(self): return ["", ""] + def get_head_function_vars(self): return [[], []] def _shim_typed_list(monkeypatch): class _ListShim: diff --git a/tests/unit/disable_jit/interpretations/test_interpretation_common.py b/tests/unit/disable_jit/interpretations/test_interpretation_common.py index 74a5763f..069af6b1 100644 --- a/tests/unit/disable_jit/interpretations/test_interpretation_common.py +++ b/tests/unit/disable_jit/interpretations/test_interpretation_common.py @@ -71,9 +71,11 @@ def add_edge(*args): _ground_rule_fn = _py(interpretation._ground_rule) if "num_ga" in inspect.signature(_ground_rule_fn).parameters: def ground_rule(*args, **kwargs): + kwargs.setdefault('head_functions', ()) return _ground_rule_fn(*args, num_ga=[0], **kwargs) else: def ground_rule(*args, **kwargs): + kwargs.setdefault('head_functions', ()) return _ground_rule_fn(*args, **kwargs) ns.ground_rule = ground_rule ns.update_rule_trace = _py(interpretation._update_rule_trace) diff --git a/tests/unit/disable_jit/interpretations/test_interpretation_init.py b/tests/unit/disable_jit/interpretations/test_interpretation_init.py index dcfd36ce..6d9dc0e9 100644 --- a/tests/unit/disable_jit/interpretations/test_interpretation_init.py +++ b/tests/unit/disable_jit/interpretations/test_interpretation_init.py @@ -76,8 +76,9 @@ def test_interpretation_init_neighbors(shim_types): False, False, False, - 0, False, + "", + True, ) assert set(interp.neighbors["n1"]) == {"n2"} assert set(interp.neighbors["n2"]) == {"n1"} @@ -211,8 +212,9 @@ def build_interp(): False, False, False, - 0, False, + "", + True, ) interp.time = 5 interp.prev_reasoning_data = [2, 0] diff --git a/tests/unit/disable_jit/interpretations/test_reason_core.py b/tests/unit/disable_jit/interpretations/test_reason_core.py index 872c6171..1c657e2c 100644 --- a/tests/unit/disable_jit/interpretations/test_reason_core.py +++ b/tests/unit/disable_jit/interpretations/test_reason_core.py @@ -174,6 +174,7 @@ def is_static(self): "allow_ground_rules": True, "max_facts_time": 0, "annotation_functions": {}, + "head_functions": (), "convergence_mode": "perfect_convergence", "convergence_delta": 0, "verbose": False, @@ -220,6 +221,7 @@ def run(**overrides): params["allow_ground_rules"], params["max_facts_time"], params["annotation_functions"], + params["head_functions"], params["convergence_mode"], params["convergence_delta"], params["verbose"],