In [24]:
import json
import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense
from keras.optimizers import sgd


class Catch(object):
    def __init__(self, grid_size=10):
        self.grid_size = grid_size
        self.reset()

    def _update_state(self, action):
        """
        Input: action and states
        Ouput: new states and reward
        """
        state = self.state
        # 0 = left
        # 1 = right
        # 2 = down
        # 3 = up
        direction_x = 0
        direction_y = 0
        
        fy, fx, py, px, t = state
        
        if action == 0:
            if px > 0:
                px -= 1
        if action == 1:
            if px < self.grid_size-1:
                px += 1
        if action == 2:
            if py > 0:
                py-= 1
        if action == 3:
            if py < self.grid_size-1:
                py += 1
                
        
        out = np.array([fy, fx, py, px, t-1])
        self.state = out

    def _draw_state(self):
        im_size = (self.grid_size,)*2
        state = self.state
        canvas = np.zeros(im_size)
        canvas[state[0], state[1]] = 1  # draw fruit
        canvas[state[2], state[3]] = 0.5  # draw basket
        return canvas

    def _get_reward(self):
        fruit_y, fruit_x, player_y, player_x, t = self.state
        if fruit_x == player_x and fruit_y == player_y:
            return 1
        else:
            return 0

    def _is_over(self):
        if self.state[4] == 0:
            return True
        else:
            return False

    def observe(self):
        canvas = self._draw_state()
#         print(self.state)
#         print(canvas)
#         input("!")
        return canvas.reshape((1, -1))

    def act(self, action):
        self._update_state(action)
        reward = self._get_reward()
        game_over = self._is_over()
        return self.observe(), reward, game_over

    def reset(self):
        fruit_x = np.random.randint(0, self.grid_size-1)
        fruit_y = np.random.randint(0, self.grid_size-1)
        player_x = np.random.randint(0, self.grid_size-1)
        player_y = np.random.randint(0, self.grid_size-1)
        self.state = np.asarray([fruit_y, fruit_x, player_y, player_x, 10])


class ExperienceReplay(object):
    def __init__(self, max_memory=100, discount=.9):
        self.max_memory = max_memory
        self.memory = list()
        self.discount = discount

    def remember(self, states, game_over):
        # memory[i] = [[state_t, action_t, reward_t, state_t+1], game_over?]
        self.memory.append([states, game_over])
        if len(self.memory) > self.max_memory:
            del self.memory[0]

    def get_batch(self, model, batch_size=10):
        len_memory = len(self.memory)
        num_actions = model.output_shape[-1]
        env_dim = self.memory[0][0][0].shape[1]
        inputs = np.zeros((min(len_memory, batch_size), env_dim))
        targets = np.zeros((inputs.shape[0], num_actions))
        for i, idx in enumerate(np.random.randint(0, len_memory,
                                                  size=inputs.shape[0])):
            state_t, action_t, reward_t, state_tp1 = self.memory[idx][0]
            game_over = self.memory[idx][1]

            inputs[i:i+1] = state_t
            # There should be no target values for actions not taken.
            # Thou shalt not correct actions not taken #deep
            targets[i] = model.predict(state_t)[0]
            Q_sa = np.max(model.predict(state_tp1)[0])
            if game_over:  # if game_over is True
                targets[i, action_t] = reward_t
            else:
                # reward_t + gamma * max_a' Q(s', a')
                targets[i, action_t] = reward_t + self.discount * Q_sa
        return inputs, targets


if __name__ == "__main__":
    # parameters
    epsilon = .1  # exploration
    num_actions = 4  # [move_left, stay, move_right]
    epoch = 1000
    max_memory = 500
    hidden_size = 100
    batch_size = 50
    grid_size = 10

    model = Sequential()
    model.add(Dense(hidden_size, input_shape=(grid_size**2,), activation='relu'))
    model.add(Dense(hidden_size, activation='relu'))
    model.add(Dense(num_actions))
    model.compile(sgd(lr=.2), "mse")

    # If you want to continue training from a previous model, just uncomment the line bellow
    # model.load_weights("model.h5")

    # Define environment/game
    env = Catch(grid_size)

    # Initialize experience replay object
    exp_replay = ExperienceReplay(max_memory=max_memory)

    # Train
    win_cnt = 0
    for e in range(epoch):
        loss = 0.
        env.reset()
        game_over = False
        # get initial input
        input_t = env.observe()

        while not game_over:
            input_tm1 = input_t
            # get next action
            if np.random.rand() <= epsilon:
                action = np.random.randint(0, num_actions, size=1)
            else:
                q = model.predict(input_tm1)
                action = np.argmax(q[0])

            # apply action, get rewards and new state
            input_t, reward, game_over = env.act(action)
            if reward == 1:
                win_cnt += 1

            # store experience
            exp_replay.remember([input_tm1, action, reward, input_t], game_over)

            # adapt model
            inputs, targets = exp_replay.get_batch(model, batch_size=batch_size)

            loss += model.train_on_batch(inputs, targets)
        print("Epoch {:03d}/999 | Loss {:.4f} | Win count {}".format(e, loss, win_cnt))

    # Save trained model weights and architecture, this will be used by the visualization code
    model.save_weights("model.h5", overwrite=True)
    with open("model.json", "w") as outfile:
        json.dump(model.to_json(), outfile)

Epoch 000/999 | Loss 0.0141 | Win count 0
Epoch 001/999 | Loss 0.1809 | Win count 1
Epoch 002/999 | Loss 0.1261 | Win count 1
Epoch 003/999 | Loss 0.0867 | Win count 1
Epoch 004/999 | Loss 0.0817 | Win count 1
Epoch 005/999 | Loss 0.0712 | Win count 1
Epoch 006/999 | Loss 0.0385 | Win count 1
Epoch 007/999 | Loss 0.0337 | Win count 1
Epoch 008/999 | Loss 0.0452 | Win count 1
Epoch 009/999 | Loss 0.0187 | Win count 1
Epoch 010/999 | Loss 0.0199 | Win count 1
Epoch 011/999 | Loss 0.0246 | Win count 1
Epoch 012/999 | Loss 0.0187 | Win count 1
Epoch 013/999 | Loss 0.0230 | Win count 1
Epoch 014/999 | Loss 0.0164 | Win count 1
Epoch 015/999 | Loss 0.0470 | Win count 4
Epoch 016/999 | Loss 0.0523 | Win count 4
Epoch 017/999 | Loss 0.0488 | Win count 4
Epoch 018/999 | Loss 0.0627 | Win count 4
Epoch 019/999 | Loss 0.0531 | Win count 4
Epoch 020/999 | Loss 0.0306 | Win count 4
Epoch 021/999 | Loss 0.0440 | Win count 4
Epoch 022/999 | Loss 0.0402 | Win count 4
Epoch 023/999 | Loss 0.0499 | Win 

Epoch 193/999 | Loss 0.0380 | Win count 30
Epoch 194/999 | Loss 0.0419 | Win count 30
Epoch 195/999 | Loss 0.0304 | Win count 30
Epoch 196/999 | Loss 0.0177 | Win count 30
Epoch 197/999 | Loss 0.0208 | Win count 30
Epoch 198/999 | Loss 0.0249 | Win count 30
Epoch 199/999 | Loss 0.0254 | Win count 30
Epoch 200/999 | Loss 0.0198 | Win count 30
Epoch 201/999 | Loss 0.0250 | Win count 30
Epoch 202/999 | Loss 0.0394 | Win count 31
Epoch 203/999 | Loss 0.0244 | Win count 31
Epoch 204/999 | Loss 0.0142 | Win count 31
Epoch 205/999 | Loss 0.0236 | Win count 31
Epoch 206/999 | Loss 0.0159 | Win count 31
Epoch 207/999 | Loss 0.0166 | Win count 31
Epoch 208/999 | Loss 0.0165 | Win count 31
Epoch 209/999 | Loss 0.0242 | Win count 31
Epoch 210/999 | Loss 0.0188 | Win count 32
Epoch 211/999 | Loss 0.0119 | Win count 32
Epoch 212/999 | Loss 0.0218 | Win count 33
Epoch 213/999 | Loss 0.0277 | Win count 33
Epoch 214/999 | Loss 0.0280 | Win count 33
Epoch 215/999 | Loss 0.0314 | Win count 33
Epoch 216/9

Epoch 384/999 | Loss 0.0020 | Win count 39
Epoch 385/999 | Loss 0.0021 | Win count 39
Epoch 386/999 | Loss 0.0018 | Win count 39
Epoch 387/999 | Loss 0.0023 | Win count 39
Epoch 388/999 | Loss 0.0024 | Win count 39
Epoch 389/999 | Loss 0.0026 | Win count 39
Epoch 390/999 | Loss 0.0021 | Win count 39
Epoch 391/999 | Loss 0.0026 | Win count 39
Epoch 392/999 | Loss 0.0018 | Win count 39
Epoch 393/999 | Loss 0.0019 | Win count 39
Epoch 394/999 | Loss 0.0021 | Win count 39
Epoch 395/999 | Loss 0.0024 | Win count 39
Epoch 396/999 | Loss 0.0019 | Win count 39
Epoch 397/999 | Loss 0.0024 | Win count 39
Epoch 398/999 | Loss 0.0016 | Win count 39
Epoch 399/999 | Loss 0.0025 | Win count 39
Epoch 400/999 | Loss 0.0022 | Win count 39
Epoch 401/999 | Loss 0.0021 | Win count 39
Epoch 402/999 | Loss 0.0021 | Win count 39
Epoch 403/999 | Loss 0.0019 | Win count 39
Epoch 404/999 | Loss 0.0018 | Win count 39
Epoch 405/999 | Loss 0.0019 | Win count 39
Epoch 406/999 | Loss 0.0018 | Win count 39
Epoch 407/9

Epoch 575/999 | Loss 0.0425 | Win count 51
Epoch 576/999 | Loss 0.0280 | Win count 51
Epoch 577/999 | Loss 0.0228 | Win count 51
Epoch 578/999 | Loss 0.0483 | Win count 51
Epoch 579/999 | Loss 0.0213 | Win count 51
Epoch 580/999 | Loss 0.0242 | Win count 51
Epoch 581/999 | Loss 0.0337 | Win count 51
Epoch 582/999 | Loss 0.0369 | Win count 51
Epoch 583/999 | Loss 0.0522 | Win count 51
Epoch 584/999 | Loss 0.0446 | Win count 51
Epoch 585/999 | Loss 0.0326 | Win count 51
Epoch 586/999 | Loss 0.0278 | Win count 51
Epoch 587/999 | Loss 0.0399 | Win count 51
Epoch 588/999 | Loss 0.0371 | Win count 51
Epoch 589/999 | Loss 0.0354 | Win count 51
Epoch 590/999 | Loss 0.0234 | Win count 51
Epoch 591/999 | Loss 0.0340 | Win count 51
Epoch 592/999 | Loss 0.0415 | Win count 51
Epoch 593/999 | Loss 0.0552 | Win count 51
Epoch 594/999 | Loss 0.0384 | Win count 51
Epoch 595/999 | Loss 0.0344 | Win count 53
Epoch 596/999 | Loss 0.0370 | Win count 53
Epoch 597/999 | Loss 0.0543 | Win count 53
Epoch 598/9

Epoch 766/999 | Loss 0.0329 | Win count 79
Epoch 767/999 | Loss 0.0300 | Win count 79
Epoch 768/999 | Loss 0.0275 | Win count 79
Epoch 769/999 | Loss 0.0235 | Win count 79
Epoch 770/999 | Loss 0.0240 | Win count 79
Epoch 771/999 | Loss 0.0372 | Win count 79
Epoch 772/999 | Loss 0.0312 | Win count 80
Epoch 773/999 | Loss 0.0271 | Win count 80
Epoch 774/999 | Loss 0.0037 | Win count 80
Epoch 775/999 | Loss 0.0328 | Win count 80
Epoch 776/999 | Loss 0.0166 | Win count 80
Epoch 777/999 | Loss 0.0164 | Win count 80
Epoch 778/999 | Loss 0.0122 | Win count 80
Epoch 779/999 | Loss 0.0118 | Win count 80
Epoch 780/999 | Loss 0.0112 | Win count 80
Epoch 781/999 | Loss 0.0209 | Win count 80
Epoch 782/999 | Loss 0.0235 | Win count 80
Epoch 783/999 | Loss 0.0117 | Win count 80
Epoch 784/999 | Loss 0.0262 | Win count 80
Epoch 785/999 | Loss 0.0030 | Win count 80
Epoch 786/999 | Loss 0.0198 | Win count 80
Epoch 787/999 | Loss 0.0195 | Win count 80
Epoch 788/999 | Loss 0.0075 | Win count 80
Epoch 789/9

Epoch 957/999 | Loss 0.0023 | Win count 92
Epoch 958/999 | Loss 0.0121 | Win count 92
Epoch 959/999 | Loss 0.0077 | Win count 92
Epoch 960/999 | Loss 0.0165 | Win count 93
Epoch 961/999 | Loss 0.0076 | Win count 93
Epoch 962/999 | Loss 0.0021 | Win count 93
Epoch 963/999 | Loss 0.0129 | Win count 93
Epoch 964/999 | Loss 0.0115 | Win count 93
Epoch 965/999 | Loss 0.0173 | Win count 93
Epoch 966/999 | Loss 0.0222 | Win count 93
Epoch 967/999 | Loss 0.0073 | Win count 93
Epoch 968/999 | Loss 0.0159 | Win count 93
Epoch 969/999 | Loss 0.0126 | Win count 93
Epoch 970/999 | Loss 0.0147 | Win count 93
Epoch 971/999 | Loss 0.0113 | Win count 93
Epoch 972/999 | Loss 0.0064 | Win count 93
Epoch 973/999 | Loss 0.0060 | Win count 93
Epoch 974/999 | Loss 0.0110 | Win count 93
Epoch 975/999 | Loss 0.0107 | Win count 93
Epoch 976/999 | Loss 0.0190 | Win count 93
Epoch 977/999 | Loss 0.0113 | Win count 93
Epoch 978/999 | Loss 0.0105 | Win count 93
Epoch 979/999 | Loss 0.0057 | Win count 93
Epoch 980/9

In [25]:
import json
import matplotlib.pyplot as plt
import numpy as np
from keras.models import model_from_json


if __name__ == "__main__":
    # Make sure this grid size matches the value used fro training
    grid_size = 10

    with open("model.json", "r") as jfile:
        model = model_from_json(json.load(jfile))
    model.load_weights("model.h5")
    model.compile("sgd", "mse")

    # Define environment, game
    env = Catch(grid_size)
    c = 0
    for e in range(10):
        loss = 0.
        env.reset()
        game_over = False
        # get initial input
        input_t = env.observe()

        plt.imshow(input_t.reshape((grid_size,)*2),
                   interpolation='none', cmap='gray')
        plt.savefig("%03d.png" % c)
        c += 1
        while not game_over:
            input_tm1 = input_t

            # get next action
            q = model.predict(input_tm1)
            action = np.argmax(q[0])

            # apply action, get rewards and new state
            input_t, reward, game_over = env.act(action)

            plt.imshow(input_t.reshape((grid_size,)*2),
                       interpolation='none', cmap='gray')
            plt.savefig("%03d.png" % c)
            c += 1

ffmpeg -i %03d.png output.gif -vf fps=1