# Q-Network Learning

In [1]:
import gym
import numpy as np
import random
import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline

### Load the environment

In [2]:
env = gym.make('FrozenLake-v0')

[2018-10-21 08:02:57,659] Making new env: FrozenLake-v0


## The Q-Network Approach

### Implementing the network itself

In [3]:
tf.reset_default_graph()

In [4]:
#These lines establish the feed-forward part of the network used to choose actions
inputs1 = tf.placeholder(shape=[1,16],dtype=tf.float32)
W       = tf.Variable(tf.random_uniform([16,4],0,0.01))
Qout    = tf.matmul(inputs1,W)
predict = tf.argmax(Qout,1)

#Below we obtain the loss by taking the sum of squares difference between the target and prediction Q values.
nextQ       = tf.placeholder(shape=[1,4],dtype=tf.float32)
loss        = tf.reduce_sum(tf.square(nextQ - Qout))
trainer     = tf.train.GradientDescentOptimizer(learning_rate=0.1)
updateModel = trainer.minimize(loss)

### Training the network

In [10]:
init = tf.initialize_all_variables()

# Set learning parameters
y = .99
e = 0.1
num_episodes = 100
#create lists to contain total rewards and steps per episode
jList = []
rList = []
with tf.Session() as sess:
    sess.run(init)
    for i in range(num_episodes):
        #Reset environment and get first new observation
        s = env.reset()
        rAll = 0
        d = False
        j = 0
        #The Q-Network
        while j < 99:
            j+=1
            #Choose an action by greedily (with e chance of random action) from the Q-network
            a,allQ = sess.run([predict,Qout],feed_dict={inputs1:np.identity(16)[s:s+1]})


            if np.random.rand(1) < e:
                a[0] = env.action_space.sample()
            #Get new state and reward from environment
            s1,r,d,_ = env.step(a[0])
            #Obtain the Q' values by feeding the new state through our network
            Q1 = sess.run(Qout,feed_dict={inputs1:np.identity(16)[s1:s1+1]})
            #Obtain maxQ' and set our target value for chosen action.
            maxQ1 = np.max(Q1)
            targetQ = allQ
            targetQ[0,a[0]] = r + y*maxQ1
            #Train our network using target and predicted Q values
            _,W1 = sess.run([updateModel,W],feed_dict={inputs1:np.identity(16)[s:s+1],nextQ:targetQ})
            rAll += r
            s = s1
            if (i==0 or i==2000):
                print('-'*50)
                print(s1,r,d)
                print('Q old s   : ',allQ)
                print('Max Q     : ',maxQ1)
                print('Action    : ',a)
                print('Q new s   : ',Q1)
                print('Prediction: ',allQ)
                print('Target    : ',targetQ)
            if d == True:
                #Reduce chance of random action as we train the model.
                e = 1./((i/50) + 10)
                break
        jList.append(j)
        rList.append(rAll)
print ("Percent of succesful episodes: " + str(sum(rList)/num_episodes) + "%")

--------------------------------------------------
[3]
[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
[[0.00012051 0.00526388 0.00358011 0.00613403]]
0 0.0 False
[[0.00012051 0.00526388 0.00358011 0.00619599]]
0.00619599
[[0.00012051 0.00526388 0.00358011 0.00613403]]
--------------------------------------------------
[3]
[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
[[0.00012051 0.00526388 0.00358011 0.0089158 ]]
1 0.0 False
[[0.00900585 0.00196688 0.00758573 0.00587726]]
0.009005855
[[0.00012051 0.00526388 0.00358011 0.0089158 ]]
--------------------------------------------------
[0]
[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
[[0.0089158  0.00196688 0.00758573 0.00587726]]
1 0.0 False
[[0.00900585 0.00196688 0.00758573 0.00587726]]
0.009005855
[[0.0089158  0.00196688 0.00758573 0.00587726]]
--------------------------------------------------
[0]
[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
[[0.00666274 0.00196688 0.00758573 0.00587726]]
0 0.0 False
[[0.00012

### Some statistics on network performance

We can see that the network beings to consistly reach the goal around the 750 episode mark.

In [None]:
plt.plot(rList)

It also begins to progress through the environment for longer than chance aroudn the 750 mark as well.

In [None]:
plt.plot(jList)