In [31]:
import theano
from theano import tensor as T
import lasagne
from lasagne.layers import *
from lasagne.objectives import *
from lasagne.nonlinearities import *
import gym
import numpy as np
from skimage.color import rgb2gray

In [36]:
def q_net(env):
    height, width, nchannels = env.observation_space.shape
    nchannels = 4 # we convert to black and white and use 4 prev frames
    l_in = InputLayer((None, nchannels, height, width))
    l_conv = Conv2DLayer(l_in, num_filters=32, filter_size=3, stride=2)
    l_conv2 = Conv2DLayer(l_conv, num_filters=64, filter_size=3, stride=2)
    l_conv3 = Conv2DLayer(l_conv2, num_filters=96, filter_size=3, stride=2)
    l_conv4 = Conv2DLayer(l_conv3, num_filters=128, filter_size=3, stride=2)
    l_dense = DenseLayer(l_conv4, num_units=env.action_space.n)
    return l_dense

In [37]:
X = T.tensor4('X')
y = T.fmatrix('y') # this is a row column
action_mask = T.fmatrix('action_mask')

In [38]:
env = gym.make('Pong-v0')
l_out = q_net(gym.make('Pong-v0'))
print count_params(l_out)
for layer in get_all_layers(l_out):
    print layer, layer.output_shape

[2017-03-16 12:18:56,939] Making new env: Pong-v0
[2017-03-16 12:18:56,973] Making new env: Pong-v0


268742
<lasagne.layers.input.InputLayer object at 0x10db9d510> (None, 4, 210, 160)
<lasagne.layers.conv.Conv2DLayer object at 0x10db9d810> (None, 32, 104, 79)
<lasagne.layers.conv.Conv2DLayer object at 0x10db9d890> (None, 64, 51, 39)
<lasagne.layers.conv.Conv2DLayer object at 0x10db9da90> (None, 96, 25, 19)
<lasagne.layers.conv.Conv2DLayer object at 0x10db9dfd0> (None, 128, 12, 9)
<lasagne.layers.dense.DenseLayer object at 0x10d8e81d0> (None, 6)


In [102]:
net_out = get_output(l_out, X)
loss = ( y - (action_mask*net_out).sum(axis=1, keepdims=True) )**2
loss = loss.mean()

In [103]:
out_fn = theano.function([X], net_out)

In [104]:
params = get_all_params(l_out, trainable=True)
params

[W, b, W, b, W, b, W, b, W, b]

In [105]:
updates = lasagne.updates.rmsprop(loss, params, learning_rate=0.1)

In [106]:
train_fn = theano.function([X,y,action_mask], loss, updates=updates)

--------

In [110]:
env = gym.make('Pong-v0')
# set up replay buffer and image buffers
buf_maxlen = 4
buf = []
buf_idx = []
experience = []
gamma = 1
mb_size = 4
for t in range(10000):
    # if the buffer is not at max length, do random actions
    # to fill it up
    if len(buf) != buf_maxlen:
        a_t = env.action_space.sample()
        x, _, _, _ = env.step(a_t)
        buf.append(rgb2gray(x))
        buf_idx.append(t)
    else:
        # phi_t is going to be the 4 most recent frames
        phi_t = np.asarray([
            buf[(t-1-3)%len(buf)], 
            buf[(t-1-2)%len(buf)], 
            buf[(t-1-1)%len(buf)], 
            buf[(t-1-0)%len(buf)]
        ]).astype("float32")
        debug_t = [buf_idx[(t-1-3)%len(buf_idx)], 
                   buf_idx[(t-1-2)%len(buf_idx)], 
                   buf_idx[(t-1-1)%len(buf_idx)], 
                   buf_idx[(t-1-0)%len(buf_idx)]]
        # with probability eps, select a random action
        a_t = env.action_space.sample()
        # execute action a_t in emulator and observe reward r_t and image x_t+1
        x_t1, r_t, is_done, info = env.step(a_t)
        # insert x_t+1 into the buffer, then grab the next
        # 4 most recent frames
        if not is_done:
            buf[ (t) % len(buf) ] = rgb2gray(x_t1)
            buf_idx[ (t) % len(buf) ] = t
            phi_t1 = np.asarray([
                buf[(t-3)%len(buf)], 
                buf[(t-2)%len(buf)], 
                buf[(t-1)%len(buf)], 
                buf[(t-0)%len(buf)]
            ]).astype("float32")
            debug_t1 = [buf_idx[(t-3)%len(buf_idx)], 
                       buf_idx[(t-2)%len(buf_idx)], 
                       buf_idx[(t-1)%len(buf_idx)], 
                       buf_idx[(t-0)%len(buf_idx)]]
        else:
            phi_t1 = phi_t
        #print debug_t, debug_t1
        # add this tuple to the experience buffer
        experience.append( {"phi_t":phi_t, "a_t":a_t, "r_t":r_t, "phi_t1":phi_t1, "is_done":is_done} )
        
        if len(experience) > mb_size:
            # sample from random experience from the buffer
            idxs = [i for i in range(0, len(experience))]
            np.random.shuffle(idxs)
            rand_transitions = [ experience[idx] for idx in idxs[0:mb_size] ]
            phi_t1_minibatch = np.asarray(
                [ rand_transitions[i]["phi_t1"] for i in range(len(rand_transitions)) ]).astype("float32")
            qvalues_minibatch = out_fn(phi_t1_minibatch)
            max_qvalues_minibatch = np.max(qvalues_minibatch,axis=1)
            y_minibatch = []
            for i in range(qvalues_minibatch.shape[0]):
                if rand_transitions[i]["is_done"]:
                    y_minibatch.append([rand_transitions[i]["r_t"]])
                else:
                    y_minibatch.append([rand_transitions[i]["r_t"]+gamma*max_qvalues_minibatch[i]])
            y_minibatch = np.asarray(y_minibatch).astype("float32")

            #print qvalues_minibatch
            #print y_minibatch
            
            mask_minibatch = np.zeros(qvalues_minibatch.shape).astype("float32")
            for i in range(qvalues_minibatch.shape[0]):
                mask_minibatch[ i, np.argmax(qvalues_minibatch[i]) ] = 1.
                
            #print mask_minibatch
            
            print "qvalues_minibatch", qvalues_minibatch.shape
            print "y_minibatch", y_minibatch.shape
            print "mask_minibatch", mask_minibatch.shape
            print "phi_t1_minibatch", phi_t1_minibatch.shape
            
            this_loss = train_fn(phi_t1_minibatch, y_minibatch, mask_minibatch)
            print this_loss
            
            #print test_fn(phi_t1_minibatch, mask_minibatch)
            
            break

[2017-03-16 13:00:21,523] Making new env: Pong-v0


qvalues_minibatch (4, 6)
y_minibatch (4, 1)
mask_minibatch (4, 6)
phi_t1_minibatch (4, 4, 210, 160)
9.92179289577e-20


In [54]:
idxs = [i for i in range(0, len(experience))]

In [28]:
np.max(np.asarray([[1,2],[3,4]]),axis=1)[0]

2

In [70]:
[1,2,3,4,5][ [3,2] ]

TypeError: list indices must be integers, not list

In [99]:
?T.sum