In [24]:
from td_exercise_environment import Environment
from bokeh.io import output_notebook, show
from bokeh.layouts import widgetbox
from bokeh.models.widgets import Slider, TextInput, Button
from bokeh.models import CustomJS
output_notebook()

puzzle_width = 8
puzzle_height = 8
no_mouses = 18

# This should be in put
start_x = 0
start_y = 0
env = Environment(puzzle_height, puzzle_width)
env.set_start_state(start_x, start_y)

In [59]:
import numpy as np
mouses = []
# for i in range(no_mouses):
#     mouses.append((np.random.randint(0, puzzle_width), np.random.randint(0, puzzle_height)))
mouses.append((0, 2))
mouses.append((1, 1))
mouses.append((1, 3))
mouses.append((2, 1))
mouses.append((3, 1))
mouses.append((3, 3))
mouses.append((4, 3))
mouses.append((4, 4))
mouses.append((4, 5))
mouses.append((4, 6))
mouses.append((5, 1))
mouses.append((5, 2))
mouses.append((5, 3))
mouses.append((6, 1))
mouses.append((6, 6))
mouses.append((6, 7))
mouses.append((7, 1))
env.set_mouses(mouses)
print(mouses)

[(0, 2), (1, 1), (1, 3), (2, 1), (3, 1), (3, 3), (4, 3), (4, 4), (4, 5), (4, 6), (5, 1), (5, 2), (5, 3), (6, 1), (6, 6), (6, 7), (7, 1)]


In [66]:
from bokeh.plotting import figure, output_file, show
from bokeh.models import Label, Arrow, NormalHead
from bokeh.models.glyphs import Rect
from bokeh.io import output_notebook, push_notebook

rect_lst = []
arrow_dict = {}
label_dict = {}
LINE_LENGTH = 0.5
WIDTH=1
HEIGHT=1
ARROW_OFFSET=0.15

def draw_rectangles(p, env):
    global rect_lst, LINE_LENGTH, WIDTH, HEIGHT
    for i in range(env._size_y):
        row = []
        for j in range(env._size_x):
            rect = Rect(x=j, y=-i, width=WIDTH, height=HEIGHT, 
                        fill_color="#CAB2D6", line_width=1.0, line_color = "#000000")
            p.add_glyph(rect)
            row.append(rect)
        rect_lst.append(row)
    for mouse in env.mouses:
        rect_lst[mouse[1]][mouse[0]].fill_color="#ff0000"
        
def draw_arrows(p, env):
    global HEIGHT, WIDTH, LINE_LENGTH, arrow_lst
    for row in env.get_states():
        for state in row:
            state_action_dict = {}
            for action in env.get_possible_actions(state):
                if action._type == 'd':
                    arrow = Arrow(end=NormalHead(line_color="firebrick", line_width=2, size=6, fill_alpha=0.5, line_alpha=0.8), 
                                  x_start=state._x+state._x*LINE_LENGTH-ARROW_OFFSET, y_start=-state._y-state._y*LINE_LENGTH-HEIGHT/2, 
                                  x_end=state._x+state._x*LINE_LENGTH-ARROW_OFFSET, y_end=-state._y-state._y*LINE_LENGTH-HEIGHT/2-LINE_LENGTH)
                    p.add_layout(arrow)
                    state_action_dict[action] = arrow
                elif action._type == 'u':
                    arrow = Arrow(end=NormalHead(line_color="firebrick", line_width=2, size=6, fill_alpha=0.5, line_alpha=0.8), 
                                  x_start=state._x+state._x*LINE_LENGTH+ARROW_OFFSET, y_start=-state._y-state._y*LINE_LENGTH+HEIGHT/2, 
                                  x_end=state._x+state._x*LINE_LENGTH+ARROW_OFFSET, y_end=-state._y-state._y*LINE_LENGTH+HEIGHT/2+LINE_LENGTH)
                    p.add_layout(arrow)
                    state_action_dict[action] = arrow
                elif action._type == 'r':
                    arrow = Arrow(end=NormalHead(line_color="firebrick", line_width=2, size=6, fill_alpha=0.5, line_alpha=0.8), 
                                  x_start=state._x+state._x*LINE_LENGTH+WIDTH/2, y_start=-state._y-state._y*LINE_LENGTH-ARROW_OFFSET, 
                                  x_end=state._x+state._x*LINE_LENGTH+WIDTH/2+LINE_LENGTH, y_end=-state._y-state._y*LINE_LENGTH-ARROW_OFFSET)
                    p.add_layout(arrow)
                    state_action_dict[action] = arrow
                elif action._type == 'l':
                    arrow = Arrow(end=NormalHead(line_color="firebrick", line_width=2, size=6, fill_alpha=0.5, line_alpha=0.8), 
                                  x_start=state._x+state._x*LINE_LENGTH-WIDTH/2, y_start=-state._y-state._y*LINE_LENGTH+ARROW_OFFSET, 
                                  x_end=state._x+state._x*LINE_LENGTH-WIDTH/2-LINE_LENGTH, y_end=-state._y-state._y*LINE_LENGTH+ARROW_OFFSET)
                    p.add_layout(arrow)
                    state_action_dict[action] = arrow
            arrow_dict[state] = state_action_dict

def draw_labels(p, env):
    global LINE_LENGTH, label_dict
    for row in env.get_states():
        for state in row:
            label_dict[state] = {}
            for action in arrow_dict[state].keys():
                arrow = arrow_dict[state][action]
                if action._type == 'd':
                    label_dict[state][action] = Label(x = arrow.x_start - 0.4, y = arrow.y_start - LINE_LENGTH/2, text = "0", text_font_size="9pt")
                    p.add_layout(label_dict[state][action])
                if action._type == 'u':
                    label_dict[state][action] = Label(x = arrow.x_start + 0.05, y = arrow.y_start + LINE_LENGTH/2, text = "0", text_font_size="9pt")
                    p.add_layout(label_dict[state][action])
                if action._type == 'r':
                    label_dict[state][action] = Label(x = arrow.x_start + 0.05, y = arrow.y_start - 0.2, text = "0", text_font_size="9pt")
                    p.add_layout(label_dict[state][action])
                if action._type == 'l':
                    label_dict[state][action] = Label(x = arrow.x_end + 0.05, y = arrow.y_end + 0.05, text = "0", text_font_size="9pt")
                    p.add_layout(label_dict[state][action])
                    
def update_labels(Q):
    for state in Q:
        if state not in label_dict:
            continue
        for action in Q[state]:
            if action not in label_dict[state]:
                continue
            label_dict[state][action].text = "{:.2f}".format(Q[state][action])
    push_notebook(handle=handle)
                        
def update_rects(Q, env):
    global start_y, start_x
    previous_states = []
    state = env.reset()
    done = False
    while not done:
        if state in previous_states:
            print("mouse is going in circles")
            break
        previous_states.append(state)
        rect_lst[state._y][state._x].fill_color="#69ea69"
        best_next_action = max(a.Q[state].items(), key=operator.itemgetter(1))[0]
        next_state, _, done = env.step(best_next_action)
        state = next_state
#     push_notebook(handle=handle)
    
        
output_notebook()
p = figure()

draw_rectangles(p, env)
# draw_arrows(p, env)
# draw_labels(p, env)
p.axis.visible = False
show(p)
# handle = show(p, notebook_handle=True)

In [67]:
from random import random
import operator

# THIS SHOULD BE PUT IN
num_episodes=500
epsilon=0.2
discount_factor=0.8
alpha=0.5

from td_exercise_agent import get_random_action, make_epsilon_greedy_policy, Agent
a = Agent(env)
a.init_q_values()

policy = make_epsilon_greedy_policy(a.Q, epsilon)

for i in range(num_episodes):

    # Initialize state
    state = env.reset()

    while True:
        # Choose A from S using policy derived from Q
        action_probs = policy(state)
        action = get_random_action(action_probs)

        # Do a step in the environment
        next_state, reward, done = env.step(action)

        # TD control
        best_next_action = max(a.Q[next_state].items(), key=operator.itemgetter(1))[0]
        td_target = reward + discount_factor * a.Q[next_state][best_next_action]
        td_delta = td_target - a.Q[state][action]
        a.Q[state][action] += alpha * td_delta

        # Check if the current state is terminal
        if done:
            break

        state = next_state
        
# Update the view with the new value for a.Q[state][action]
# update_labels(a.Q)
update_rects(a.Q, env)

show(p)