From 2fe479e0019eddd386ca9da9de0368b4c03bbeba Mon Sep 17 00:00:00 2001 From: ialarmedalien Date: Tue, 21 Oct 2025 09:30:39 -0700 Subject: [PATCH] Allow detect_cycles to receive a list of nodes instead of just a single node --- linkml_runtime/utils/schemaview.py | 31 +++++++++++++++++------- tests/test_utils/test_schemaview.py | 37 +++++++++++++++++++---------- 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/linkml_runtime/utils/schemaview.py b/linkml_runtime/utils/schemaview.py index f3ee0036..f0a15f47 100644 --- a/linkml_runtime/utils/schemaview.py +++ b/linkml_runtime/utils/schemaview.py @@ -8,6 +8,7 @@ import uuid import warnings from collections import defaultdict, deque +from collections.abc import Iterable from copy import copy, deepcopy from dataclasses import dataclass from enum import Enum @@ -44,7 +45,7 @@ from linkml_runtime.utils.pattern import PatternResolver if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Mapping + from collections.abc import Callable, Mapping from types import NotImplementedType from linkml_runtime.utils.metamodelcore import URI, URIorCURIE @@ -96,8 +97,13 @@ class OrderedBy(Enum): BLACK = 2 -def detect_cycles(f: Callable[[Any], Iterable[Any] | None], x: Any) -> None: - """Detect cycles in a graph, using function `f` to walk the graph, starting at node `x`. +def detect_cycles( + f: Callable[[Any], Iterable[Any] | None], + node_list: Iterable[Any], +) -> None: + """Detect cycles in a graph, using function `f` to walk the graph. + + Input is supplied as a list of nodes that are used to populate the `todo` stack. Uses the classic white/grey/black colour coding algorithm to track which nodes have been explored. In this case, "node" refers to any element in a schema and "neighbours" are elements that can be reached from that @@ -107,21 +113,28 @@ def detect_cycles(f: Callable[[Any], Iterable[Any] | None], x: Any) -> None: GREY: node is being processed; processing includes exploring all neighbours reachable via f(node) BLACK: node and all of its neighbours (and their neighbours, etc.) have been processed - A directed cycle reachable from node `x` raises a ValueError. + A directed cycle reachable from a node or its neighbours raises a ValueError. :param f: function that returns an iterable of neighbouring nodes (parents or children) :type f: Callable[[Any], Iterable[Any] | None] - :param x: graph node - :type x: Any - :raises ValueError: if a cycle is discovered through repeated calls to f(x) + :param node_list: list or other iterable of values to process + :type node_list: Iterable[Any] + :raises ValueError: if a cycle is discovered through repeated calls to f(node) """ + # ensure we have some nodes to start the analysis + if not node_list or not isinstance(node_list, Iterable) or isinstance(node_list, str): + err_msg = "detect_cycles requires a list of values to process" + raise ValueError(err_msg) + # keep track of the processing state of nodes in the graph processing_state: dict[Any, int] = {} # Stack entries are (node, processed_flag). # processed_flag == True means all neighbours (nodes generated by running `f(node)`) # have been added to the todo stack and the node can be marked BLACK. - todo: list[tuple[Any, bool]] = [(x, False)] + + # initialise the todo stack with entries set to False + todo: list[tuple[Any, bool]] = [(node, False) for node in node_list] while todo: node, processed_flag = todo.pop() @@ -173,7 +186,7 @@ def _closure( :rtype: list[str | ElementName | ClassDefinitionName | EnumDefinitionName | SlotDefinitionName | TypeDefinitionName] """ if kwargs and kwargs.get("detect_cycles"): - detect_cycles(f, x) + detect_cycles(f, [x]) rv = [x] if reflexive else [] visited = [] diff --git a/tests/test_utils/test_schemaview.py b/tests/test_utils/test_schemaview.py index a44c4cae..ccd3476d 100644 --- a/tests/test_utils/test_schemaview.py +++ b/tests/test_utils/test_schemaview.py @@ -2981,12 +2981,25 @@ def test_class_name_mappings() -> None: assert {snm_def.name: snm for snm, snm_def in view.slot_name_mappings().items()} == slot_names +""" +Tests of the detect_cycles function, which can identify cyclic relationships between classes, types, and other schema elements. +""" + + +@pytest.mark.parametrize("dodgy_input", [None, [], set(), {}, 12345, 123.45, "some string", ()]) +def test_detect_cycles_input_error(dodgy_input: Any) -> None: + """Ensure that `detect_cycles` throws an error if input is not supplied in the appropriate form.""" + with pytest.raises(ValueError, match="detect_cycles requires a list of values to process"): + detect_cycles(lambda x: x, dodgy_input) + + @pytest.fixture(scope="module") def sv_cycles_schema() -> SchemaView: """A schema containing cycles!""" return SchemaView(INPUT_DIR_PATH / "cycles.yaml") +# metadata for elements in the `sv_cycles_schema` CYCLES = { TYPES: { # types in cycles, either directly or via ancestors @@ -3028,8 +3041,8 @@ def sv_cycles_schema() -> SchemaView: # key: class name, value: class ancestors 1: { "BaseClass": {"BaseClass"}, - "MixinA": {"MixinA"}, - "MixinB": {"MixinB"}, + "MixinA": {"MixinA"}, # no ID slot + "MixinB": {"MixinB"}, # no ID slot "NonCycleClassA": {"NonCycleClassA", "BaseClass"}, "NonCycleClassB": {"MixinA", "NonCycleClassB", "NonCycleClassA", "BaseClass"}, "NonCycleClassC": {"MixinB", "NonCycleClassC", "NonCycleClassA", "BaseClass"}, @@ -3048,10 +3061,10 @@ def test_detect_type_cycles_error(sv_cycles_schema: SchemaView, target: str, cyc """Test detection of cycles in the types segment of the cycles schema.""" if fn == "detect_cycles": with pytest.raises(ValueError, match=f"Cycle detected at node '{cycle_start_node}'"): - detect_cycles(lambda x: sv_cycles_schema.type_parents(x), target) + detect_cycles(sv_cycles_schema.type_parents, [target]) elif fn == "graph_closure": with pytest.raises(ValueError, match=f"Cycle detected at node '{cycle_start_node}'"): - graph_closure(lambda x: sv_cycles_schema.type_parents(x), target, detect_cycles=True) + graph_closure(sv_cycles_schema.type_parents, target, detect_cycles=True) else: with pytest.raises(ValueError, match=f"Cycle detected at node '{cycle_start_node}'"): sv_cycles_schema.type_ancestors(type_name=target, detect_cycles=True) @@ -3062,9 +3075,9 @@ def test_detect_type_cycles_error(sv_cycles_schema: SchemaView, target: str, cyc def test_detect_type_cycles_no_cycles(sv_cycles_schema: SchemaView, target: str, expected: set[str], fn: str) -> None: """Ensure that types without cycles in their ancestry do not throw an error.""" if fn == "detect_cycles": - detect_cycles(lambda x: sv_cycles_schema.type_parents(x), target) + detect_cycles(sv_cycles_schema.type_parents, [target]) elif fn == "graph_closure": - got = graph_closure(lambda x: sv_cycles_schema.type_parents(x), target, detect_cycles=True) + got = graph_closure(sv_cycles_schema.type_parents, target, detect_cycles=True) assert set(got) == expected else: got = sv_cycles_schema.type_ancestors(target, detect_cycles=True) @@ -3077,11 +3090,11 @@ def test_detect_class_cycles_error(sv_cycles_schema: SchemaView, target: str, cy """Test detection of class cycles in the cycles schema.""" if fn == "detect_cycles": with pytest.raises(ValueError, match=f"Cycle detected at node '{cycle_start_node}'"): - detect_cycles(lambda x: sv_cycles_schema.class_parents(x), target) + detect_cycles(sv_cycles_schema.class_parents, [target]) elif fn == "graph_closure": with pytest.raises(ValueError, match=f"Cycle detected at node '{cycle_start_node}'"): - graph_closure(lambda x: sv_cycles_schema.class_parents(x), target, detect_cycles=True) + graph_closure(sv_cycles_schema.class_parents, target, detect_cycles=True) else: with pytest.raises(ValueError, match=f"Cycle detected at node '{cycle_start_node}'"): sv_cycles_schema.class_ancestors(target, detect_cycles=True) @@ -3092,9 +3105,9 @@ def test_detect_class_cycles_error(sv_cycles_schema: SchemaView, target: str, cy def test_detect_class_cycles_no_cycles(sv_cycles_schema: SchemaView, target: str, expected: set[str], fn: str) -> None: """Ensure that classes without cycles in their ancestry do not throw an error.""" if fn == "detect_cycles": - detect_cycles(lambda x: sv_cycles_schema.class_parents(x), target) + detect_cycles(sv_cycles_schema.class_parents, [target]) elif fn == "graph_closure": - got = graph_closure(lambda x: sv_cycles_schema.class_parents(x), target, detect_cycles=True) + got = graph_closure(sv_cycles_schema.class_parents, target, detect_cycles=True) assert set(got) == expected else: got = sv_cycles_schema.class_ancestors(target, detect_cycles=True) @@ -3116,10 +3129,10 @@ def check_recursive_id_slots(class_name: str) -> list[str]: # classes with a cycle in the class identifier slot range are cunningly named if "IdentifierCycle" in target: with pytest.raises(ValueError, match="Cycle detected at node "): - detect_cycles(lambda x: check_recursive_id_slots(x), target) + detect_cycles(check_recursive_id_slots, [target]) else: - detect_cycles(lambda x: check_recursive_id_slots(x), target) + detect_cycles(check_recursive_id_slots, [target]) @pytest.mark.parametrize(