In [139]:
import plotly.graph_objects as go
import networkx as nx
import random

# Make directed graph
G = nx.DiGraph()

# Define nodes 1-9
G.add_nodes_from(range(1, 10))

# Node types
disliked_nodes = [3]
transaction_nodes = [2, 5, 9]
data_nodes = [1, 4, 6, 7, 8]

# Define edges (child -> parent, strength of reference)
edges = [
    (2, 1, 'strong'),  
    (3, 1, 'strong'),
    (4, 1, 'strong'),  
    (4, 2, 'strong'),  
    (5, 3, 'strong'),  
    (5, 4, 'strong'),  
    (6, 3, 'strong'),  
    (6, 4, 'strong'),  
    (7, 4, 'strong'),   
    (7, 6, 'weak'),   
    (8, 2, 'strong'),    
    (8, 4, 'strong'),
    # Can play around connecting node 9 to 7 and 8, making the node weak or strong
    # 8 is the stronger path, may pick 7 if congested
    (9, 7, 'strong'),   
]

# Build edges
for child, parent, ref_type in edges:
    G.add_edge(child, parent, reference=ref_type)

# Identify node strength of parent nodes,
# Any nodes that have a strong link to disliked nodes are considered weak
# Any nodes that are in a branch with a disliked node are also weak
node_type = {}
node_weak = {}
for node in G.nodes:
    node_weak[node] = False
    
for node in G.nodes:
    incoming = list(G.in_edges(node, data=True))
    for child, parent, data in incoming:
        if ((parent in disliked_nodes and data['reference'] == 'strong') or node_weak[parent])and not node_weak[child]:
            node_weak[child] = True
    node_type[node] = 'weak' if node_weak[node] else 'strong'

# Assign positions
# Making the y random helps make it more of a mesh shape for visibility
# In the layer map its node:layer, where a higher layer is newer in time (moves right in the graph)
random.seed(42)
layer_map = {
    1: 0,
    2: 1, 3: 1,
    4: 2, 5: 2,
    6: 3, 7: 4,
    8: 4, 9: 4
}
pos = {}
for node, layer in layer_map.items():
    x = layer * 2.5
    y = random.uniform(-1, 1)
    pos[node] = (x, y)
nx.set_node_attributes(G, pos, 'pos')

In [140]:
# Set visibility and colour of nodes and edges (i.e. traces of them)
# Edge traces
edge_x, edge_y, edge_styles = [], [], []

for u, v, data in G.edges(data=True):
    x0, y0 = G.nodes[u]['pos']
    x1, y1 = G.nodes[v]['pos']
    edge_x.append([x0, x1, None])
    edge_y.append([y0, y1, None])
    edge_styles.append(data['reference'])

# Changes colours and styles of edges depending on strength of edge
def make_edge_trace(style):
    x_vals = [edge_x[i] for i in range(len(edge_styles)) if edge_styles[i] == style]
    y_vals = [edge_y[i] for i in range(len(edge_styles)) if edge_styles[i] == style]
    color = '#888' if style == 'strong' else 'firebrick'
    dash = 'solid' if style == 'strong' else 'dot'
    return go.Scatter(
        x=sum(x_vals, []), y=sum(y_vals, []),
        line=dict(width=1, color=color, dash=dash),
        hoverinfo='none',
        mode='lines'
    )

# Node trace
node_x, node_y, node_colors, node_borders, node_texts, hover_texts = [], [], [], [], [], []

for node in G.nodes():
    x, y = G.nodes[node]['pos']
    node_x.append(x)
    node_y.append(y)
    
    # Color by type
    if node in disliked_nodes:
        border_color = 'firebrick'
    elif node in transaction_nodes:
        border_color = 'royalblue'
    else:
        border_color = 'gray'
    
    num_edges = 0
    for edge_pair in G.edges:
        if node in edge_pair:
            num_edges += 1
    
    # Changes colours and styles of node borders depending on strength of node
    node_colors.append('#eee' if node_type[node] == 'strong' else '#ddd')
    node_borders.append(border_color)
    node_texts.append(f"Node {node}<br>{node_type[node].capitalize()}")
    hover_texts.append(f"{num_edges} Edge(s)")

# Put nodes in a scatter plot and give text labels on graph
node_trace = go.Scatter(
    x=node_x, y=node_y,
    mode='markers+text',
    hoverinfo='text',
    hovertext=hover_texts,
    text=node_texts,
    textposition="top center",
    marker=dict(
        color=node_colors,
        size=30,
        line=dict(width=2, color=node_borders)
    )
)

In [141]:
# Make the figure
fig = go.Figure(data=[
    make_edge_trace('strong'),
    make_edge_trace('weak'),
    node_trace
],
layout=go.Layout(
    title="Directed Acyclic Graph",
    showlegend=False,
    hovermode='closest',
    margin=dict(b=20, l=5, r=5, t=40),
    annotations=[dict(
        text="Legend: Gray = Data | Blue = Transaction | Red = Disliked",
        showarrow=False, xref="paper", yref="paper",
        x=0.01, y=-0.1
    )],
    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
))

fig.show()

In [142]:
assert nx.is_directed_acyclic_graph(G), "Graph must be a DAG"