Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 72 additions & 62 deletions dvc/command/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging

from dvc.command.base import CmdBase, append_doc_link
from dvc.exceptions import DvcException

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -30,80 +29,91 @@ def _show_dot(G):
return dot_file.getvalue()


def _build(G, target=None, full=False, outs=False):
def _collect_targets(repo, target, outs):
if not target:
return []

pairs = repo.collect_granular(target)
if not outs:
return [stage.addressing for stage, _ in pairs]

targets = []
for stage, info in pairs:
if not info:
targets.extend([str(out) for out in stage.outs])
continue

for out in repo.outs_trie.itervalues(prefix=info.parts): # noqa: B301
targets.extend(str(out))

return targets


def _transform(repo, outs):
import networkx as nx

from dvc.repo.graph import get_pipeline, get_pipelines
if outs:
G = repo.outs_graph

def _relabel(out):
return str(out)

if target:
H = get_pipeline(get_pipelines(G), target)
if not full:
descendants = nx.descendants(G, target)
descendants.add(target)
H.remove_nodes_from(set(G.nodes()) - descendants)
else:
H = G
G = repo.graph

if outs:
G = nx.DiGraph()
for stage in H.nodes:
G.add_nodes_from(stage.outs)
def _relabel(stage):
return stage.addressing
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole _transform method is only here because str(stage) is not using stage.addressing. Need to consider changing Stage.__str__ to it.


for from_stage, to_stage in nx.edge_dfs(H):
G.add_edges_from(
[
(from_out, to_out)
for from_out in from_stage.outs
for to_out in to_stage.outs
]
)
H = G
return nx.relabel_nodes(G, _relabel, copy=True)

def _relabel(node):
from dvc.stage import Stage

return node.addressing if isinstance(node, Stage) else str(node)
def _filter(G, targets, full):
import networkx as nx

return nx.relabel_nodes(H, _relabel, copy=False)
if not targets:
return G

H = G.copy()
if not full:
descendants = set()
for target in targets:
descendants.update(nx.descendants(G, target))
descendants.add(target)
H.remove_nodes_from(set(G.nodes()) - descendants)

undirected = H.to_undirected()
connected = set()
for target in targets:
connected.update(nx.node_connected_component(undirected, target))

H.remove_nodes_from(set(H.nodes()) - connected)

return H


def _build(repo, target=None, full=False, outs=False):
targets = _collect_targets(repo, target, outs)
G = _transform(repo, outs)
return _filter(G, targets, full)


class CmdDAG(CmdBase):
def run(self):
try:
target = None
if self.args.target:
stages = self.repo.collect(self.args.target)
if len(stages) > 1:
logger.error(
f"'{self.args.target}' contains more than one stage "
"{stages}, please specify one stage"
)
return 1
target = stages[0]

G = _build(
self.repo.graph,
target=target,
full=self.args.full,
outs=self.args.outs,
)

if self.args.dot:
logger.info(_show_dot(G))
else:
from dvc.utils.pager import pager

pager(_show_ascii(G))

return 0
except DvcException:
msg = "failed to show "
if self.args.target:
msg += f"a pipeline for '{target}'"
else:
msg += "pipelines"
logger.exception(msg)
return 1
Comment on lines -99 to -106
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated πŸ™ , but we are gradually going away from KO errors like this in other places, so did it here too.

G = _build(
self.repo,
target=self.args.target,
full=self.args.full,
outs=self.args.outs,
)

if self.args.dot:
logger.info(_show_dot(G))
else:
from dvc.utils.pager import pager

pager(_show_ascii(G))

return 0


def add_parser(subparsers, parent_parser):
Expand Down
122 changes: 14 additions & 108 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from contextlib import contextmanager
from functools import wraps

from funcy import cached_property, cat, first
from funcy import cached_property, cat
from git import InvalidGitRepositoryError

from dvc.config import Config
Expand All @@ -23,7 +23,8 @@

from ..stage.exceptions import StageFileDoesNotExistError, StageNotFound
from ..utils import parse_target
from .graph import check_acyclic, get_pipeline, get_pipelines
from .graph import build_graph, build_outs_graph, get_pipeline, get_pipelines
from .trie import build_outs_trie

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -289,7 +290,7 @@ def check_modified_graph(self, new_stages):
#
# [1] https://github.com/iterative/dvc/issues/2671
if not getattr(self, "_skip_graph_checks", False):
self._collect_graph(self.stages + new_stages)
build_graph(self.stages + new_stages)

def _collect_inside(self, path, graph):
import networkx as nx
Expand Down Expand Up @@ -448,114 +449,17 @@ def used_cache(

return cache

def _collect_graph(self, stages):
"""Generate a graph by using the given stages on the given directory

The nodes of the graph are the stage's path relative to the root.

Edges are created when the output of one stage is used as a
dependency in other stage.

The direction of the edges goes from the stage to its dependency:

For example, running the following:

$ dvc run -o A "echo A > A"
$ dvc run -d A -o B "echo B > B"
$ dvc run -d B -o C "echo C > C"

Will create the following graph:

ancestors <--
|
C.dvc -> B.dvc -> A.dvc
| |
| --> descendants
|
------- pipeline ------>
|
v
(weakly connected components)

Args:
stages (list): used to build a graph, if None given, collect stages
in the repository.

Raises:
OutputDuplicationError: two outputs with the same path
StagePathAsOutputError: stage inside an output directory
OverlappingOutputPathsError: output inside output directory
CyclicGraphError: resulting graph has cycles
"""
import networkx as nx
from pygtrie import Trie

from dvc.exceptions import (
OutputDuplicationError,
OverlappingOutputPathsError,
StagePathAsOutputError,
)

G = nx.DiGraph()
stages = stages or self.stages
outs = Trie() # Use trie to efficiently find overlapping outs and deps

for stage in filter(bool, stages): # bug? not using it later
for out in stage.outs:
out_key = out.path_info.parts

# Check for dup outs
if out_key in outs:
dup_stages = [stage, outs[out_key].stage]
raise OutputDuplicationError(str(out), dup_stages)

# Check for overlapping outs
if outs.has_subtrie(out_key):
parent = out
overlapping = first(outs.values(prefix=out_key))
else:
parent = outs.shortest_prefix(out_key).value
overlapping = out
if parent and overlapping:
msg = (
"Paths for outs:\n'{}'('{}')\n'{}'('{}')\n"
"overlap. To avoid unpredictable behaviour, "
"rerun command with non overlapping outs paths."
).format(
str(parent),
parent.stage.addressing,
str(overlapping),
overlapping.stage.addressing,
)
raise OverlappingOutputPathsError(parent, overlapping, msg)

outs[out_key] = out

for stage in stages:
out = outs.shortest_prefix(PathInfo(stage.path).parts).value
if out:
raise StagePathAsOutputError(stage, str(out))

# Building graph
G.add_nodes_from(stages)
for stage in stages:
for dep in stage.deps:
if dep.path_info is None:
continue

dep_key = dep.path_info.parts
overlapping = [n.value for n in outs.prefixes(dep_key)]
if outs.has_subtrie(dep_key):
overlapping.extend(outs.values(prefix=dep_key))

G.add_edges_from((stage, out.stage) for out in overlapping)
check_acyclic(G)

return G
@cached_property
def outs_trie(self):
return build_outs_trie(self.stages)

@cached_property
def graph(self):
return self._collect_graph(self.stages)
return build_graph(self.stages, self.outs_trie)

@cached_property
def outs_graph(self):
return build_outs_graph(self.graph, self.outs_trie)

@cached_property
def pipelines(self):
Expand Down Expand Up @@ -648,6 +552,8 @@ def close(self):
self.scm.close()

def _reset(self):
self.__dict__.pop("outs_trie", None)
self.__dict__.pop("outs_graph", None)
Comment on lines +555 to +556
Copy link
Contributor Author

@efiop efiop Nov 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of these will be very handy in #4847 and potentially other places. Plus we build some of them as a side-effect when validating the DAG, so we can cache them along the way.

self.__dict__.pop("graph", None)
self.__dict__.pop("stages", None)
self.__dict__.pop("pipelines", None)
Loading