In [8]:
from collections import deque, namedtuple
from datetime import datetime

import gym
from IPython.core.display import clear_output
from keras import backend as kb
from keras.layers import Dense, Activation, Dropout, Flatten, Lambda
from keras.layers.convolutional import Conv2D
from keras.models import Sequential, load_model
from keras.optimizers import Adam
from keras.layers.advanced_activations import LeakyReLU
import tensorflow as tf

RANDOM_SEED = 42
GAMMA = 0.99

Replay = namedtuple('Replay', ['prev', 'action', 'reward', 'next'])

env = gym.make('MsPacman-v0')
env.seed(RANDOM_SEED)

def dqn_loss(y_true, y_pred):
    mask = kb.one_hot(kb.argmax(y_true, axis=-1), env.action_space.n)
    loss = kb.max(kb.square(y_true - y_pred) * mask, axis=-1)
    return loss


[2017-08-05 17:39:05,912] Making new env: MsPacman-v0


In [9]:
kb.clear_session()
kb.reset_uids()

model = Sequential()
model.add(Conv2D(input_shape=(4, 80, 80), 
                 data_format='channels_first',
                 filters=16, 
                 kernel_size=(8, 8), 
                 strides=4,
                 kernel_initializer='he_normal',
                 activation='relu'))
model.add(Conv2D(filters=32, 
                 kernel_size=(4, 4), 
                 strides=2,
                 kernel_initializer='he_normal',
                 activation='relu'))
model.add(Flatten())
model.add(Dense(units=256,
                kernel_initializer='he_normal',
                activation='relu'))
model.add(Dense(units=env.action_space.n, 
                activation='linear'))
model.compile(loss=dqn_loss,
              optimizer=Adam())
model.save('/var/pylon/models/ale-latest.h5')

memory = deque(maxlen=int(1e6))
total_frames = 0
episode = 0
log_mode = 'w'


In [18]:
started = datetime.utcnow()
messages = deque(maxlen=20)

with open('/var/pylon/dqn.log', log_mode) as log:
    log_mode = 'a'
    while total_frames < 2e6:
        action = None
        total_reward = 0
        episode_frames = 0
        loss = 0
        game_over = False
        env.reset()
        prev_state = []

        while not game_over:
            epsilon = max(0.9 - total_frames / 1e6, 0) + 0.1
            if random.random() > epsilon and len(prev_state) == 4:
                action = np.argmax(model.predict_on_batch(np.array([prev_state]))[0])
            else:
                action = random.randrange(env.action_space.n)

            observation, reward, game_over, _ = env.step(action)
            image = (np
                     .dot(observation, [0.2989, 0.5870, 0.1140])
                     .reshape(observation.shape[:2])
                     [5:165:2,::2]
                     / 255)
            reward = np.clip(reward, -1., 1.)
            total_reward += reward

            next_state = prev_state[-3:] + [image]

            if len(prev_state) == 4:
                memory.append(Replay(prev=prev_state,
                                     action=action,
                                     reward=reward,
                                     next=next_state if not game_over else None))

            if total_frames > 5e4:
                sample = random.sample(memory, 32)

                x_future = np.array([s.next if s.next else np.zeros((4, 80, 80))
                                     for s in sample])
                y_future = model.predict_on_batch(x_future)

                y_action = np.array([s.action for s in sample])
                y_reward = (np.array([s.reward for s in sample])
                            + GAMMA 
                              * y_future.max(axis=1)
                              * [bool(s.next is not None) for s in sample])

                x_train = np.array([s.prev for s in sample])
                y_train = (np.eye(env.action_space.n)[y_action]
                           * y_reward[:,np.newaxis])
                loss = model.train_on_batch(x_train, y_train)

            if total_frames % int(1e5) == 0:
                model.save('/var/pylon/models/ale-{0:%Y%m%d%H%M%S}-{1:08d}.h5'.format(started, total_frames))
                model.save('/var/pylon/models/ale-latest.h5')

            episode_frames += 1
            total_frames += 1
            prev_state = next_state

        message = ("Episode: {0:5d} ended with frames: {1:5d}, score: {2:.2f}, loss: {3:.6f}, total frames: {4}"
                   .format(episode, episode_frames, total_reward, loss, total_frames))
        print(message)
        log.write(message + "\n")

        episode += 1


Episode:    91 ended with frames:   569, score: 16.00, loss: 2271.822998, total frames: 59388
Episode:    92 ended with frames:   554, score: 17.00, loss: 6.564946, total frames: 59942
Episode:    93 ended with frames:   665, score: 15.00, loss: 46.722626, total frames: 60607
Episode:    94 ended with frames:   566, score: 22.00, loss: 2879802.500000, total frames: 61173
Episode:    95 ended with frames:   541, score: 16.00, loss: 493071.187500, total frames: 61714
Episode:    96 ended with frames:   642, score: 28.00, loss: 189575.593750, total frames: 62356
Episode:    97 ended with frames:   661, score: 25.00, loss: 78978.000000, total frames: 63017
Episode:    98 ended with frames:   500, score: 20.00, loss: 50578.375000, total frames: 63517
Episode:    99 ended with frames:   677, score: 22.00, loss: 28246.908203, total frames: 64194
Episode:   100 ended with frames:   608, score: 24.00, loss: 10237.169922, total frames: 64802
Episode:   101 ended with frames:   679, score: 25.00,

KeyboardInterrupt: 

In [20]:
action = 0
total_reward = 0
episode_frames = 0
game_over = False
env.reset()
prev_state = []
while not game_over:
    if len(prev_state) == 4:
        action = np.argmax(model.predict_on_batch(np.array([prev_state]))[0])
        print(action, model.predict_on_batch(np.array([prev_state]))[0])

    observation, reward, game_over, _ = env.step(action)
    image = (np
             .dot(observation, [0.2989, 0.5870, 0.1140])
             .reshape(observation.shape[:2])
             .astype('uint8')
             [5:165:2,::2]
             / 255)
    total_reward += reward

    next_state = prev_state[-3:] + [image]

    episode_frames += 1
    total_frames += 1

    prev_state = next_state

print("Episode: {0:5d} ended with frames: {1:5d}, score: {2:6f}, loss: {3:.6f}, total frames: {4}"
      .format(0, episode_frames, total_reward, loss, total_frames))


3 [ 200.03338623  199.75456238  201.51683044  203.39750671  200.67459106
  201.45968628  199.41487122  202.1463623   201.40844727]
3 [ 200.97042847  200.87467957  202.50009155  204.23176575  201.71890259
  202.43125916  200.5009613   203.11985779  202.35676575]
3 [ 200.62509155  200.56773376  202.22265625  203.94032288  201.41944885
  202.15776062  200.17127991  202.83261108  202.06373596]
3 [ 200.58709717  200.55969238  202.21611023  203.92930603  201.40682983
  202.13449097  200.14628601  202.81214905  202.04605103]
3 [ 199.76651001  199.55397034  201.32432556  203.19006348  200.46546936
  201.26983643  199.18135071  201.93760681  201.20010376]
3 [ 199.11975098  198.70396423  200.59060669  202.59931946  199.68754578
  200.54832458  198.38949585  201.21548462  200.50756836]
3 [ 199.15617371  198.67259216  200.56738281  202.62374878  199.65591431
  200.54048157  198.39147949  201.204422    200.50065613]
3 [ 199.78555298  199.48239136  201.27212524  203.1934967   200.40960693
  201.2198

3 [ 198.91603088  198.485672    200.38801575  202.42118835  199.47961426
  200.35800171  198.16099548  201.02850342  200.31280518]
3 [ 199.4788208   199.12568665  200.95210266  202.90725708  200.08332825
  200.91223145  198.78753662  201.59725952  200.87008667]
3 [ 199.34127808  198.99720764  200.83108521  202.78631592  199.95739746
  200.80586243  198.64785767  201.48600769  200.76171875]
3 [ 199.09085083  198.76373291  200.62037659  202.58706665  199.72119141
  200.58863831  198.39335632  201.2645874   200.54219055]
3 [ 199.55857849  199.30792236  201.10470581  202.99919128  200.23124695
  201.05899048  198.94396973  201.72244263  200.99165344]
3 [ 199.16900635  198.83746338  200.70069885  202.65400696  199.80955505
  200.65354919  198.49601746  201.32003784  200.59695435]
3 [ 199.98204041  199.74703979  201.50553894  203.37995911  200.65235901
  201.4407959   199.40345764  202.12452698  201.38894653]
3 [ 200.36793518  200.25756836  201.94676208  203.72904968  201.11730957
  201.8858

3 [ 183.61271667  180.73846436  184.59220886  188.7615509   182.9190979
  184.9500885   180.51889038  185.66049194  185.34487915]
3 [ 187.95834351  185.79766846  189.10110474  192.64152527  187.62005615
  189.36412048  185.55552673  190.02076721  189.60386658]
3 [ 191.76873779  190.23077393  193.06602478  196.05892944  191.74212646
  193.21861267  189.96788025  193.84455872  193.34373474]
3 [ 195.54937744  194.51382446  196.87463379  199.40776062  195.72079468
  196.94471741  194.27079773  197.5602417   196.98951721]
3 [ 195.18682861  194.0032959   196.43933105  199.07391357  195.2631073
  196.51747131  193.80670166  197.13372803  196.58287048]
3 [ 196.15023804  195.1224823   197.42863464  199.93273926  196.30740356
  197.48353577  194.90965271  198.11027527  197.5297699 ]
3 [ 196.33651733  195.35568237  197.63383484  200.1022644   196.52468872
  197.68966675  195.13928223  198.31724548  197.73043823]
3 [ 196.33758545  195.41586304  197.68591309  200.10871887  196.56698608
  197.730773

3 [ 181.90621948  179.42112732  183.38989258  187.26660156  181.72015381
  183.61151123  178.98736572  184.50630188  184.00157166]
3 [ 182.39154053  179.91101074  183.85066223  187.69178772  182.20271301
  184.05537415  179.50382996  184.95922852  184.44667053]
3 [ 181.8062439   179.33583069  183.36094666  187.17741394  181.65911865
  183.5675354   178.90290833  184.46488953  183.91156006]
3 [ 181.89866638  179.44822693  183.47602844  187.30064392  181.76022339
  183.67382812  179.02526855  184.53735352  184.02407837]
3 [ 182.74328613  180.26515198  184.22886658  188.01574707  182.52787781
  184.44148254  179.91123962  185.25587463  184.76631165]
3 [ 183.8447876   181.37399292  185.26391602  188.9672699   183.59584045
  185.4315033   181.08677673  186.24790955  185.77232361]
3 [ 183.07769775  180.41075134  184.43487549  188.27452087  182.70347595
  184.60391235  180.16624451  185.4438324   184.97940063]
3 [ 183.2671814   180.64382935  184.66905212  188.44577026  182.90567017
  184.8549

3 [ 196.43486023  195.49261475  197.739151    200.19168091  196.64907837
  197.80166626  195.25335693  198.43408203  197.84265137]
3 [ 196.22790527  195.35176086  197.60235596  200.01873779  196.50447083
  197.67030334  195.07740784  198.28968811  197.68992615]
3 [ 195.89962769  195.00340271  197.3026886   199.73812866  196.18063354
  197.34712219  194.71824646  197.97953796  197.38145447]
3 [ 195.18965149  194.10549927  196.51164246  199.09779358  195.34835815
  196.58459473  193.8651886   197.20600891  196.63543701]
3 [ 194.83374023  193.65145874  196.11791992  198.76168823  194.9276886
  196.20150757  193.43911743  196.82237244  196.26347351]
3 [ 195.40606689  194.29876709  196.69319153  199.27038574  195.53808594
  196.76037598  194.07041931  197.40658569  196.83682251]
3 [ 195.27929688  194.22355652  196.61317444  199.16926575  195.45307922
  196.69891357  193.97731018  197.32421875  196.74905396]
3 [ 195.35542297  194.35896301  196.73120117  199.25523376  195.57795715
  196.79740

3 [ 196.0647583   195.24952698  197.53753662  199.89567566  196.49671936
  197.490448    194.92744446  198.21092224  197.58978271]
3 [ 196.11361694  195.33930969  197.61947632  199.96887207  196.57922363
  197.58984375  195.01805115  198.28309631  197.66217041]
3 [ 195.71076965  194.99745178  197.29762268  199.63108826  196.23922729
  197.26042175  194.62510681  197.96144104  197.32466125]
3 [ 195.79800415  195.11213684  197.39176941  199.7219696   196.34060669
  197.34132385  194.72505188  198.05101013  197.40733337]
3 [ 195.71131897  194.94644165  197.24867249  199.62756348  196.1893158
  197.22563171  194.5990448   197.914505    197.27267456]
3 [ 195.88093567  195.04014587  197.35058594  199.76036072  196.28337097
  197.30870056  194.73243713  198.01637268  197.39912415]
3 [ 196.05206299  195.23527527  197.51283264  199.9102478   196.47018433
  197.46925354  194.92718506  198.18516541  197.56103516]
3 [ 195.56414795  194.74212646  197.06570435  199.49900818  196.00190735
  197.03598

In [21]:
for i in range(5):
    model2 = Sequential()
    for l in model.layers[:i+1]:
        model2.add(l)
    model2.set_weights(model.get_weights())
    print(len(np.unique(model2.predict_on_batch(np.array([prev_state])))))


2566
496
496
28
9


In [None]:
model.save('/var/pylon/models/ale-{0:%Y%m%d%H%M%S}-{1:08d}.h5'.format(started, total_frames))
model.save('/var/pylon/models/ale-latest.h5')
