-
Notifications
You must be signed in to change notification settings - Fork 1.2k
dag: support output as target #4908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,6 @@ | |
import logging | ||
|
||
from dvc.command.base import CmdBase, append_doc_link | ||
from dvc.exceptions import DvcException | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -30,80 +29,91 @@ def _show_dot(G): | |
return dot_file.getvalue() | ||
|
||
|
||
def _build(G, target=None, full=False, outs=False): | ||
def _collect_targets(repo, target, outs): | ||
if not target: | ||
return [] | ||
|
||
pairs = repo.collect_granular(target) | ||
if not outs: | ||
return [stage.addressing for stage, _ in pairs] | ||
|
||
targets = [] | ||
for stage, info in pairs: | ||
if not info: | ||
targets.extend([str(out) for out in stage.outs]) | ||
continue | ||
|
||
for out in repo.outs_trie.itervalues(prefix=info.parts): # noqa: B301 | ||
targets.extend(str(out)) | ||
|
||
return targets | ||
|
||
|
||
def _transform(repo, outs): | ||
import networkx as nx | ||
|
||
from dvc.repo.graph import get_pipeline, get_pipelines | ||
if outs: | ||
G = repo.outs_graph | ||
|
||
def _relabel(out): | ||
return str(out) | ||
|
||
if target: | ||
H = get_pipeline(get_pipelines(G), target) | ||
if not full: | ||
descendants = nx.descendants(G, target) | ||
descendants.add(target) | ||
H.remove_nodes_from(set(G.nodes()) - descendants) | ||
else: | ||
H = G | ||
G = repo.graph | ||
|
||
if outs: | ||
G = nx.DiGraph() | ||
for stage in H.nodes: | ||
G.add_nodes_from(stage.outs) | ||
def _relabel(stage): | ||
return stage.addressing | ||
|
||
for from_stage, to_stage in nx.edge_dfs(H): | ||
G.add_edges_from( | ||
[ | ||
(from_out, to_out) | ||
for from_out in from_stage.outs | ||
for to_out in to_stage.outs | ||
] | ||
) | ||
H = G | ||
return nx.relabel_nodes(G, _relabel, copy=True) | ||
|
||
def _relabel(node): | ||
from dvc.stage import Stage | ||
|
||
return node.addressing if isinstance(node, Stage) else str(node) | ||
def _filter(G, targets, full): | ||
import networkx as nx | ||
|
||
return nx.relabel_nodes(H, _relabel, copy=False) | ||
if not targets: | ||
return G | ||
|
||
H = G.copy() | ||
if not full: | ||
descendants = set() | ||
for target in targets: | ||
descendants.update(nx.descendants(G, target)) | ||
descendants.add(target) | ||
H.remove_nodes_from(set(G.nodes()) - descendants) | ||
|
||
undirected = H.to_undirected() | ||
connected = set() | ||
for target in targets: | ||
connected.update(nx.node_connected_component(undirected, target)) | ||
|
||
H.remove_nodes_from(set(H.nodes()) - connected) | ||
|
||
return H | ||
|
||
|
||
def _build(repo, target=None, full=False, outs=False): | ||
targets = _collect_targets(repo, target, outs) | ||
G = _transform(repo, outs) | ||
return _filter(G, targets, full) | ||
|
||
|
||
class CmdDAG(CmdBase): | ||
def run(self): | ||
try: | ||
target = None | ||
if self.args.target: | ||
stages = self.repo.collect(self.args.target) | ||
if len(stages) > 1: | ||
logger.error( | ||
f"'{self.args.target}' contains more than one stage " | ||
"{stages}, please specify one stage" | ||
) | ||
return 1 | ||
target = stages[0] | ||
|
||
G = _build( | ||
self.repo.graph, | ||
target=target, | ||
full=self.args.full, | ||
outs=self.args.outs, | ||
) | ||
|
||
if self.args.dot: | ||
logger.info(_show_dot(G)) | ||
else: | ||
from dvc.utils.pager import pager | ||
|
||
pager(_show_ascii(G)) | ||
|
||
return 0 | ||
except DvcException: | ||
msg = "failed to show " | ||
if self.args.target: | ||
msg += f"a pipeline for '{target}'" | ||
else: | ||
msg += "pipelines" | ||
logger.exception(msg) | ||
return 1 | ||
|
||
G = _build( | ||
self.repo, | ||
target=self.args.target, | ||
full=self.args.full, | ||
outs=self.args.outs, | ||
) | ||
|
||
if self.args.dot: | ||
logger.info(_show_dot(G)) | ||
else: | ||
from dvc.utils.pager import pager | ||
|
||
pager(_show_ascii(G)) | ||
|
||
return 0 | ||
|
||
|
||
def add_parser(subparsers, parent_parser): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
from contextlib import contextmanager | ||
from functools import wraps | ||
|
||
from funcy import cached_property, cat, first | ||
from funcy import cached_property, cat | ||
from git import InvalidGitRepositoryError | ||
|
||
from dvc.config import Config | ||
|
@@ -23,7 +23,8 @@ | |
|
||
from ..stage.exceptions import StageFileDoesNotExistError, StageNotFound | ||
from ..utils import parse_target | ||
from .graph import check_acyclic, get_pipeline, get_pipelines | ||
from .graph import build_graph, build_outs_graph, get_pipeline, get_pipelines | ||
from .trie import build_outs_trie | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -289,7 +290,7 @@ def check_modified_graph(self, new_stages): | |
# | ||
# [1] https://github.com/iterative/dvc/issues/2671 | ||
if not getattr(self, "_skip_graph_checks", False): | ||
self._collect_graph(self.stages + new_stages) | ||
build_graph(self.stages + new_stages) | ||
|
||
def _collect_inside(self, path, graph): | ||
import networkx as nx | ||
|
@@ -448,114 +449,17 @@ def used_cache( | |
|
||
return cache | ||
|
||
def _collect_graph(self, stages): | ||
"""Generate a graph by using the given stages on the given directory | ||
|
||
The nodes of the graph are the stage's path relative to the root. | ||
|
||
Edges are created when the output of one stage is used as a | ||
dependency in other stage. | ||
|
||
The direction of the edges goes from the stage to its dependency: | ||
|
||
For example, running the following: | ||
|
||
$ dvc run -o A "echo A > A" | ||
$ dvc run -d A -o B "echo B > B" | ||
$ dvc run -d B -o C "echo C > C" | ||
|
||
Will create the following graph: | ||
|
||
ancestors <-- | ||
| | ||
C.dvc -> B.dvc -> A.dvc | ||
| | | ||
| --> descendants | ||
| | ||
------- pipeline ------> | ||
| | ||
v | ||
(weakly connected components) | ||
|
||
Args: | ||
stages (list): used to build a graph, if None given, collect stages | ||
in the repository. | ||
|
||
Raises: | ||
OutputDuplicationError: two outputs with the same path | ||
StagePathAsOutputError: stage inside an output directory | ||
OverlappingOutputPathsError: output inside output directory | ||
CyclicGraphError: resulting graph has cycles | ||
""" | ||
import networkx as nx | ||
from pygtrie import Trie | ||
|
||
from dvc.exceptions import ( | ||
OutputDuplicationError, | ||
OverlappingOutputPathsError, | ||
StagePathAsOutputError, | ||
) | ||
|
||
G = nx.DiGraph() | ||
stages = stages or self.stages | ||
outs = Trie() # Use trie to efficiently find overlapping outs and deps | ||
|
||
for stage in filter(bool, stages): # bug? not using it later | ||
for out in stage.outs: | ||
out_key = out.path_info.parts | ||
|
||
# Check for dup outs | ||
if out_key in outs: | ||
dup_stages = [stage, outs[out_key].stage] | ||
raise OutputDuplicationError(str(out), dup_stages) | ||
|
||
# Check for overlapping outs | ||
if outs.has_subtrie(out_key): | ||
parent = out | ||
overlapping = first(outs.values(prefix=out_key)) | ||
else: | ||
parent = outs.shortest_prefix(out_key).value | ||
overlapping = out | ||
if parent and overlapping: | ||
msg = ( | ||
"Paths for outs:\n'{}'('{}')\n'{}'('{}')\n" | ||
"overlap. To avoid unpredictable behaviour, " | ||
"rerun command with non overlapping outs paths." | ||
).format( | ||
str(parent), | ||
parent.stage.addressing, | ||
str(overlapping), | ||
overlapping.stage.addressing, | ||
) | ||
raise OverlappingOutputPathsError(parent, overlapping, msg) | ||
|
||
outs[out_key] = out | ||
|
||
for stage in stages: | ||
out = outs.shortest_prefix(PathInfo(stage.path).parts).value | ||
if out: | ||
raise StagePathAsOutputError(stage, str(out)) | ||
|
||
# Building graph | ||
G.add_nodes_from(stages) | ||
for stage in stages: | ||
for dep in stage.deps: | ||
if dep.path_info is None: | ||
continue | ||
|
||
dep_key = dep.path_info.parts | ||
overlapping = [n.value for n in outs.prefixes(dep_key)] | ||
if outs.has_subtrie(dep_key): | ||
overlapping.extend(outs.values(prefix=dep_key)) | ||
|
||
G.add_edges_from((stage, out.stage) for out in overlapping) | ||
check_acyclic(G) | ||
|
||
return G | ||
@cached_property | ||
def outs_trie(self): | ||
return build_outs_trie(self.stages) | ||
|
||
@cached_property | ||
def graph(self): | ||
return self._collect_graph(self.stages) | ||
return build_graph(self.stages, self.outs_trie) | ||
|
||
@cached_property | ||
def outs_graph(self): | ||
return build_outs_graph(self.graph, self.outs_trie) | ||
|
||
@cached_property | ||
def pipelines(self): | ||
|
@@ -648,6 +552,8 @@ def close(self): | |
self.scm.close() | ||
|
||
def _reset(self): | ||
self.__dict__.pop("outs_trie", None) | ||
self.__dict__.pop("outs_graph", None) | ||
|
||
self.__dict__.pop("graph", None) | ||
self.__dict__.pop("stages", None) | ||
self.__dict__.pop("pipelines", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole
_transform
method is only here becausestr(stage)
is not usingstage.addressing
. Need to consider changingStage.__str__
to it.