In [1]:
#open matplotlib backend
%matplotlib

Using matplotlib backend: Qt5Agg


In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
import search_tool
from search_tool import Graph
from networkx.drawing.nx_pydot import graphviz_layout
import networkx as nx

In [3]:
#put graph here

graph = Graph(['A', ('B', 6), ('C', 4), ('D',3)], [('A', 'B', 5), ('B', 'C', 4), ('B', 'D', 3)], 'A', 'C', weighted=True)
f_h, e_h, p_h = search_tool.BFS(graph)

In [4]:
#helper functions for clarity
def draw_color_coded_graph(graph, color_map, ax):
    if graph.weighted == True:
        costs = nx.get_edge_attributes(graph.nx_graph, 'weight')
        pos = graphviz_layout(graph.nx_graph, prog = "dot", root = graph.start)
        nx.draw_networkx_edge_labels(graph.nx_graph, pos, edge_labels = costs, font_color = 'red')
        nx.draw_networkx(graph.nx_graph, pos, with_labels = "True", alpha = 0.9, node_color=color_map, ax=ax)
        plt.draw()
    else:
        pos = graphviz_layout(graph.nx_graph, prog = "dot", root = graph.start)
        nx.draw_networkx(graph.nx_graph, pos, with_labels = "True", node_color=color_map, ax=ax)
        plt.draw()

def draw_graph_on_axis(graph, ax):
    if graph.weighted == True:
        costs = nx.get_edge_attributes(graph.nx_graph, 'weight')
        pos = graphviz_layout(graph.nx_graph, prog = "dot", root = graph.start)
        nx.draw_networkx_edge_labels(graph.nx_graph, pos, edge_labels = costs, font_color = 'red')
        nx.draw_networkx(graph.nx_graph, pos, with_labels = "True", alpha = 0.9, node_color='cyan', ax=ax)
        plt.draw()
    else:
        pos = graphviz_layout(graph.nx_graph, prog = "dot", root = graph.start)
        nx.draw_networkx(graph.nx_graph, pos, with_labels = "True", node_color='cyan', ax=ax)
        plt.draw()

In [5]:
#visualization; will open a new window on local computer, will crash on remote network
fig, ax = plt.subplots()
graph.visualize_graph()
fig.subplots_adjust(bottom=0.2)

class Index:
    ind = 0
    def __init__(self, graph, f_h, e_h, p_h, ax):
        self.graph = graph
        self.f_h = f_h
        self.e_h = e_h
        self.p_h = p_h
        self.ax = ax
    def next(self, event):
        self.ind += 1
        i = self.ind % len(self.f_h)
        color_map = []
        for node in self.graph.nx_graph:
            if node in self.p_h[i]: 
                color_map.append('green')
            elif node in self.f_h[i]:
                color_map.append('yellow')
            elif node in self.e_h[i]:
                color_map.append('purple')
            else:
                color_map.append('cyan')
        draw_color_coded_graph(self.graph, color_map, self.ax)
    def prev(self, event):
        self.ind -= 1
        i = self.ind % len(self.f_h)
        color_map = []
        for node in self.graph.nx_graph:
            if node in self.p_h[i]: 
                color_map.append('green')
            elif node in self.f_h[i]:
                color_map.append('yellow')
            elif node in self.e_h[i]:
                color_map.append('purple')
            else:
                color_map.append('cyan')
        draw_color_coded_graph(self.graph, color_map, self.ax)
    
    def restart(self, event):
        draw_graph_on_axis(self.graph, self.ax)
        self.ind = 0
    
    def last(self, event):
        self.ind = len(self.f_h) - 1
        i = self.ind % len(self.f_h)
        color_map = []
        for node in self.graph.nx_graph:
            if node in self.p_h[i]: 
                color_map.append('green')
            elif node in self.f_h[i]:
                color_map.append('yellow')
            elif node in self.e_h[i]:
                color_map.append('purple')
            else:
                color_map.append('cyan')
        draw_color_coded_graph(self.graph, color_map, self.ax)

callback = Index(graph, f_h, e_h, p_h, ax)
axprev = fig.add_axes([0.6, 0.05, 0.1, 0.075])
axnext = fig.add_axes([0.7, 0.05, 0.1, 0.075])
axend = fig.add_axes([0.8, 0.05, 0.1, 0.075])
axrestart = fig.add_axes([0.5, 0.05, 0.1, 0.075])
bnext = Button(axnext, 'Next')
bnext.on_clicked(callback.next)
bprev = Button(axprev, 'Previous')
bprev.on_clicked(callback.prev)
brestart = Button(axrestart, 'Restart')
brestart.on_clicked(callback.restart)
bend = Button(axend, 'End')
bend.on_clicked(callback.last)
plt.show()