<a href="https://colab.research.google.com/github/mylonasc/easy-sankey/blob/master/Easy_Sankey.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## An easier way to create Sankey flow diagrams with `Plotly`

This is a notebook that shows the use of some simple datastructures to create Sankey diagrams using `plotly`. 

### TODO:
Currently the diagrams are only constructed "forward" (from start node to end node) without allowing for recurring loops. Fix that.


In [None]:
!pip install plotly

In [None]:
import plotly.graph_objects as go
import urllib, json
import plotly.express as px

In [None]:
opacity = 0.8
apply_opacity = lambda x : x.replace('(','a(').replace(')',',%2.2f)'%opacity)
pastel_colors = [apply_opacity(c) for c in px.colors.qualitative.Pastel]

In [None]:
class Node:
    def __init__(self, name, value, index,  color = 'rgba(30,42,10,0.5)'):
        self.name = name
        self.value = value
        self.index = index
        self.color = color
        self.in_nodes = []
        self.out_nodes = []
        self.out_edge_weights = []

    def set_out_nodes(self, out_nodes, pct_dist = None):
        if pct_dist is None:
            pct_dist = [1./len(out_nodes)]* len(out_nodes)

        self.out_edge_weights = [self.value * p for p in pct_dist]

        for o, p in zip(out_nodes, pct_dist):
            o.value += self.value * p
            self.out_nodes.append(o)
    
    def dfs(self, max_depth = 100):
        """
        Performs DFS for the graph where this node belongs to.
        
        Returns a list of all nodes, and a list of tuples containing 
        [link_weight,link_source,link_target]
        """
        links = []
        next_nodes = [self]
        node_set = set([self])
        depth = 0
        while len(next_nodes)>0 or depth >= max_depth:
            tmp = []
            for n in next_nodes:
                links.extend([(l, (n.index, o.index)) for o, l in zip(n.out_nodes, n.out_edge_weights)])
                tmp.extend(n.out_nodes)
            node_set.update(tmp)
            next_nodes = tmp
        node_list = sorted(list(node_set), key = lambda n : n.index)
        return node_list, links

n1 = Node('Total CO2',100.,0, color = pastel_colors[0])
#-------------------------
n2 = Node('Compute',0.,1, color = pastel_colors[1])
n3 = Node('Data',0.,2,pastel_colors[1])
n1.set_out_nodes([n2,n3], [0.8,0.3])
#-------------------------
n4 = Node('Training',0, 3,pastel_colors[2])
n5 = Node('Inference',0, 4,pastel_colors[2])
n6 = Node('Experimentation',0, 5,pastel_colors[2])
compute_nodes = [n4,n5,n6]
n2.set_out_nodes(compute_nodes,[0.15,0.8,0.05])
#-------------------------
n7 = Node('Storage',0, 6,pastel_colors[3])
n8 = Node('Transmission',0, 7, pastel_colors[3])
n9 = Node('Collection',0, 8, pastel_colors[3])
data_nodes = [n7,n8,n9]
n3.set_out_nodes(data_nodes)
#-------------------------
n9 = Node('Waste',0,9, pastel_colors[4])
n10 = Node('Recycling',0,10, pastel_colors[4])
final_nodes = [n9,n10]
for c in compute_nodes:
    c.set_out_nodes(final_nodes)
for d in data_nodes:
    d.set_out_nodes(final_nodes)

In [None]:
all_nodes, all_links = n1.dfs()

In [None]:
# color_links = ['rgba(20,120,20,0.8)']*len(all_links)
color_nodes = [n.color for n in all_nodes]

In [None]:
data = {}
data['node'] = {}
data['node']['label'] = [n.name for n in all_nodes]
data['node']['color'] = [n.color for n in all_nodes]
data['link'] = {}
data['link']['source'] = [l[1][0] for l in all_links]
data['link']['target'] = [l[1][1] for l in all_links]
data['link']['color'] = [all_nodes[l[1][0]].color for l in all_links] # get the source color
data['link']['value'] = [l[0] for l in all_links] # get the source color
data['link']['label'] = ['']*len(all_links)

In [None]:
# override gray link colors with 'source' colors
opacity = 0.4
# change 'magenta' to its 'rgba' value to add opacity
# data['data'][0]['node']['color'] = ['rgba(255,0,255, 0.8)' if color == "magenta" else color for color in data['data'][0]['node']['color']]
# data['data'][0]['link']['color'] = [data['data'][0]['node']['color'][src].replace("0.8", str(opacity))
#                                     for src in data['data'][0]['link']['source']]

fig = go.Figure(data=[go.Sankey(
    arrangement = 'snap',
    valueformat = ".0f",
    valuesuffix = "TWh",
    # Define nodes
    node = dict(
      pad = 15,
      thickness = 15,
      line = dict(color = "black", width = 0.5),
      label =  data['node']['label'],
      color =  data['node']['color']
    ),
    # Add links
    link = dict(
      source =  data['link']['source'],
      target =  data['link']['target'],
      value =  data['link']['value'],
      label =  data['link']['label'],
      color =  data['link']['color']
))])

fig.update_layout(title_text="Circular economy of AI models (<a href=https://mylonasc.netlify.app/>C. Mylonas</a>)",
                  font_size=10)
fig.show()