In [1]:
from __future__ import division
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import tensorflow as tf
from env_current import *
from collections import deque
from utils import *
from qnetwork import *
import plotting
import time

  (fname, cnt))
  (fname, cnt))
  from ._conv import register_converters as _register_converters


In [2]:
def train(sess, env, qnet):
    
    global EXPLORATION_RATE
  
    summary_ops, summary_vars = build_summaries()

    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter(SUMMARY_DIR, sess.graph)
    
    qnet.update_target()
    
    replay_buffer = ReplayBuffer(BUFFER_SIZE, RANDOM_SEED)
    
    reward_list = []
    
    for num_epi in range(MAX_EPISODES):

        s = env.reset()
        s = [list(np.unravel_index(s, env.shape))]

        ep_reward = 0
        ep_ave_max_q = 0

        for j in range(MAX_EPISODE_LEN):

            a = np.argmax(qnet.predict_q(np.reshape(s, (1, qnet.state_dim))))
    
            if np.random.rand(1) < EXPLORATION_RATE:
                s2, r, terminal, info = env.step(np.random.randint(0,qnet.action_dim))
            else:
                s2, r, terminal, info = env.step(a)
            
            s2 = list(np.unravel_index(s2, env.shape))

            replay_buffer.add(np.reshape(s, (qnet.state_dim,)), np.reshape(a, (1,)), r,
                              terminal, np.reshape(s2, (qnet.state_dim,)))

            # Keep adding experience to the memory until
            # there are at least minibatch size samples
            if replay_buffer.size() > MINIBATCH_SIZE:
                s_batch, a_batch, r_batch, t_batch, s2_batch = replay_buffer.sample_batch(MINIBATCH_SIZE)

                # Calculate targets
                target_q = qnet.predect_target(s2_batch)

                y_i = []
                for k in range(MINIBATCH_SIZE):
                    if t_batch[k]:
                        y_i.append(r_batch[k])
                    else:
                        y_i.append(r_batch[k] + GAMMA * np.amax(target_q[k]))

                # Update the critic given the targets
                predicted_q_value, _ = qnet.train(s_batch, a_batch, np.reshape(y_i, (MINIBATCH_SIZE, 1)), num_epi)

                ep_ave_max_q += np.amax(predicted_q_value)
                
                # Update target networks
                qnet.update_target()

            s = s2
            ep_reward += r

            if terminal or j == MAX_EPISODE_LEN-1:
                
                if EXPLORATION_RATE > 0.02 and terminal:
                    EXPLORATION_RATE = EXPLORATION_RATE*0.92
                
                reward_list += [ep_reward]
                
                if np.average(reward_list[-10:]) >= LR_DECAY_TRUNCATION:
                    qnet.decay_learning_rate(1)

                summary_str = sess.run(summary_ops, feed_dict={
                    summary_vars[0]: ep_reward,
                    summary_vars[1]: ep_ave_max_q / float(j),
                    summary_vars[2]: EXPLORATION_RATE,
                    summary_vars[3]: qnet.get_learning_rate()
                })

                writer.add_summary(summary_str, num_epi)
                writer.flush()

                print('| Reward: {:d} | Episode: {:d} | Qmax: {:.4f} | Exploration: {:.6f} '.format(int(ep_reward), \
                        num_epi, (ep_ave_max_q / float(j)), EXPLORATION_RATE))
                
                f = open("stats.txt", "ab")
                f.write("| Reward: " + str(int(ep_reward)) 
                        +" | Episode: " + str(num_epi) 
                        + " | Qmax: " + str(ep_ave_max_q / float(j)) 
                        + " | Exploration: " + str(EXPLORATION_RATE) + "\n")
                f.close()
                
                break
                
        if num_epi%1 == 0:
            state_list = []
            action_list = []
            world = np.zeros(env.shape)
            for state in range(env.nS):
                state = np.unravel_index(state, env.shape)
                action = qnet.predict_q(np.reshape(state, (1,state_dim)))
                action = np.argmax(action)
                state_list.append(state)
                action_list.append(action)
                
#             print np.reshape(action_list, env.shape)
                
            f = open("action.txt","ab")
            np.savetxt(f, np.reshape(action_list, env.shape), fmt="%i")
            f.write("---------------------------\n")
            f.close()
    
    

In [3]:
LEARNING_RATE = 0.0015
GAMMA = 0.99
# GAMMA = 0.7
TAU = 0.001
BUFFER_SIZE = 10**6
MINIBATCH_SIZE = 64
RANDOM_SEED = 101
MAX_EPISODES = 50000
MAX_EPISODE_LEN = 1500
file_appendix = time.ctime()[4:16].replace("  ","").replace(" ","_").replace(":","-")
SUMMARY_DIR = './results/tf_ddqn_' + file_appendix
SAVE_DIR = "./saved_model/" + file_appendix + "/ddqn.ckpt"
EXPLORATION_RATE = 0.65
LR_DECAY_TRUNCATION = -50

In [4]:
with tf.Session() as sess:
    env = CurrentWorld()
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)
    env.seed(RANDOM_SEED)
    
    state_dim = 2
    action_dim = 5
    
    Qnet = QNet(sess, state_dim, action_dim, LEARNING_RATE, TAU, MINIBATCH_SIZE, SAVE_DIR)
    
    train(sess, env, Qnet)

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


DDQN Saved
| Reward: -1500 | Episode: 0 | Qmax: 29.0392 | Exploration: 0.650000 
| Reward: -1500 | Episode: 1 | Qmax: 31.8888 | Exploration: 0.650000 
| Reward: -1500 | Episode: 2 | Qmax: 29.5829 | Exploration: 0.650000 
| Reward: -1500 | Episode: 3 | Qmax: 27.3973 | Exploration: 0.650000 
| Reward: -1500 | Episode: 4 | Qmax: 25.8825 | Exploration: 0.650000 
| Reward: -492 | Episode: 5 | Qmax: 30.2546 | Exploration: 0.598000 
| Reward: -345 | Episode: 6 | Qmax: 32.7059 | Exploration: 0.550160 
| Reward: -84 | Episode: 7 | Qmax: 33.3009 | Exploration: 0.506147 
| Reward: -44 | Episode: 8 | Qmax: 34.2149 | Exploration: 0.465655 
| Reward: -58 | Episode: 9 | Qmax: 33.6918 | Exploration: 0.428403 
| Reward: -61 | Episode: 10 | Qmax: 33.5526 | Exploration: 0.394131 
| Reward: -101 | Episode: 11 | Qmax: 34.2039 | Exploration: 0.362600 
| Reward: -977 | Episode: 12 | Qmax: 31.9958 | Exploration: 0.333592 
| Reward: -179 | Episode: 13 | Qmax: 30.5665 | Exploration: 0.306905 
| Reward: -408 | E

| Reward: -33 | Episode: 117 | Qmax: -8.2554 | Exploration: 0.019589 
| Reward: -16 | Episode: 118 | Qmax: -8.4879 | Exploration: 0.019589 
| Reward: -59 | Episode: 119 | Qmax: -8.2034 | Exploration: 0.019589 
DDQN Saved
| Reward: -41 | Episode: 120 | Qmax: -7.9746 | Exploration: 0.019589 
| Reward: -27 | Episode: 121 | Qmax: -8.1008 | Exploration: 0.019589 
| Reward: -130 | Episode: 122 | Qmax: -7.6693 | Exploration: 0.019589 
| Reward: -15 | Episode: 123 | Qmax: -7.9116 | Exploration: 0.019589 
| Reward: -78 | Episode: 124 | Qmax: -7.4221 | Exploration: 0.019589 
| Reward: -59 | Episode: 125 | Qmax: -7.6174 | Exploration: 0.019589 
| Reward: -19 | Episode: 126 | Qmax: -7.9071 | Exploration: 0.019589 
| Reward: -185 | Episode: 127 | Qmax: -7.5433 | Exploration: 0.019589 
| Reward: -71 | Episode: 128 | Qmax: -7.7057 | Exploration: 0.019589 
| Reward: -58 | Episode: 129 | Qmax: -7.7672 | Exploration: 0.019589 
| Reward: -180 | Episode: 130 | Qmax: -7.7652 | Exploration: 0.019589 
| Rewa

| Reward: -76 | Episode: 233 | Qmax: -7.5115 | Exploration: 0.019589 
| Reward: -23 | Episode: 234 | Qmax: -7.8792 | Exploration: 0.019589 
| Reward: -31 | Episode: 235 | Qmax: -7.7979 | Exploration: 0.019589 
| Reward: -79 | Episode: 236 | Qmax: -7.7105 | Exploration: 0.019589 
| Reward: -57 | Episode: 237 | Qmax: -7.8963 | Exploration: 0.019589 
| Reward: -36 | Episode: 238 | Qmax: -7.9109 | Exploration: 0.019589 
| Reward: -30 | Episode: 239 | Qmax: -7.8181 | Exploration: 0.019589 
DDQN Saved
| Reward: -23 | Episode: 240 | Qmax: -7.9924 | Exploration: 0.019589 
| Reward: -19 | Episode: 241 | Qmax: -8.2853 | Exploration: 0.019589 
| Reward: -53 | Episode: 242 | Qmax: -7.8848 | Exploration: 0.019589 
| Reward: -53 | Episode: 243 | Qmax: -7.9476 | Exploration: 0.019589 
| Reward: -103 | Episode: 244 | Qmax: -7.8879 | Exploration: 0.019589 
| Reward: -59 | Episode: 245 | Qmax: -8.1029 | Exploration: 0.019589 
| Reward: -33 | Episode: 246 | Qmax: -8.2333 | Exploration: 0.019589 
| Reward

| Reward: -23 | Episode: 350 | Qmax: -7.6316 | Exploration: 0.019589 
| Reward: -31 | Episode: 351 | Qmax: -7.6781 | Exploration: 0.019589 
| Reward: -17 | Episode: 352 | Qmax: -7.8324 | Exploration: 0.019589 
| Reward: -15 | Episode: 353 | Qmax: -7.7688 | Exploration: 0.019589 
| Reward: -15 | Episode: 354 | Qmax: -7.7505 | Exploration: 0.019589 
| Reward: -15 | Episode: 355 | Qmax: -7.7345 | Exploration: 0.019589 
| Reward: -35 | Episode: 356 | Qmax: -7.4329 | Exploration: 0.019589 
| Reward: -90 | Episode: 357 | Qmax: -7.1960 | Exploration: 0.019589 
| Reward: -29 | Episode: 358 | Qmax: -7.5248 | Exploration: 0.019589 
| Reward: -93 | Episode: 359 | Qmax: -7.2369 | Exploration: 0.019589 
DDQN Saved
| Reward: -17 | Episode: 360 | Qmax: -7.4700 | Exploration: 0.019589 
| Reward: -19 | Episode: 361 | Qmax: -7.4975 | Exploration: 0.019589 
| Reward: -24 | Episode: 362 | Qmax: -7.3644 | Exploration: 0.019589 
| Reward: -15 | Episode: 363 | Qmax: -7.5587 | Exploration: 0.019589 
| Reward:

| Reward: -15 | Episode: 467 | Qmax: -7.2082 | Exploration: 0.019589 
| Reward: -111 | Episode: 468 | Qmax: -6.7739 | Exploration: 0.019589 
| Reward: -37 | Episode: 469 | Qmax: -6.8584 | Exploration: 0.019589 
| Reward: -90 | Episode: 470 | Qmax: -6.8048 | Exploration: 0.019589 
| Reward: -18 | Episode: 471 | Qmax: -7.2039 | Exploration: 0.019589 
| Reward: -21 | Episode: 472 | Qmax: -7.2074 | Exploration: 0.019589 
| Reward: -15 | Episode: 473 | Qmax: -7.3614 | Exploration: 0.019589 
| Reward: -16 | Episode: 474 | Qmax: -7.1585 | Exploration: 0.019589 
| Reward: -15 | Episode: 475 | Qmax: -7.1032 | Exploration: 0.019589 
| Reward: -28 | Episode: 476 | Qmax: -6.8701 | Exploration: 0.019589 
| Reward: -15 | Episode: 477 | Qmax: -7.0271 | Exploration: 0.019589 
| Reward: -27 | Episode: 478 | Qmax: -6.8595 | Exploration: 0.019589 
| Reward: -15 | Episode: 479 | Qmax: -6.9389 | Exploration: 0.019589 
DDQN Saved
| Reward: -15 | Episode: 480 | Qmax: -6.9404 | Exploration: 0.019589 
| Reward

| Reward: -15 | Episode: 584 | Qmax: -6.5964 | Exploration: 0.019589 
| Reward: -15 | Episode: 585 | Qmax: -6.7205 | Exploration: 0.019589 
| Reward: -15 | Episode: 586 | Qmax: -6.7476 | Exploration: 0.019589 
| Reward: -17 | Episode: 587 | Qmax: -6.6316 | Exploration: 0.019589 
| Reward: -31 | Episode: 588 | Qmax: -6.3729 | Exploration: 0.019589 
| Reward: -39 | Episode: 589 | Qmax: -6.4025 | Exploration: 0.019589 
| Reward: -15 | Episode: 590 | Qmax: -6.7132 | Exploration: 0.019589 
| Reward: -83 | Episode: 591 | Qmax: -6.4090 | Exploration: 0.019589 
| Reward: -15 | Episode: 592 | Qmax: -6.8321 | Exploration: 0.019589 
| Reward: -15 | Episode: 593 | Qmax: -6.9993 | Exploration: 0.019589 
| Reward: -15 | Episode: 594 | Qmax: -6.8893 | Exploration: 0.019589 
| Reward: -15 | Episode: 595 | Qmax: -7.2137 | Exploration: 0.019589 
| Reward: -126 | Episode: 596 | Qmax: -6.5028 | Exploration: 0.019589 
| Reward: -39 | Episode: 597 | Qmax: -6.4712 | Exploration: 0.019589 
| Reward: -19 | Epi

| Reward: -15 | Episode: 701 | Qmax: -6.6095 | Exploration: 0.019589 
| Reward: -17 | Episode: 702 | Qmax: -6.6505 | Exploration: 0.019589 
| Reward: -15 | Episode: 703 | Qmax: -6.5921 | Exploration: 0.019589 
| Reward: -15 | Episode: 704 | Qmax: -6.6226 | Exploration: 0.019589 
| Reward: -31 | Episode: 705 | Qmax: -6.5503 | Exploration: 0.019589 
| Reward: -21 | Episode: 706 | Qmax: -6.5722 | Exploration: 0.019589 
| Reward: -19 | Episode: 707 | Qmax: -6.5787 | Exploration: 0.019589 
| Reward: -15 | Episode: 708 | Qmax: -6.8002 | Exploration: 0.019589 
| Reward: -17 | Episode: 709 | Qmax: -6.7233 | Exploration: 0.019589 
| Reward: -17 | Episode: 710 | Qmax: -6.7112 | Exploration: 0.019589 
| Reward: -25 | Episode: 711 | Qmax: -6.4634 | Exploration: 0.019589 
| Reward: -17 | Episode: 712 | Qmax: -6.7898 | Exploration: 0.019589 
| Reward: -15 | Episode: 713 | Qmax: -6.5999 | Exploration: 0.019589 
| Reward: -37 | Episode: 714 | Qmax: -6.4840 | Exploration: 0.019589 
| Reward: -15 | Epis

| Reward: -34 | Episode: 818 | Qmax: -6.1198 | Exploration: 0.019589 
| Reward: -15 | Episode: 819 | Qmax: -6.4567 | Exploration: 0.019589 
DDQN Saved
| Reward: -15 | Episode: 820 | Qmax: -6.3643 | Exploration: 0.019589 
| Reward: -15 | Episode: 821 | Qmax: -6.4849 | Exploration: 0.019589 
| Reward: -15 | Episode: 822 | Qmax: -6.4583 | Exploration: 0.019589 
| Reward: -15 | Episode: 823 | Qmax: -6.1235 | Exploration: 0.019589 
| Reward: -20 | Episode: 824 | Qmax: -6.2423 | Exploration: 0.019589 
| Reward: -39 | Episode: 825 | Qmax: -6.1386 | Exploration: 0.019589 
| Reward: -15 | Episode: 826 | Qmax: -6.2672 | Exploration: 0.019589 
| Reward: -21 | Episode: 827 | Qmax: -6.3666 | Exploration: 0.019589 
| Reward: -15 | Episode: 828 | Qmax: -6.4367 | Exploration: 0.019589 
| Reward: -15 | Episode: 829 | Qmax: -6.3791 | Exploration: 0.019589 
| Reward: -15 | Episode: 830 | Qmax: -6.4052 | Exploration: 0.019589 
| Reward: -19 | Episode: 831 | Qmax: -6.2597 | Exploration: 0.019589 
| Reward:

| Reward: -15 | Episode: 935 | Qmax: -6.4557 | Exploration: 0.019589 
| Reward: -15 | Episode: 936 | Qmax: -6.4773 | Exploration: 0.019589 
| Reward: -15 | Episode: 937 | Qmax: -6.7384 | Exploration: 0.019589 
| Reward: -15 | Episode: 938 | Qmax: -6.5745 | Exploration: 0.019589 
| Reward: -15 | Episode: 939 | Qmax: -6.7488 | Exploration: 0.019589 
DDQN Saved
| Reward: -15 | Episode: 940 | Qmax: -6.7465 | Exploration: 0.019589 
| Reward: -15 | Episode: 941 | Qmax: -6.7463 | Exploration: 0.019589 
| Reward: -15 | Episode: 942 | Qmax: -6.5387 | Exploration: 0.019589 
| Reward: -17 | Episode: 943 | Qmax: -6.4686 | Exploration: 0.019589 
| Reward: -15 | Episode: 944 | Qmax: -6.4836 | Exploration: 0.019589 
| Reward: -15 | Episode: 945 | Qmax: -6.6085 | Exploration: 0.019589 
| Reward: -15 | Episode: 946 | Qmax: -6.6487 | Exploration: 0.019589 
| Reward: -15 | Episode: 947 | Qmax: -6.7580 | Exploration: 0.019589 
| Reward: -15 | Episode: 948 | Qmax: -6.3794 | Exploration: 0.019589 
| Reward:

| Reward: -19 | Episode: 1051 | Qmax: -6.6962 | Exploration: 0.019589 
| Reward: -19 | Episode: 1052 | Qmax: -6.8901 | Exploration: 0.019589 
| Reward: -16 | Episode: 1053 | Qmax: -6.4211 | Exploration: 0.019589 
| Reward: -20 | Episode: 1054 | Qmax: -6.2812 | Exploration: 0.019589 
| Reward: -15 | Episode: 1055 | Qmax: -7.0855 | Exploration: 0.019589 
| Reward: -15 | Episode: 1056 | Qmax: -7.0393 | Exploration: 0.019589 
| Reward: -15 | Episode: 1057 | Qmax: -6.4473 | Exploration: 0.019589 
| Reward: -15 | Episode: 1058 | Qmax: -6.8326 | Exploration: 0.019589 
| Reward: -15 | Episode: 1059 | Qmax: -6.2564 | Exploration: 0.019589 
DDQN Saved
| Reward: -15 | Episode: 1060 | Qmax: -6.2245 | Exploration: 0.019589 
| Reward: -15 | Episode: 1061 | Qmax: -6.0386 | Exploration: 0.019589 
| Reward: -15 | Episode: 1062 | Qmax: -6.2417 | Exploration: 0.019589 
| Reward: -19 | Episode: 1063 | Qmax: -6.7952 | Exploration: 0.019589 
| Reward: -17 | Episode: 1064 | Qmax: -6.2680 | Exploration: 0.019

KeyboardInterrupt: 

world = np.zeros(env.shape)
a_list = []
s_list = []
for s in range(env.nS):
    a_list += [np.argmax(P[s])]
    s_list += [np.unravel_index(s,env.shape)]
for s,a in zip(s_list,a_list):
    world[s] = a
    


%matplotlib auto
plt.imshow(world)

%matplotlib inline
matplotlib.style.use('ggplot')
plotting.plot_episode_stats(stats)

def get_optimal_path(Q,env):
    env.reset()
    start_state = env.start_state
    terminal_state = env.terminal_state
    state = np.ravel_multi_index(start_state,env.shape)
    path = [start_state]
    value = 0
    action = []
    while 1:
        next_action = np.argmax(Q[state])
        next_state,reward,done,_ = env.step(next_action)
        path += [np.unravel_index(next_state,env.shape)]
        value += reward
        action += [next_action]
        if done:
            return path, action, value
            break
        state = next_state

opt_path,action,value = get_optimal_path(Q,env)

%matplotlib auto
world = deepcopy(env.winds)
t = 0
for i in opt_path[:-1]:
    world[i] = 6
#     world[i] += action[t]
    t+=1
plt.imshow(world)
# print value