In [191]:
import plotly.graph_objects as go
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Read data in. Drop extraneous columns.
df = pd.read_excel("data/2015_Guidelines.xlsx")
df.drop(columns=["Unnamed: 27", "Unnamed: 28", "Unnamed: 29"], inplace=True)

#categories = list(df['BJCP Categories'])
#styles = list(df['Styles'])

# Create a network graph from the dataframe.
G = nx.from_pandas_edgelist(df, source="BJCP Categories", target="Styles", edge_attr=None)

# Create positions for the graph.
#pos = nx.spring_layout(G)

# Read the pre-defined coordinates.
coords = np.genfromtxt(r'data/coords.txt', delimiter=',')

# Create an empty dictionary that will match each node with the coordinates from above.
pos = {}

# List of attributes
attr = ['Style Family', 'Style History', 'Origin', 'ABV min', 'ABV max', 'IBUs min', 'IBUs max', 'SRM min', 'SRM max', 
          'OG min', 'OG max', 'FG min', 'FG max', 'Overall Impression', 'Aroma', 'Appearance', 'Flavor', 'Mouthfell', 
          'Comments', 'History', 'Characteristic Ingredients', 'Style Comparison', 'Commercial Examples', 'Notes']

# Add attributes for the style (i.e. target) nodes.
for i, row in df.iterrows():
    for a in attr:
        G.nodes[row['BJCP Categories']][a] = " "
        G.nodes[row['Styles']][a] = str(row[a])

        
nodes_loc = np.empty(shape=(len(G.nodes), 2))
  
colors_dict = {1: '#FFE699', 2: '#FFD878', 3: '#FFCA5A', 4: '#FFBF42',
         5: '#FBB123', 6: '#F8A600', 7: '#F39C00', 8: '#EA8F00',
         9: '#E58500', 10: '#DE7C00', 11: '#D77200', 12: '#CF6900',
         13: '#CB6200', 14: '#C35900', 15: '#BB5100', 16: '#B54C00',
         17: '#B04500', 18: '#A63E00', 19: '#A13700', 20: '#9B3200',
         21: '#952D00', 22: '#8E2900', 23: '#882300', 24: '#821E00',
         25: '#7B1A00', 26: '#771900', 27: '#701400', 28: '#6A0E00',
         29: '#660D00', 30: '#5E0B00', 31: '#5A0A02', 32: '#600903',
         33: '#520907', 34: '#4C0505', 35: '#470606', 36: '#440607',
         37: '#3F0708', 38: '#3B0607', 39: '#3A070B', 40: '#36080A'}

colors = []

for ix, node in enumerate(G.nodes()):
    
    pos[node] = coords[ix]
    nodes_loc[ix] = pos[node]
    
    try:
        
        m = 0.5 * (float(G.nodes[node]['SRM min']) + float(G.nodes[node]['SRM max']))
        colors.append(colors_dict[round(m)])
                  
    except ValueError:
        if G.nodes[node]['SRM max'] == " ":
            colors.append('#000cb8')
        else:
            colors.append('gray')

edges_loc = []

for edge in G.edges():
    edges_loc.append(pos[edge[0]])
    edges_loc.append(pos[edge[1]])
    edges_loc.append([None, None])

edges_loc = np.array(edges_loc)

nodes=[dict(type='scatter',
            x=nodes_loc[:,0],
            y=nodes_loc[:,1],
            mode='markers',
            marker=dict(size=15, color=colors), 
            hoverinfo='text',
            hovertext=[node + 
                       "<br>Origin: " + G.nodes[node]['Origin'] +
                       "<br>ABV range: " + G.nodes[node]['ABV min'] + "-" + G.nodes[node]['ABV max'] +
                       "<br>IBU range: " + G.nodes[node]['IBUs min'] + "-" + G.nodes[node]['IBUs max']
                       for node in G.nodes()])]

edges=[dict(type='scatter',
            x=edges_loc[:,0],
            y=edges_loc[:,1],
            mode='lines',
            line=dict(width=2, color='gray'))]

data=edges+nodes

fig = go.Figure(data=data,
             layout=go.Layout(
                height=500,
                 width=800,
                title={'text':'Beer Taxonomy Chart', 
                       'font':{'family':'Impact',
                               'color':'black',
                               'size':20},
                       'x':0.5,
                       'xanchor':'center',
                       'y':0.95},
                hovermode='closest',
                margin=dict(b=20,l=0,r=0,t=50),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                showlegend=False,
                annotations=[dict(
                    text="Blue nodes = BJCP Categories <br> Gray nodes indicate variable color",
                    font={'family': 'Arial', 'size':12},
                    showarrow=False,
                    xref="paper", yref="paper",
                    x=0.98, y=0.95)]
             )
            )
                
fig.show()