diff --git a/src/macrogen/graph.py b/src/macrogen/graph.py index 91da7ef..919ce09 100644 --- a/src/macrogen/graph.py +++ b/src/macrogen/graph.py @@ -6,7 +6,7 @@ from datetime import date, timedelta from io import TextIOWrapper from pathlib import Path -from typing import List, Any, Dict, Tuple, Union, Sequence, Optional, Set +from typing import List, Any, Dict, Tuple, Union, Sequence, Optional, Set, Iterable from warnings import warn from zipfile import ZipFile, ZIP_DEFLATED @@ -16,6 +16,7 @@ from .bibliography import BiblSource from .config import config from .datings import build_datings_graph +from macrogen.graphutils import simplify_timeline from .fes import eades, FES_Baharev, V from .graphutils import expand_edges, collapse_edges_by_source, add_iweight from .igraph_wrapper import to_igraph, nx_edges @@ -305,9 +306,136 @@ def _load_from(self, load_from: Path): except nx.NetworkXError as e: logger.info('Could not remove %s→%s (%d): %s', u, v, k, e) check_acyclic(self.dag, - f'Base graph from {load_from} is not acyclic after removing conflicting and ignored edges.') + f'Base graph from {load_from} is not acyclic after removing conflicting and ignored edges.') self.closure = nx.transitive_closure(self.dag) + def node(self, spec: Union[Reference, date, str]): + """ + Returns a node from the graph. + Args: + spec: A reference to the node. Can be a node or label or uri. + + Returns: + a single node. + + Raises: + KeyError if no node can be found + """ + + def first(iterable): + iterator = iter(iterable) + item = next(iterator) + try: + second = next(iterator) + logger.warning('There should be only %s in iterable, but there was more (first: %s)', item, second) + except StopIteration: + pass + return item + + try: + if isinstance(spec, Reference) or isinstance(spec, date): + return first(node for node in self.base.nodes if node == spec) + if spec.startswith('faust://'): + ref = Witness.get(spec) + return first(node for node in self.base.nodes if node == ref) + else: + return first(ref for ref in self.base.nodes if isinstance(ref, Reference) + and (ref.uri == 'faust://document/faustedition/' + spec + or ref.label == spec)) + + except StopIteration: + raise KeyError("No node matching {!r} in the base graph.".format(spec)) + + def add_path(self, graph: nx.MultiDiGraph, source: Node, target: Node, weight='iweight', method='dijkstra', + must_exist=False, edges_from: Optional[nx.MultiDiGraph] = None): + """ + Finds the shortest path from source to target in the base graph and adds it to the given graph. + + Args: + graph: The graph in which to add the path. This is modified, of course. + source: source node + target: target node + weight: Attribute name to be used as weight + method: see `nx.shortest_path` + must_exist: if True, raise an exception if there is no path from source to target + + Returns: + the path as list of nodes, if any. + """ + try: + if edges_from is None: + edges_from = self.base + path = nx.shortest_path(edges_from, source, target, weight, method) + edges = expand_edges(edges_from, nx.utils.pairwise(path)) + graph.add_edges_from(edges) + except nx.NetworkXNoPath as e: + if must_exist: + raise e + return path + + def subgraph(self, *nodes: Node, context: bool = True, path_to: Iterable[Node] = {}, abs_dates: bool=True, + path_from: Iterable[Node] = {}, pathes: Iterable[Node] = {}, keep_timeline=False) -> nx.MultiDiGraph: + """ + Extracts a sensible subgraph from the base graph. + + Args: + *nodes: node or nodes to include in the graph + context: If true, we add the context of the given node (or nodes), i.e. the predecessors and successors from + the dag + path_to: Nodes to which the shortest path should be included, if any + path_from: Node from which the shortest path should be included, if any + pathes: Node(s) from / to which the spp should be included, if any + + Description: + This method can be used to extract an 'interesting' subgraph around one or more nodes from the base + graph. The resulting graph is constructed as follows: + + 1. The given nodes in the positional arguments are added to the set of relevant nodes. + 2. If `context` is True, the set of relevant nodes is extended by all direct neighbours in the base graph. + 3. The subgraph induced by the relevant nodes is extracted. + 4. If `abs_dates` is True, for each node in nodes, we look for the closest earlier and later dating node. + If it is not present, we add the shortest path to it. + 5. We add the shortest paths from all nodes in path_from and pathes to each node in nodes. + 6. We add the shortest paths from each node in nodes to each node in pathes and path_from. + + Returns: + The constructed subgraph + """ + central_nodes = set(nodes) + relevant_nodes = set(central_nodes) + if context: + for node in central_nodes: + relevant_nodes |= set(self.dag.pred[node]).union(self.dag.succ[node]) + + subgraph = nx.subgraph(self.base, relevant_nodes).copy() + sources = set(path_from).union(pathes) + targets = set(path_from).union(pathes) + + for node in central_nodes: + if abs_dates: + prev = max((d for d in self.closure.pred[node] if isinstance(d, date)), default=None) + if prev is not None and prev not in central_nodes: + self.add_path(subgraph, prev, node, edges_from=self.dag) + next_ = min((d for d in self.closure.succ[node] if isinstance(d, date))) + if next_ is not None and next not in central_nodes: + self.add_path(subgraph, node, next_, edges_from=self.dag) + + for source in sources: + self.add_path(subgraph, source, node) + for target in targets: + self.add_path(subgraph, node, target) + + if not keep_timeline: + subgraph = simplify_timeline(subgraph) + + return subgraph + + + + +def macrogenesis_graphs() -> MacrogenesisInfo: + warn("macrogenesis_graphs() is deprecated, instantiate MacrogenesisInfo directly instead", DeprecationWarning) + return MacrogenesisInfo() def scc_subgraphs(graph: nx.MultiDiGraph) -> List[nx.MultiDiGraph]: @@ -505,13 +633,14 @@ def __init__(self, graphs: MacrogenesisInfo, edge: MultiEdge): self.shortest_path = nx.shortest_path(graphs.base, self.v, self.u, weight='iweight') self.sp_weight = sum(attr.get('weight', 0) for u, v, k, attr in expand_edges(graphs.base, nx.utils.pairwise( - self.shortest_path, - cyclic=True))) + self.shortest_path, + cyclic=True))) self.involved_cycles = {cycle for cycle in graphs.simple_cycles if in_path((self.u, self.v), cycle, True)} - counter_edges = {edge for cycle in self.involved_cycles for edge in nx.utils.pairwise(cycle)} - {(self.u, self.v)} + counter_edges = {edge for cycle in self.involved_cycles for edge in nx.utils.pairwise(cycle)} - { + (self.u, self.v)} self.removed_sources = {attr['source'] for u, v, k, attr in self.parallel_edges} - def stat(self): # TODO better name + def stat(self): # TODO better name return dict( u=self.u, v=self.v, @@ -521,9 +650,3 @@ def stat(self): # TODO better name involved_cycles=len(self.involved_cycles), removed_source_count=len(self.removed_sources) ) - - - -def macrogenesis_graphs() -> MacrogenesisInfo: - warn("macrogenesis_graphs() is deprecated, instantiate MacrogenesisInfo directly instead", DeprecationWarning) - return MacrogenesisInfo()