In [90]:
import networkx as nx
import plotly.graph_objects as go

In [91]:
G = nx.Graph()
nodes = ['ana', 'pericardial_effusion', 'anti_smith_antibody', 'proteinuria', 'joint_involvement', 'seizure', 
         'fever', 'delirium', 'low_c3', 'anti_dsdna_antibody', 'Lupus']
edges = [('ana', 'pericardial_effusion'), ('pericardial_effusion', 'anti_smith_antibody'), 
         ('anti_smith_antibody', 'proteinuria'), ('proteinuria', 'joint_involvement'), 
         ('joint_involvement', 'seizure'), ('seizure', 'fever'), ('fever', 'delirium'), ('delirium', 'low_c3'),
         ('low_c3', 'anti_dsdna_antibody'), ('anti_dsdna_antibody', 'Lupus')]
# nodes = [1, 2, 3, 4, 5, 6]
# edges = [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6)]
G.add_nodes_from(nodes)
G.add_edges_from(edges)

node_colors = ['orange'] +['#C70039']*(len(nodes)-2) + ['green']

edge_labels = {
    ('ana', 'pericardial_effusion'): '1',
    ('pericardial_effusion', 'anti_smith_antibody'): '0',
    ('anti_smith_antibody', 'proteinuria'): '0',
    ('proteinuria', 'joint_involvement'): '1',
    ('joint_involvement', 'seizure'): '0',
    ('seizure', 'fever'): '0',
    ('fever', 'delirium'): '0',
    ('delirium', 'low_c3'): '0',
    ('low_c3', 'anti_dsdna_antibody'): '0/1',
    ('anti_dsdna_antibody', 'Lupus'): '1',
    
}

In [92]:
# pos = nx.circular_layout(G)
pos=nx.fruchterman_reingold_layout(G)
# pos = nx.spring_layout(G)

In [93]:
# Create edge traces
edge_x = []
edge_y = []
for edge in G.edges():
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_x.extend([x0, x1, None])
    edge_y.extend([y0, y1, None])

In [94]:
# Create node traces
node_x = []
node_y = []
for node in G.nodes():
    x, y = pos[node]
    node_x.append(x)
    node_y.append(y)

In [95]:
# Create edge and node traces for the Plotly plot
edge_trace = go.Scatter(
    x=edge_x, y=edge_y,
    line=dict(width=3, color='#888'),
    hoverinfo='none',
    mode='lines',
    textfont=dict(size=14, color='black', family='Arial, sans-serif'), #is not working - why?
    textposition = 'top center',
    )

In [96]:
node_trace = go.Scatter(
    x=node_x, y=node_y,
    mode='markers+text',
    hoverinfo='text',
    marker=dict(showscale=False, colorscale='YlGnBu', size=15, color=node_colors),
    textposition='bottom center',
    text=nodes,
    textfont=dict(size=14, color='black', family='Arial, sans-serif'))

In [97]:
# Create figure
fig = go.Figure(data=[edge_trace, node_trace],
                layout=go.Layout(showlegend=False, hovermode='closest', margin=dict(b=0, l=0, r=0, t=0),
                                 xaxis=dict(showgrid=False, zeroline=False), 
                                 yaxis=dict(showgrid=False, zeroline=False)
                                ))

In [98]:
# Add edge labels
for edge, text in edge_labels.items():
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_x = (x0 + x1) / 2
    edge_y = (y0 + y1) / 2
    fig.add_trace(go.Scatter(x=[edge_x], y=[edge_y], mode='text', text=text, showlegend=False))

In [99]:
fig.show()