In [18]:
from collections import deque, namedtuple
from datetime import datetime

from ale_python_interface import ALEInterface
from IPython.core.display import clear_output
from keras import backend as kb
from keras.layers import Dense, Activation, Dropout, Flatten, Lambda
from keras.layers.convolutional import Conv2D
from keras.models import Sequential
from keras.optimizers import Adam
from keras.layers.advanced_activations import LeakyReLU
import tensorflow as tf

RANDOM_SEED = 42
GAMMA = 0.99

screen_buffer = np.empty((210, 160), dtype=np.uint8)

Replay = namedtuple('Replay', ['prev', 'action', 'reward', 'next'])


ale = ALEInterface()
ale.setInt(b'random_seed', random.randrange(10000))
ale.setBool(b'display_screen', False)
ale.setBool(b'sound', False)
ale.loadROM(b'/var/roms/ms_pacman.bin')

ACTIONS = ale.getMinimalActionSet()

def dqn_loss(y_true, y_pred):
    # keep only argmax action reward, calculate absolute error
    mask = kb.one_hot(kb.argmax(y_true, axis=-1), len(ACTIONS))
    loss = kb.max(kb.abs(y_true - y_pred) * mask, axis=-1)

    # huber clipped loss
    CLIP = 1.
    condition = loss < CLIP
    squared_loss = 0.5 * kb.square(loss)
    linear_loss = CLIP * (loss - 0.5 * CLIP)
    return tf.where(condition, squared_loss, linear_loss)


In [19]:
kb.clear_session()
kb.reset_uids()

model = Sequential()
model.add(Conv2D(input_shape=(4, 80, 80), 
                 data_format='channels_first',
                 filters=16, 
                 kernel_size=(8, 8), 
                 strides=4,
                 kernel_initializer='he_normal',
                 activation='relu'))
model.add(Conv2D(filters=32, 
                 kernel_size=(4, 4), 
                 strides=2,
                 kernel_initializer='he_normal',
                 activation='relu'))
model.add(Flatten())
model.add(Dense(units=256,
                kernel_initializer='he_normal',
                activation='relu'))
model.add(Dense(units=len(ACTIONS), 
                activation='linear'))
model.compile(loss=dqn_loss,
              optimizer='rmsprop')

memory = deque(maxlen=int(1e6))
total_frames = 0
episode = 0
log_mode = 'w'


In [20]:
started = datetime.utcnow()
prev_state = []

with open('/var/pylon/dqn.log', log_mode) as log:
    log_mode = 'a'
    while total_frames < 100000:
        action = None
        total_reward = 0
        episode_frames = 0
        loss = 0
        print_rewards = True

        while not ale.game_over():
            if episode_frames % 4 == 0:
                epsilon = max(0.9 - total_frames / 1e6, 0) + 0.1
                if random.random() > epsilon and len(prev_state) == 4:
                    rewards = model.predict_on_batch(np.array([prev_state]) / 255)[0]
                    action = np.argmax(rewards)
                    if print_rewards:
                        print(rewards)
                        print_rewards = False
                else:
                    action = random.randrange(len(ACTIONS))

            reward = ale.act(ACTIONS[action])
            reward = np.clip(reward, -1., 1.)
            total_reward += reward

            ale.getScreenGrayscale(screen_buffer)
            next_state = prev_state[-3:] + [screen_buffer[5:165:2,::2].copy()]

            if episode_frames > 90:
                memory.append(Replay(prev=prev_state,
                                     action=action,
                                     reward=reward,
                                     next=next_state if not ale.game_over() else None))

                if total_frames > 5e4 and episode_frames % 4 == 0:
                    sample = random.sample(memory, 32)

                    x_future = np.array([s.next if s.next else np.zeros((4, 80, 80))
                                         for s in sample]) / 255
                    y_future = model.predict_on_batch(x_future)

                    y_action = np.array([s.action for s in sample])
                    y_reward = (np.array([s.reward for s in sample])
                                + GAMMA 
                                  * y_future.max(axis=1)
                                  * [bool(s.next is not None) for s in sample])

                    x_train = np.array([s.prev for s in sample]) / 255
                    y_train = (np.eye(len(ACTIONS))[y_action]
                               * y_reward[:,np.newaxis])

                    loss = model.train_on_batch(x_train, y_train)

                if total_frames % int(1e5) == 0:
                    model.save('/var/pylon/models/ale-{0:%Y%m%d%H%M%S}-{1:08d}.h5'.format(started, total_frames))
                    model.save('/var/pylon/models/ale-latest.h5')

            episode_frames += 1
            total_frames += 1
            prev_state = next_state

        message = ("Episode: {0:5d} ended with frames: {1:5d}, score: {2:.2f}, loss: {3:.6f}, total frames: {4}"
                   .format(episode, episode_frames, total_reward, loss, total_frames))
        print(message)
        log.write(message + "\n")

        ale.reset_game()
        episode += 1


[-0.56751084  0.70651287 -0.2473015   0.20678973 -0.09428598  0.19464289
  0.61059624  0.51802731 -0.5351522 ]
Episode:     0 ended with frames:  2185, score: 31.00, loss: 0.000000, total frames: 2185
[-0.58425671  0.7428416  -0.29892993  0.14424779 -0.19347283  0.1940058
  0.5355221   0.50214714 -0.56052166]
Episode:     1 ended with frames:  2001, score: 31.00, loss: 0.000000, total frames: 4186
[-0.57629901  0.76552033 -0.3471936   0.09386498 -0.18738113  0.19354799
  0.53398031  0.55440414 -0.54927891]
Episode:     2 ended with frames:  1905, score: 27.00, loss: 0.000000, total frames: 6091
[-0.61949611  0.76315659 -0.33202702  0.14482383 -0.16712457  0.15538055
  0.52627319  0.46185955 -0.5174309 ]
Episode:     3 ended with frames:  1761, score: 16.00, loss: 0.000000, total frames: 7852
[-0.61533993  0.69955802 -0.28931266  0.17148399 -0.17954233  0.20871013
  0.53057009  0.47613508 -0.52411115]
Episode:     4 ended with frames:  2145, score: 26.00, loss: 0.000000, total frames: 9

Episode:    39 ended with frames:  2081, score: 38.00, loss: 4074.792969, total frames: 78768
[ 780933.875   766186.1875  770340.4375  782114.4375  776757.0625
  774861.1875  779200.1875  776344.6875  774533.5625]
Episode:    40 ended with frames:  1913, score: 19.00, loss: 2599.190430, total frames: 80681
[ 341269.1875   345556.375    347542.75     345372.15625  346982.09375
  342732.625    340880.09375  340993.4375   350034.625  ]
Episode:    41 ended with frames:  1585, score: 15.00, loss: 1760.607422, total frames: 82266
[ 316523.9375   314103.625    312477.96875  314924.1875   320166.46875
  313798.15625  314262.75     313780.625    315432.21875]
Episode:    42 ended with frames:  2145, score: 21.00, loss: 1545.102539, total frames: 84411
[ 253761.46875   256376.578125  255193.671875  258185.234375  255611.890625
  258774.46875   253709.46875   257552.859375  255421.421875]
Episode:    43 ended with frames:  2305, score: 23.00, loss: 3264.426758, total frames: 86716
[ 356709.4375 

In [None]:
ale = ALEInterface()
ale.setInt(b'random_seed', int(random.randrange(1000)))
ale.setBool(b'display_screen', False)
ale.setBool(b'sound', False)
ale.loadROM(b'/var/roms/ms_pacman.bin')

ale.getScreenGrayscale(screen_buffer)
prev_state = [screen_buffer[5:165:2,::2].copy() for i in range(4)]

action = None
total_reward = 0
episode_frames = 0
while not ale.game_over():
    action = np.argmax(model.predict_on_batch(np.array([prev_state]) / 255)[0])
    print(action, model.predict_on_batch(np.array([prev_state]) / 255)[0])

    reward = ale.act(ACTIONS[action])
    total_reward += reward

    ale.getScreenGrayscale(screen_buffer)
    next_state = prev_state[-3:] + [screen_buffer[5:165:2,::2].copy()]

    episode_frames += 1
    total_frames += 1

    if total_frames % int(1e6) == 0:
        model.save('/var/pylon/models/ale-{0:%Y%m%d%H%M%S}-{1:08d}.h5'.format(started, total_frames))
        model.save('/var/pylon/models/ale-latest.h5')

    prev_state = next_state

print("Episode: {0:5d} ended with frames: {1:5d}, score: {2:6d}, loss: {3:.6f}, total frames: {4}"
      .format(0, episode_frames, total_reward, loss, total_frames))
ale.reset_game()


7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.30540

7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.30540

7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.30540

7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.30540

7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.30540

7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.30540

7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.30540

7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.30540

7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.30540

7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.30540

7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.3054049   2.28209376]
7 [ 2.27306032  2.28147602  2.28262067  2.28257513  2.28589153  2.29913521
  2.28865647  2.30540

In [13]:
for i in range(5):
    model2 = Sequential()
    for l in model.layers[:i+1]:
        model2.add(l)
    model2.set_weights(model.get_weights())
    print(len(np.unique(model2.predict_on_batch(np.array([prev_state]) / 255))))

model.get_weights()[-1]

3056
1681
1681
1
9


array([ 2.29775023,  2.29657555,  2.29960728,  2.29406786,  2.30479908,
        2.31189585,  2.31715941,  2.29673624,  2.30769515], dtype=float32)

In [None]:
model.save('/var/pylon/models/ale-{0:%Y%m%d%H%M%S}-{1:08d}.h5'.format(started, total_frames))
model.save('/var/pylon/models/ale-latest.h5')
