In [1]:
import networkx as nx

In [4]:
def initial_nodes(G):
    nodes = list(G.nodes)
    nodes_with_indeg_0 = filter(lambda v: G.in_degree(v) == 0, nodes)
    return list(nodes_with_indeg_0)

In [5]:
# Given an acyclic graph G, construct a new graph with a root vertex, whose edges connect to the nodes in G with indegree 0
def augmented_DAG_by_root(G):
    # Does not augment if G already has a root
    if not nx.is_directed_acyclic_graph(G):
        print("The directed graph must be acyclic.")
        return
    
    G1 = G.copy()
    head_nodes = initial_nodes(G1)
    if len(head_nodes) == 1:
        return G1

    G1.add_node('*')
    arb_edge_list = [('*', v) for v in head_nodes]
    G1.add_edges_from(arb_edge_list)
    return G1

In [6]:
# Determines a spanning arborescence for an unweighted directed graph G - must be acyclic with a unique vertex of in-degree 0
def spanning_arborescence(G):
    if not (nx.is_directed_acyclic_graph(G) and len(initial_nodes(G)) == 1):
        print("The directed graph must be an arborescence.")
        return

    root = initial_nodes(G)[0]
    undiscovered_nodes = list(G.nodes)
    undiscovered_nodes.remove(root)
    current_node_layer = [root]
    spanning_arb_edges = []

    while len(current_node_layer) > 0:
        # Construct the next node layer using only undiscovered neighbors of the nodes in current_node_layer
        new_nodes = []
        for n in current_node_layer:
            neighbors = list(nx.neighbors(G,n))
            neighbors = list(filter(lambda s: s in undiscovered_nodes, neighbors))

            for s in neighbors:
                new_nodes.append(s)
                spanning_arb_edges.append((n,s))
                undiscovered_nodes.remove(s)
        current_node_layer = new_nodes.copy()

    spanning_arb = nx.DiGraph(spanning_arb_edges)
    return spanning_arb    

In [None]:
G = nx.DiGraph()
G.add_edges_from([(1,2),(1,3),(2,4),(3,4),(5,6),(6,7),(7,4)])

In [7]:
H = augmented_DAG_by_root(G)

In [16]:
list(H.edges)

[(1, 2), (1, 3), (2, 4), (3, 4), (5, 6), (6, 7), (7, 4), ('*', 1), ('*', 5)]

In [13]:
K = spanning_arborescence(H)

In [15]:
list(K.edges)

[('*', 1), ('*', 5), (1, 2), (1, 3), (5, 6), (2, 4), (6, 7)]