In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import json
import networkx as nx
import plotly.graph_objects as go

## Code

In [2]:
def plot_graph(data_nx):
  # visualization from https://plotly.com/python/network-graphs/

  edge_x = []
  edge_y = []
  for edge in data_nx.edges():
      x0, y0 = data_nx.nodes[edge[0]]['pos']
      x1, y1 = data_nx.nodes[edge[1]]['pos']
      edge_x.append(x0)
      edge_x.append(x1)
      edge_x.append(None)
      edge_y.append(y0)
      edge_y.append(y1)
      edge_y.append(None)

  edge_trace = go.Scatter(
      x=edge_x, y=edge_y,
      line=dict(width=0.5, color='#888'),
      hoverinfo='none',
      mode='lines')

  node_x = []
  node_y = []
  for node in data_nx.nodes():
      x, y = data_nx.nodes[node]['pos']
      node_x.append(x)
      node_y.append(y)

  node_trace = go.Scatter(
      x=node_x, y=node_y,
      mode='markers',
      hoverinfo='text',
      marker=dict(
          showscale=True,
          # colorscale options
          #'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
          #'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
          #'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
          colorscale='YlGnBu',
          reversescale=True,
          color=[],
          size=10,
          colorbar=dict(
              thickness=15,
              title='Node Connections',
              xanchor='left',
              titleside='right'
          ),
          line_width=2))
  
  node_adjacencies = []
  node_text = []
  for node, adjacencies in enumerate(data_nx.adjacency()):
      node_adjacencies.append(len(adjacencies[1]))
      node_text.append('# of connections: '+str(len(adjacencies[1])))

  node_trace.marker.color = node_adjacencies
  node_trace.text = node_text

  fig = go.Figure(data=[edge_trace, node_trace],
              layout=go.Layout(
                  showlegend=False,
                  hovermode='closest',
                  margin=dict(b=20,l=5,r=5,t=40),
                  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                  yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                  )
  fig.show()

In [3]:
def plot_graph_labels(data_nx, labels):
  # visualization from https://plotly.com/python/network-graphs/
  # adjusted to show labels as colors

  edge_x = []
  edge_y = []
  for edge in data_nx.edges():
      x0, y0 = data_nx.nodes[edge[0]]['pos']
      x1, y1 = data_nx.nodes[edge[1]]['pos']
      edge_x.append(x0)
      edge_x.append(x1)
      edge_x.append(None)
      edge_y.append(y0)
      edge_y.append(y1)
      edge_y.append(None)

  edge_trace = go.Scatter(
      x=edge_x, y=edge_y,
      line=dict(width=0.5, color='#888'),
      hoverinfo='none',
      mode='lines')

  node_x = []
  node_y = []
  for node in data_nx.nodes():
      x, y = data_nx.nodes[node]['pos']
      node_x.append(x)
      node_y.append(y)

  node_trace = go.Scatter(
      x=node_x, y=node_y,
      mode='markers',
      hoverinfo='text',
      marker=dict(
          showscale=True,
          # colorscale options
          #'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
          #'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
          #'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
          colorscale='Bluered',
          reversescale=True,
          color=[],
          size=10,
          colorbar=dict(
              thickness=15,
              title='Class',
              xanchor='left',
              titleside='right'
          ),
          line_width=2))
  
  """node_adjacencies = []"""
  node_text = []
  for node, adjacencies in enumerate(data_nx.adjacency()):
      """node_adjacencies.append(len(adjacencies[1]))"""
      node_text.append('# of connections: '+str(len(adjacencies[1])))

  node_trace.marker.color = labels
  node_trace.text = node_text

  fig = go.Figure(data=[edge_trace, node_trace],
              layout=go.Layout(
                  showlegend=False,
                  hovermode='closest',
                  margin=dict(b=20,l=5,r=5,t=40),
                  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                  yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                  )
  fig.show()

## Run Plots

In [50]:
file_path = './Datasets/Processed/'
# options : 'euro', 'timme', 'cd', 'conref'
dat = 'cd'
# options : 'all', 'abortion', 'marijuana', 'gayRights', or 'obama'
top = 'obama'
t_all = True

if dat == 'cd':
  dat = dat + top

if dat == 'timme':
  if t_all:
    dat = dat + '_all'

graph_path = file_path + dat + '_graph.txt'
y_path = file_path + dat + '_ydict.json'

G = nx.read_weighted_edgelist(graph_path)
with open(y_path, 'r') as f:
  y_dict = json.load(f)

In [51]:
y_plot = []
for i in G.nodes():
    y_plot.append(y_dict[i])
data = [G, y_dict, y_plot]
graphs = [G]
conn = [" " if nx.is_connected(graph) else " not " for graph in graphs]
print("The full graph has", len(G), "nodes and is" + conn[0] + "connected.")

The full graph has 59 nodes and is connected.


In [52]:
poss = nx.spring_layout(data[0], iterations=10)
nx.set_node_attributes(data[0], poss, 'pos')
plot_graph(data[0])

In [53]:
plot_graph_labels(data[0], data[2])