# Function


In [19]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import fiona
import shapely
from shapely.geometry import shape,mapping, Point, Polygon, MultiPolygon
import networkx as nx
import pickle
import matplotlib.ticker as mticker
from ipywidgets import interact, widgets

In [23]:

def plot_optimal_q(q_path='../data/emp_Q_shA.pkl', shp_file='../data/taxi_zones/taxi_zones.shp', t=32, all_nodes=False, manhattan=False):
    q = open(q_path, 'rb')
    data = pickle.load(q)
    man = [ 4,  12,  13,  24,  41,  42,  43,  45,  48,  50,  68,  74,  75, 
           79,  87,  88,  90, 100, 103, 104, 105, 107, 113, 114, 116, 120,
           125, 127, 128, 137, 140, 141, 142, 143, 144, 148, 151, 152, 153,
           158, 161, 162, 163, 164, 166, 170, 186, 194, 202, 209, 211, 224,
           229, 230, 231, 232, 233, 234, 236, 237, 238, 239, 243, 244, 246,
           249, 261, 262, 263]
    edge = [] # action to another zone
    wait = [] # wait in the same zone
    self = [] # action within the same zone
    a = np.zeros((264,))
    v = []
    nodes = []
    
    for i in data:
        if i[1]==t:
            now = data[i]
            # if there is optimal action
            if ((sum(np.equal(now,a))!=264)):
                start = int(i[0])
                end = int(np.argmax(now))
                if(manhattan==True):
                    if((start in man) & (end in man)):
                        v.append([start, np.amax(now)])
                        nodes.append(start)
                        nodes.append(end)
                        if (end == 0):
                            wait.append(start)
                        elif (start == end):
                            self.append(end)
                        else:
                            edge.append((start, end))
                else:
                    v.append([start, np.amax(now)])
                    nodes.append(start)
                    nodes.append(end)

                    if (end == 0):
                        wait.append(start)
                    elif (start == end):
                        self.append(end)
                    else:
                        edge.append((start, end))
                    
    nodes = list(set(nodes))
    
    G = nx.DiGraph()
    taxi_zones = fiona.open(shp_file)
    for j in range(len(taxi_zones)):
        zone = taxi_zones[j]
        i = int(zone['id']) + 1
        shape = shapely.geometry.asShape(zone['geometry'])
        center = shape.centroid.coords[0]
        
        #add node
        if(all_nodes==True):
            if(manhattan==True):
                if(i in man):
                    G.add_node(i, pos=center)
            else:
                G.add_node(i, pos=center) 
        else:
            if(i in nodes):
                G.add_node(i, pos=center)
    
    color = ['' for k in range(len(list(G.nodes)))] # red for action, blue for wait
    width = np.zeros(len(list(G.nodes)))  # bolder if there is optimal action
    
    for i in range(len(list(G.nodes))):
        n = list(G.nodes)[i]
        if (n in wait):
            color[i] = 'blue'
            width[i] = 3
        elif (n in self):
            color[i] ='red'
            width[i] = 3
        else:
            color[i] ='black'
            width[i] = 1
    
    # add edge
    G.add_edges_from(edge)
    p = nx.get_node_attributes(G,'pos')
    fig = plt.figure(3,figsize=(27,27)) 
    nx.draw_networkx_nodes(G, pos=p, node_color='white', node_size=500, edgecolors=color, linewidths=width)
    nx.draw_networkx_labels(G, pos=p, font_size=10)
    nx.draw_networkx_edges(G, pos = p, width=3, edge_color='red')
    ax = plt.gca() # get the current axis
    ax.collections[0].set_edgecolor(color) 
    fig.suptitle('Optimal Action for Taxi Zones', fontsize=30, y=0.9)
    #plt.savefig('../optimal_q.png', bbox_inches = 'tight')
    plt.show()

    return

In [24]:
def interactive_plot(q_path='../data/emp_Q_shA.pkl'):
    q = open(q_path, 'rb')
    data = pickle.load(q)
    t_list = []
    for i in data:
        t_list.append(i[1])
    interact(plot_optimal_q, 
             all_nodes = widgets.RadioButtons(options=[True, False], value=False),
             manhattan = widgets.RadioButtons(options=[True, False], value=False),
             t = widgets.IntSlider(min=int(min(t_list)),max=int(max(t_list)),step=1,value=0))

In [25]:
interactive_plot(q_path='../data/emp_Q_shA.pkl')

interactive(children=(Text(value='../data/emp_Q_shA.pkl', description='q_path'), Text(value='../data/taxi_zone…