Skip to content

Commit

Permalink
Merge fb3b8bb into 10db27f
Browse files Browse the repository at this point in the history
  • Loading branch information
ncilfone authored Jan 11, 2022
2 parents 10db27f + fb3b8bb commit 846062c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
2 changes: 1 addition & 1 deletion spock/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _get_general_arguments(arguments: dict, config_dag: Graph):
dictionary of general level parameters
"""
config_names = {n.__name__ for n in config_dag.nodes}
config_names = config_dag.node_names
return {
key: value
for key, value in arguments.items()
Expand Down
37 changes: 25 additions & 12 deletions spock/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,29 @@ def __init__(self, input_classes):
"Please correct your @spock decorated classes by removing any cyclic references"
)

@property
def dag(self):
"""Returns the DAG"""
return self._dag

@property
def nodes(self):
"""Returns the node names/input_classes"""
"""Returns the input_classes/nodes"""
return self._input_classes

@property
def node_names(self):
"""Returns the node names"""
return {f"{k.__name__}" for k in self.nodes}

@property
def node_map(self):
return {f"{k.__name__}": k for k in self.nodes}

@property
def roots(self):
"""Returns the roots of the dependency graph"""
return [k for k, v in self._dag.items() if len(v) == 0]
return [self.node_map[k] for k, v in self.dag.items() if len(v) == 0]

def _build(self):
"""Builds a dictionary of nodes and their edges (essentially builds the DAG)
Expand All @@ -55,15 +69,14 @@ def _build(self):
"""
# Build a dictionary of all nodes (base spock classes)
nodes = {val: [] for val in self._input_classes}
node_names = [f"{k.__module__}.{k.__name__}" for k in nodes.keys()]
nodes = {val: [] for val in self.node_names}
# Iterate thorough all of the base spock classes to get the dependencies and reverse dependencies
for input_class in self._input_classes:
dep_classes = _find_all_spock_classes(input_class)
for v in dep_classes:
if f"{v.__module__}.{v.__name__}" not in node_names:
for input_class in self.nodes:
dep_names = {f"{v.__name__}" for v in _find_all_spock_classes(input_class)}
for v in dep_names:
if v not in self.node_names:
raise ValueError(
f"Missing @spock decorated class -- `{v.__name__}` was not passed as an *arg to "
f"Missing @spock decorated class -- `{v}` was not passed as an *arg to "
f"ConfigArgBuilder"
)
nodes.get(v).append(input_class)
Expand All @@ -78,9 +91,9 @@ def _has_cycles(self):
"""
# DFS w/ recursion stack for DAG cycle detection
visited = {key: False for key in self._dag.keys()}
visited = {key: False for key in self.dag.keys()}
all_nodes = list(visited.keys())
recursion_stack = {key: False for key in self._dag.keys()}
recursion_stack = {key: False for key in self.dag.keys()}
# Recur for all edges
for node in all_nodes:
if visited.get(node) is False:
Expand All @@ -106,7 +119,7 @@ def _cycle_dfs(self, node: Type, visited: dict, recursion_stack: dict):
# Update recursion stack
recursion_stack.update({node: True})
# Recur through the edges
for val in self._dag.get(node):
for val in self.dag.get(node):
if visited.get(val) is False:
# The the recursion returns True then work it up the stack
if self._cycle_dfs(val, visited, recursion_stack) is True:
Expand Down

0 comments on commit 846062c

Please sign in to comment.