# Init

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import xarray as xr
import numpy as np
import pandas as pd

import holoviews as hv
from holoviews.streams import Pipe, Buffer

from collections import deque, defaultdict

import streamz
import streamz.dataframe

import random, sys, gym, math, bokeh, pdb, time

from simple_agent import SimpleAgent
from drop_pick_agent import DropPickAgent
from decomp_agent import DecompAgent

hv.extension('bokeh')

# Setup

In [None]:
def train_episode(env, agent):
    state = env.reset()
    episode_return = 0
    episode = []
    while True:
        action = agent.select_action(state)
        next_state, reward, done, _ = env.step(action)
        agent.step(state, action, reward, next_state, done)
        
        episode_return += reward
        episode.append((state, action, reward))
        
        state = next_state
        if done:
            return episode, episode_return

In [None]:
# Aggregate episodes into batches
batch_size = 1

# Only show this many batches at once
max_batches_to_show = 100000

# Also apply rolling average to a certain window of batches
rolling_size = 10

In [None]:
def aggregate(df):
    df.x = int(df.mean(0))
    return df.head(1)

training_stream = streamz.Stream()
training_batched_stream = training_stream.partition(batch_size).map(pd.concat).map(aggregate)

example = pd.DataFrame({'x': [0]}, index=[0])
training_sdf = streamz.dataframe.DataFrame(training_batched_stream, example=example)

training_raw_buffer = Buffer(training_sdf, length=max_batches_to_show)
training_smooth_buffer = Buffer(training_sdf.x.rolling(rolling_size).median())
training_raw_dmap = hv.DynamicMap(hv.Curve, streams=[training_raw_buffer]).relabel('raw')
training_smooth_dmap = hv.DynamicMap(hv.Curve, streams=[training_smooth_buffer]).relabel('smooth')

# Run

In [None]:
env = gym.make('Taxi-v2')
agent = DecompAgent()
episode_i = 0
best_sample_avg = -np.inf

In [None]:
%%opts Curve [width=700 height=200 show_grid=True tools=['hover']]
training_raw_dmap # * training_smooth_dmap

In [None]:
num_episodes = 100000
window = 100
episode_returns = deque(maxlen=window)
min_return = np.inf
for i in range(num_episodes):
    
    episode, episode_return = train_episode(env, agent)
    episode_returns.append(episode_return)
    if episode_return < min_return:
        min_return = episode_return
        min_return_episode = episode

    # best 100 sample average    
    if len(episode_returns) >= window:
        sample_average = np.mean(episode_returns)
        best_sample_avg = max(best_sample_avg, np.mean(episode_returns))
        # output
        if i % 100 == 0:
            training_stream.emit( pd.DataFrame({'x': sample_average}, index=[episode_i]) )
        if i % 100 == 0:
            sys.stdout.write('\r' + "Episode: " + str(episode_i)+ " best avg: " + str(best_sample_avg))
    #
    episode_i += 1
    

In [None]:
min_return

In [None]:
average_q = np.max(agent.sub_agent.Q, axis=3)
average_q_da = xr.DataArray(average_q, coords=[('row',range(5)), ('col',range(5)), ('dest',range(4))], name='average_q')

policy = np.argmax(agent.sub_agent.Q, axis=3)
policy_r = xr.DataArray(np.ones(policy.shape), coords=[('row',range(5)), ('col',range(5)), ('dest',range(4))], name='mag')
policy_theta = xr.DataArray(policy, coords=[('row',range(5)), ('col',range(5)), ('dest',range(4))], name='angle')
policy_theta = xr.where(policy_theta == 0, math.radians(90), policy_theta)
policy_theta = xr.where(policy_theta == 1, math.radians(270), policy_theta)
policy_theta = xr.where(policy_theta == 2, math.radians(0), policy_theta)
policy_theta = xr.where(policy_theta == 3, math.radians(180), policy_theta)

ds = hv.Dataset(xr.merge([policy_theta, policy_r, average_q_da]))
policy_field = ds.to(gv.VectorField, ['col', 'row'], ['angle', 'mag'])
average_q_img = ds.to(hv.Image, ['col', 'row'], ['average_q'])

In [None]:
%%opts VectorField [width=200, height=200, invert_yaxis=True] (scale=1.5, line_width=3, color='black')
policy_field.redim.range(row=(-0.5, 4.5), col=(-0.5, 4.5))

In [None]:
path = [tuple( [x for x in agent.decode_state(s)[-3::-1]])+(a,r)  for s,a,r in min_return_episode]
path

In [None]:
%%opts Image [width=200, height=200, invert_yaxis=True] 
%%opts Points [invert_yaxis=True] (size=10, color='red')
%%opts Curve  [invert_yaxis=True] (line_width=1, color='red')
average_q_img * hv.Curve(path) * hv.Points(path)