diff --git a/dvc/command/dag.py b/dvc/command/dag.py index e408be644a..bc95369146 100644 --- a/dvc/command/dag.py +++ b/dvc/command/dag.py @@ -2,7 +2,6 @@ import logging from dvc.command.base import CmdBase, append_doc_link -from dvc.exceptions import DvcException logger = logging.getLogger(__name__) @@ -30,80 +29,91 @@ def _show_dot(G): return dot_file.getvalue() -def _build(G, target=None, full=False, outs=False): +def _collect_targets(repo, target, outs): + if not target: + return [] + + pairs = repo.collect_granular(target) + if not outs: + return [stage.addressing for stage, _ in pairs] + + targets = [] + for stage, info in pairs: + if not info: + targets.extend([str(out) for out in stage.outs]) + continue + + for out in repo.outs_trie.itervalues(prefix=info.parts): # noqa: B301 + targets.extend(str(out)) + + return targets + + +def _transform(repo, outs): import networkx as nx - from dvc.repo.graph import get_pipeline, get_pipelines + if outs: + G = repo.outs_graph + + def _relabel(out): + return str(out) - if target: - H = get_pipeline(get_pipelines(G), target) - if not full: - descendants = nx.descendants(G, target) - descendants.add(target) - H.remove_nodes_from(set(G.nodes()) - descendants) else: - H = G + G = repo.graph - if outs: - G = nx.DiGraph() - for stage in H.nodes: - G.add_nodes_from(stage.outs) + def _relabel(stage): + return stage.addressing - for from_stage, to_stage in nx.edge_dfs(H): - G.add_edges_from( - [ - (from_out, to_out) - for from_out in from_stage.outs - for to_out in to_stage.outs - ] - ) - H = G + return nx.relabel_nodes(G, _relabel, copy=True) - def _relabel(node): - from dvc.stage import Stage - return node.addressing if isinstance(node, Stage) else str(node) +def _filter(G, targets, full): + import networkx as nx - return nx.relabel_nodes(H, _relabel, copy=False) + if not targets: + return G + + H = G.copy() + if not full: + descendants = set() + for target in targets: + descendants.update(nx.descendants(G, target)) + descendants.add(target) + H.remove_nodes_from(set(G.nodes()) - descendants) + + undirected = H.to_undirected() + connected = set() + for target in targets: + connected.update(nx.node_connected_component(undirected, target)) + + H.remove_nodes_from(set(H.nodes()) - connected) + + return H + + +def _build(repo, target=None, full=False, outs=False): + targets = _collect_targets(repo, target, outs) + G = _transform(repo, outs) + return _filter(G, targets, full) class CmdDAG(CmdBase): def run(self): - try: - target = None - if self.args.target: - stages = self.repo.collect(self.args.target) - if len(stages) > 1: - logger.error( - f"'{self.args.target}' contains more than one stage " - "{stages}, please specify one stage" - ) - return 1 - target = stages[0] - - G = _build( - self.repo.graph, - target=target, - full=self.args.full, - outs=self.args.outs, - ) - - if self.args.dot: - logger.info(_show_dot(G)) - else: - from dvc.utils.pager import pager - - pager(_show_ascii(G)) - - return 0 - except DvcException: - msg = "failed to show " - if self.args.target: - msg += f"a pipeline for '{target}'" - else: - msg += "pipelines" - logger.exception(msg) - return 1 + G = _build( + self.repo, + target=self.args.target, + full=self.args.full, + outs=self.args.outs, + ) + + if self.args.dot: + logger.info(_show_dot(G)) + else: + from dvc.utils.pager import pager + + pager(_show_ascii(G)) + + return 0 def add_parser(subparsers, parent_parser): diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index c6a3e0bec9..a70b541226 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from functools import wraps -from funcy import cached_property, cat, first +from funcy import cached_property, cat from git import InvalidGitRepositoryError from dvc.config import Config @@ -23,7 +23,8 @@ from ..stage.exceptions import StageFileDoesNotExistError, StageNotFound from ..utils import parse_target -from .graph import check_acyclic, get_pipeline, get_pipelines +from .graph import build_graph, build_outs_graph, get_pipeline, get_pipelines +from .trie import build_outs_trie logger = logging.getLogger(__name__) @@ -289,7 +290,7 @@ def check_modified_graph(self, new_stages): # # [1] https://github.com/iterative/dvc/issues/2671 if not getattr(self, "_skip_graph_checks", False): - self._collect_graph(self.stages + new_stages) + build_graph(self.stages + new_stages) def _collect_inside(self, path, graph): import networkx as nx @@ -448,114 +449,17 @@ def used_cache( return cache - def _collect_graph(self, stages): - """Generate a graph by using the given stages on the given directory - - The nodes of the graph are the stage's path relative to the root. - - Edges are created when the output of one stage is used as a - dependency in other stage. - - The direction of the edges goes from the stage to its dependency: - - For example, running the following: - - $ dvc run -o A "echo A > A" - $ dvc run -d A -o B "echo B > B" - $ dvc run -d B -o C "echo C > C" - - Will create the following graph: - - ancestors <-- - | - C.dvc -> B.dvc -> A.dvc - | | - | --> descendants - | - ------- pipeline ------> - | - v - (weakly connected components) - - Args: - stages (list): used to build a graph, if None given, collect stages - in the repository. - - Raises: - OutputDuplicationError: two outputs with the same path - StagePathAsOutputError: stage inside an output directory - OverlappingOutputPathsError: output inside output directory - CyclicGraphError: resulting graph has cycles - """ - import networkx as nx - from pygtrie import Trie - - from dvc.exceptions import ( - OutputDuplicationError, - OverlappingOutputPathsError, - StagePathAsOutputError, - ) - - G = nx.DiGraph() - stages = stages or self.stages - outs = Trie() # Use trie to efficiently find overlapping outs and deps - - for stage in filter(bool, stages): # bug? not using it later - for out in stage.outs: - out_key = out.path_info.parts - - # Check for dup outs - if out_key in outs: - dup_stages = [stage, outs[out_key].stage] - raise OutputDuplicationError(str(out), dup_stages) - - # Check for overlapping outs - if outs.has_subtrie(out_key): - parent = out - overlapping = first(outs.values(prefix=out_key)) - else: - parent = outs.shortest_prefix(out_key).value - overlapping = out - if parent and overlapping: - msg = ( - "Paths for outs:\n'{}'('{}')\n'{}'('{}')\n" - "overlap. To avoid unpredictable behaviour, " - "rerun command with non overlapping outs paths." - ).format( - str(parent), - parent.stage.addressing, - str(overlapping), - overlapping.stage.addressing, - ) - raise OverlappingOutputPathsError(parent, overlapping, msg) - - outs[out_key] = out - - for stage in stages: - out = outs.shortest_prefix(PathInfo(stage.path).parts).value - if out: - raise StagePathAsOutputError(stage, str(out)) - - # Building graph - G.add_nodes_from(stages) - for stage in stages: - for dep in stage.deps: - if dep.path_info is None: - continue - - dep_key = dep.path_info.parts - overlapping = [n.value for n in outs.prefixes(dep_key)] - if outs.has_subtrie(dep_key): - overlapping.extend(outs.values(prefix=dep_key)) - - G.add_edges_from((stage, out.stage) for out in overlapping) - check_acyclic(G) - - return G + @cached_property + def outs_trie(self): + return build_outs_trie(self.stages) @cached_property def graph(self): - return self._collect_graph(self.stages) + return build_graph(self.stages, self.outs_trie) + + @cached_property + def outs_graph(self): + return build_outs_graph(self.graph, self.outs_trie) @cached_property def pipelines(self): @@ -648,6 +552,8 @@ def close(self): self.scm.close() def _reset(self): + self.__dict__.pop("outs_trie", None) + self.__dict__.pop("outs_graph", None) self.__dict__.pop("graph", None) self.__dict__.pop("stages", None) self.__dict__.pop("pipelines", None) diff --git a/dvc/repo/graph.py b/dvc/repo/graph.py index 325ea6d0b3..41142ae770 100644 --- a/dvc/repo/graph.py +++ b/dvc/repo/graph.py @@ -26,3 +26,97 @@ def get_pipelines(G): import networkx as nx return [G.subgraph(c).copy() for c in nx.weakly_connected_components(G)] + + +def build_graph(stages, outs_trie=None): + """Generate a graph by using the given stages on the given directory + + The nodes of the graph are the stage's path relative to the root. + + Edges are created when the output of one stage is used as a + dependency in other stage. + + The direction of the edges goes from the stage to its dependency: + + For example, running the following: + + $ dvc run -o A "echo A > A" + $ dvc run -d A -o B "echo B > B" + $ dvc run -d B -o C "echo C > C" + + Will create the following graph: + + ancestors <-- + | + C.dvc -> B.dvc -> A.dvc + | | + | --> descendants + | + ------- pipeline ------> + | + v + (weakly connected components) + + Args: + stages (list): used to build a graph from + + Raises: + OutputDuplicationError: two outputs with the same path + StagePathAsOutputError: stage inside an output directory + OverlappingOutputPathsError: output inside output directory + CyclicGraphError: resulting graph has cycles + """ + import networkx as nx + + from dvc.exceptions import StagePathAsOutputError + + from ..path_info import PathInfo + from .trie import build_outs_trie + + G = nx.DiGraph() + + # Use trie to efficiently find overlapping outs and deps + outs_trie = outs_trie or build_outs_trie(stages) + + for stage in stages: + out = outs_trie.shortest_prefix(PathInfo(stage.path).parts).value + if out: + raise StagePathAsOutputError(stage, str(out)) + + # Building graph + G.add_nodes_from(stages) + for stage in stages: + for dep in stage.deps: + if dep.path_info is None: + continue + + dep_key = dep.path_info.parts + overlapping = [n.value for n in outs_trie.prefixes(dep_key)] + if outs_trie.has_subtrie(dep_key): + overlapping.extend(outs_trie.values(prefix=dep_key)) + + G.add_edges_from((stage, out.stage) for out in overlapping) + check_acyclic(G) + + return G + + +# NOTE: using stage graph instead of just list of stages to make sure that it +# has already passed all the sanity checks like cycles/overlapping outputs and +# so on. +def build_outs_graph(graph, outs_trie): + import networkx as nx + + G = nx.DiGraph() + + G.add_nodes_from(outs_trie.values()) + for stage in graph.nodes(): + for dep in stage.deps: + dep_key = dep.path_info.parts + overlapping = [n.value for n in outs_trie.prefixes(dep_key)] + if outs_trie.has_subtrie(dep_key): + overlapping.extend(outs_trie.values(prefix=dep_key)) + + for from_out in stage.outs: + G.add_edges_from((from_out, out) for out in overlapping) + return G diff --git a/dvc/repo/trie.py b/dvc/repo/trie.py new file mode 100644 index 0000000000..d0826d0b91 --- /dev/null +++ b/dvc/repo/trie.py @@ -0,0 +1,41 @@ +from funcy import first +from pygtrie import Trie + +from dvc.exceptions import OutputDuplicationError, OverlappingOutputPathsError + + +def build_outs_trie(stages): + outs = Trie() + + for stage in filter(bool, stages): # bug? not using it later + for out in stage.outs: + out_key = out.path_info.parts + + # Check for dup outs + if out_key in outs: + dup_stages = [stage, outs[out_key].stage] + raise OutputDuplicationError(str(out), dup_stages) + + # Check for overlapping outs + if outs.has_subtrie(out_key): + parent = out + overlapping = first(outs.values(prefix=out_key)) + else: + parent = outs.shortest_prefix(out_key).value + overlapping = out + if parent and overlapping: + msg = ( + "Paths for outs:\n'{}'('{}')\n'{}'('{}')\n" + "overlap. To avoid unpredictable behaviour, " + "rerun command with non overlapping outs paths." + ).format( + str(parent), + parent.stage.addressing, + str(overlapping), + overlapping.stage.addressing, + ) + raise OverlappingOutputPathsError(parent, overlapping, msg) + + outs[out_key] = out + + return outs diff --git a/tests/unit/command/test_dag.py b/tests/unit/command/test_dag.py index 5c30825fc0..42f1fcf8fc 100644 --- a/tests/unit/command/test_dag.py +++ b/tests/unit/command/test_dag.py @@ -23,7 +23,7 @@ def test_dag(tmp_dir, dvc, mocker, fmt): @pytest.fixture -def graph(tmp_dir, dvc): +def repo(tmp_dir, dvc): tmp_dir.dvc_gen("a", "a") tmp_dir.dvc_gen("b", "b") @@ -42,46 +42,68 @@ def graph(tmp_dir, dvc): ) dvc.run(no_exec=True, deps=["a", "h"], outs=["j"], cmd="cmd4", name="4") - return dvc.graph + return dvc -def test_build(graph): - assert nx.is_isomorphic(_build(graph), graph) +def test_build(repo): + assert nx.is_isomorphic(_build(repo), repo.graph) -def test_build_target(graph): - (stage,) = filter( - lambda s: hasattr(s, "name") and s.name == "3", graph.nodes() - ) - G = _build(graph, target=stage) +def test_build_target(repo): + G = _build(repo, target="3") assert set(G.nodes()) == {"3", "b.dvc", "a.dvc"} assert set(G.edges()) == {("3", "a.dvc"), ("3", "b.dvc")} -def test_build_target_with_outs(graph): - (stage,) = filter( - lambda s: hasattr(s, "name") and s.name == "3", graph.nodes() - ) - G = _build(graph, target=stage, outs=True) +def test_build_target_with_outs(repo): + G = _build(repo, target="3", outs=True) assert set(G.nodes()) == {"a", "b", "h", "i"} assert set(G.edges()) == { - ("h", "a"), - ("h", "b"), ("i", "a"), ("i", "b"), + ("h", "a"), + ("h", "b"), } -def test_build_full(graph): - (stage,) = filter( - lambda s: hasattr(s, "name") and s.name == "3", graph.nodes() - ) - G = _build(graph, target=stage, full=True) - assert nx.is_isomorphic(G, graph) +def test_build_granular_target_with_outs(repo): + G = _build(repo, target="h", outs=True) + assert set(G.nodes()) == {"a", "b", "h"} + assert set(G.edges()) == { + ("h", "a"), + ("h", "b"), + } + + +def test_build_full(repo): + G = _build(repo, target="3", full=True) + assert nx.is_isomorphic(G, repo.graph) + + +# NOTE: granular or not, full outs DAG should be the same +@pytest.mark.parametrize("granular", [True, False]) +def test_build_full_outs(repo, granular): + target = "h" if granular else "3" + G = _build(repo, target=target, outs=True, full=True) + assert set(G.nodes()) == {"j", "i", "d", "b", "g", "f", "e", "a", "h"} + assert set(G.edges()) == { + ("d", "a"), + ("e", "a"), + ("f", "b"), + ("g", "b"), + ("h", "a"), + ("h", "b"), + ("i", "a"), + ("i", "b"), + ("j", "a"), + ("j", "h"), + } -def test_show_ascii(graph): - assert [line.rstrip() for line in _show_ascii(graph).splitlines()] == [ +def test_show_ascii(repo): + assert [ + line.rstrip() for line in _show_ascii(repo.graph).splitlines() + ] == [ " +----------------+ +----------------+", # noqa: E501 " | stage: 'a.dvc' | | stage: 'b.dvc' |", # noqa: E501 " *+----------------+**** +----------------+", # noqa: E501 @@ -100,8 +122,8 @@ def test_show_ascii(graph): ] -def test_show_dot(graph): - assert _show_dot(graph) == ( +def test_show_dot(repo): + assert _show_dot(repo.graph) == ( "strict digraph {\n" "stage;\n" "stage;\n" diff --git a/tests/unit/repo/test_repo.py b/tests/unit/repo/test_repo.py index 1516d3c81a..44ae5d0d4b 100644 --- a/tests/unit/repo/test_repo.py +++ b/tests/unit/repo/test_repo.py @@ -105,29 +105,29 @@ def test_collect_optimization_on_stage_name(tmp_dir, dvc, mocker, run_copy): def test_skip_graph_checks(tmp_dir, dvc, mocker, run_copy): # See https://github.com/iterative/dvc/issues/2671 for more info - mock_collect_graph = mocker.patch("dvc.repo.Repo._collect_graph") + mock_build_graph = mocker.patch("dvc.repo.build_graph") # sanity check tmp_dir.gen("foo", "foo text") dvc.add("foo") run_copy("foo", "bar", single_stage=True) - assert mock_collect_graph.called + assert mock_build_graph.called # check that our hack can be enabled - mock_collect_graph.reset_mock() + mock_build_graph.reset_mock() dvc._skip_graph_checks = True tmp_dir.gen("baz", "baz text") dvc.add("baz") run_copy("baz", "qux", single_stage=True) - assert not mock_collect_graph.called + assert not mock_build_graph.called # check that our hack can be disabled - mock_collect_graph.reset_mock() + mock_build_graph.reset_mock() dvc._skip_graph_checks = False tmp_dir.gen("quux", "quux text") dvc.add("quux") run_copy("quux", "quuz", single_stage=True) - assert mock_collect_graph.called + assert mock_build_graph.called def test_branch_config(tmp_dir, scm):