In [None]:
import plotly.graph_objects as go
import networkx as nx
import matplotlib.pyplot as plt

import dash
from dash import dcc
from dash import html
import json
from dash.dependencies import Input, Output

In [None]:
import math
from typing import List
from itertools import chain


# Start and end are lists defining start and end points
# Edge x and y are lists used to construct the graph
# arrowAngle and arrowLength define properties of the arrowhead
# arrowPos is None, 'middle' or 'end' based on where on the edge you want the arrow to appear
# arrowLength is the length of the arrowhead
# arrowAngle is the angle in degrees that the arrowhead makes with the edge
# dotSize is the plotly scatter dot size you are using (used to even out line spacing when you have a mix of edge lengths)
def addEdge(start, end, edge_x, edge_y, lengthFrac=0.9, arrowPos = "end",
            arrowLength=0.025, arrowAngle = 30, dotSize=15):

    # Get start and end cartesian coordinates
    x0, y0 = start
    x1, y1 = end

    # Incorporate the fraction of this segment covered by a dot into total reduction
    length = math.sqrt( (x1-x0)**2 + (y1-y0)**2 )
    dotSizeConversion = .0565/20 # length units per dot size
    convertedDotDiameter = dotSize * dotSizeConversion
    lengthFracReduction = convertedDotDiameter / length
    lengthFrac = lengthFrac - lengthFracReduction

    # If the line segment should not cover the entire distance, get actual start and end coords
    skipX = (x1-x0)*(1-lengthFrac)
    skipY = (y1-y0)*(1-lengthFrac)
    x0 = x0 + skipX/2
    x1 = x1 - skipX/2
    y0 = y0 + skipY/2
    y1 = y1 - skipY/2

    # Append line corresponding to the edge
    edge_x.append(x0)
    edge_x.append(x1)
    edge_x.append(None) # Prevents a line being drawn from end of this edge to start of next edge
    edge_y.append(y0)
    edge_y.append(y1)
    edge_y.append(None)

    # Draw arrow
    if not arrowPos == None:

        # Find the point of the arrow; assume is at end unless told middle
        pointx = x1
        pointy = y1

        eta = math.degrees(math.atan((x1-x0)/(y1-y0))) if y1!=y0 else 90.0

        if arrowPos == 'middle' or arrowPos == 'mid':
            pointx = x0 + (x1-x0)/2
            pointy = y0 + (y1-y0)/2

        # Find the directions the arrows are pointing
        signx = (x1-x0)/abs(x1-x0) if x1!=x0 else +1    #verify this once
        signy = (y1-y0)/abs(y1-y0) if y1!=y0 else +1    #verified

        # Append first arrowhead
        dx = arrowLength * math.sin(math.radians(eta + arrowAngle)) 
        dy = arrowLength * math.cos(math.radians(eta + arrowAngle)) 
        edge_x.append(pointx)
        edge_x.append(pointx - signx**2 * signy * dx)
        edge_x.append(None)
        edge_y.append(pointy)
        edge_y.append(pointy - signx**2 * signy * dy)
        edge_y.append(None)

        # And second arrowhead
        dx = arrowLength * math.sin(math.radians(eta - arrowAngle))
        dy = arrowLength * math.cos(math.radians(eta - arrowAngle))
        edge_x.append(pointx)
        edge_x.append(pointx - signx**2 * signy * dx)
        edge_x.append(None)
        edge_y.append(pointy)
        edge_y.append(pointy - signx**2 * signy * dy)
        edge_y.append(None)


    return edge_x, edge_y


In [None]:
def load_json(json_path= 'data.json'):
    with open('data.json', 'rb') as j:
        data = json.load(j)
    return data

In [None]:
def createDiGraphFromJson(data):
    G = nx.DiGraph()
    G.add_nodes_from(data["pipelines"][0]["nodes"], color = "red")
    edges = list(map(tuple, data["pipelines"][0]["edges"]))
    G.add_edges_from(edges)
    return G

In [None]:
def create_edge_trace(graph, node_positions):
    edge_x = []
    edge_y = []
    
    for edge in graph.edges():
        start = node_positions[edge[0]]
        end = node_positions[edge[1]]
        edge_x, edge_y = addEdge(start, end, edge_x, edge_y)
        edge_x, edge_y = addEdge(start, end, edge_x, edge_y)
    
    edge_trace = go.Scatter(
    x=edge_x, y=edge_y,
    line=dict(width=0.9, color='black'),
    hoverinfo='none',
    mode='lines')
    
    return edge_trace  

In [None]:
def create_node_traces(G,node_positions, data):    
    node_x_green = []
    node_y_green = []
    node_x_red = []
    node_y_red = []
    node_x_yellow = []
    node_y_yellow = []
    lg = []
    
    for ix, node in enumerate(G.nodes()):
        if data["pipelines"][0]["status"][ix] == "running":
            x, y = node_positions[node]
            node_x_green.append(x)
            node_y_green.append(y)
            
        elif data["pipelines"][0]["status"][ix] == "pending":
            x, y = node_positions[node]
            node_x_yellow.append(x)
            node_y_yellow.append(y)
            
        elif data["pipelines"][0]["status"][ix] == "failed":
            x, y = node_positions[node]
            node_x_red.append(x)
            node_y_red.append(y)
            
    

    node_trace_green = go.Scatter(
        x=node_x_green, y=node_y_green,
        mode='markers',
        marker_symbol = "circle",
        marker_size=60,
        marker_color="green",
        line_width=5,
        name = "Done"
    )
    
    node_trace_red = go.Scatter(
        x=node_x_red, y=node_y_red,
        mode='markers',
        marker_symbol = "circle",
        marker_size=60,
        marker_color="red",
        line_width=5,
        name = "Failed")
        
    node_trace_yellow = go.Scatter(
        x=node_x_yellow, y=node_y_yellow,
        mode='markers',
        marker_symbol = "circle",
        marker_size=60,
        marker_color="yellow",
        line_width=5,
        name = "Pending")
        
    return [node_trace_yellow, node_trace_red, node_trace_green]

In [None]:
app = dash.Dash()
app.layout = html.Div(
    html.Div([
        html.H4('Test'),
#         html.Div(id='live-update-text'),
        dcc.Graph(id='live-update-graph'),
        dcc.Interval(
            id='interval-component',
            interval=1*1000, # in milliseconds
            n_intervals=0
        )
    ])
)

@app.callback(Output('live-update-graph', 'figure'),
              Input('interval-component', 'n_intervals'))
def update(n):
    data = load_json()
    G = createDiGraphFromJson(data)
    node_positions = nx.spring_layout(G, seed=41)
    edge_trace = create_edge_trace(G, node_positions)
    node_traces = create_node_traces(G, node_positions, data)
    traces = [edge_trace] + node_traces
    
    fig = go.Figure(data=traces,
             layout=go.Layout(
                title='<br>Network graph made with Python',
                titlefont_size=16,
                showlegend=True,
#                 hovermode='closest',
                margin=dict(b=30,l=10,r=10,t=40),
                annotations=[ dict(
                    text = i, 
                    showarrow = False,
                    x = node_positions[i][0],
                    y = node_positions[i][1]) for i in node_positions.keys()],
                    
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )    
    return fig

app.run_server(debug=True, use_reloader=False)