Original code taken from [https://gist.github.com/EderSantana/c7222daa328f0e885093](https://gist.github.com/EderSantana/c7222daa328f0e885093)

# Installation
To be able to run the animation below, make sure you have the latest version of matplotlib, by running `pip3 install matplotlib --upgrade`

In [1]:
%matplotlib inline
import json
import numpy as np
import random
from keras.models import Sequential
from keras.layers.core import Dense
from keras.optimizers import sgd
import matplotlib.pyplot as plt
import matplotlib.animation
import IPython.display
import time

Using TensorFlow backend.


## Define the game environment and replay classes
The idea in this Catch game is that there is fruit falling, and the user gets to move a basket so that they catch the fruit. If they catch it, they win and the game is over. If they miss it, they lose and the game is over. We are trying to teach the computer to play this game.

In [2]:
class Catch(object):
    def __init__(self, grid_size=10):
        '''        
        Initializes internal state.
        '''
        self.grid_size = grid_size
        self.min_basket_center = 1
        self.max_basket_center = self.grid_size-2
        self.reset()

    def _update_state(self, action):
        '''
        Input: action (0 for left, 1 for stay, 2 for right)
        
        Moves basket according to action. Moves fruit down. Updates state to reflect these movements
        '''
        if action == 0:  # left
            movement = -1
        elif action == 1:  # stay
            movement = 0
        elif action == 2: # right
            movement = 1
        else:
            raise Exception('Invalid action {}'.format(action))
        fruit_x, fruit_y, basket_center = self.state
        # move the basket unless this would move it off the edge of the grid
        new_basket_center = min(max(self.min_basket_center, basket_center + movement), self.max_basket_center)
        # move fruit down
        fruit_y += 1
        out = np.asarray([fruit_x, fruit_y, new_basket_center])
        self.state = out

    def _draw_state(self):
        '''
        Returns a 2D numpy array with 1s (white squares) at the locations of the fruit and basket and
        0s (black squares) everywhere else.
        '''
        im_size = (self.grid_size, self.grid_size)
        canvas = np.zeros(im_size)
        
        fruit_x, fruit_y, basket_center = self.state
        canvas[fruit_y, fruit_x] = 1  # draw fruit
        canvas[-1, basket_center-1:basket_center + 2] = 1  # draw 3-pixel basket
        return canvas

    def _get_reward(self):
        '''
        Returns 1 if the fruit was caught, -1 if it was dropped, and 0 if it is still in the air.
        '''
        fruit_x, fruit_y, basket_center = self.state
        if fruit_y == self.grid_size-1:
            if abs(fruit_x - basket_center) <= 1:
                return 1 # it caught the fruit
            else:
                return -1 # it dropped the fruit
        else:
            return 0 # the fruit is still in the air

    def observe(self):
        '''
        Returns the current canvas, as a 1D array.
        '''
        canvas = self._draw_state()
        return canvas.reshape((1, -1))

    def act(self, policy):
        '''
        Input: policy (a 10x10x10 array of actions for each possible state of fruit_x, fruit_y, and basket_center
        with 0 for left, 1 for stay, 2 for right)
        
        Returns:
            current canvas (as a 1D array)
            reward received after this action
            True if game is over and False otherwise
        '''
        fruit_x, fruit_y, basket_center = self.state
        action = policy[fruit_x][fruit_y][basket_center]
#         print("Policy:")
#         print(policy)
#         print("Action:")
#         print(action)
        self._update_state(action)
        observation = self.observe()
        reward = self._get_reward()
        game_over = (reward != 0) # if the reward is zero, the fruit is still in the air
        return observation, reward, game_over

    def reset(self):
        '''
        Updates internal state
            fruit in a random column in the top row
            basket center in a random column
        '''
        fruit_x = random.randint(0, self.grid_size-1)
        fruit_y = 0
        basket_center = random.randint(self.min_basket_center, self.max_basket_center)
        self.state = np.asarray([fruit_x, fruit_y, basket_center])

In [3]:
class ExperienceReplay(object):
    def __init__(self, max_memory=100, discount=.9):
        self.max_memory = max_memory
        self.memory = list()
        self.discount = discount

    def remember(self, states, game_over):
        '''
        Input:
            states: [starting_observation, action_taken, reward_received, new_observation]
            game_over: boolean
        Add the states and game over to the internal memory array. If the array is longer than
        self.max_memory, drop the oldest memory
        '''
        self.memory.append([states, game_over])
        if len(self.memory) > self.max_memory:
            del self.memory[0]

    def get_batch(self, model, batch_size=10):
        '''
        Randomly chooses batch_size memories, possibly repeating.
        For each of these memories, updates the models current best guesses about the value of taking a
            certain action from the starting state, based on the reward received and the model's current
            estimate of how valuable the new state is.
        '''
        len_memory = len(self.memory)
        num_actions = model.output_shape[-1] # the number of possible actions
        env_dim = self.memory[0][0][0].shape[1] # the number of pixels in the image
        input_size = min(len_memory, batch_size)
        inputs = np.zeros((input_size, env_dim))
        targets = np.zeros((input_size, num_actions))
        for i, idx in enumerate(np.random.randint(0, len_memory, size=input_size)):
            starting_observation, action_taken, reward_received, new_observation = self.memory[idx][0]
            game_over = self.memory[idx][1]

            # Set the input to the state that was observed in the game before an action was taken
            inputs[i:i+1] = starting_observation
            
            # Start with the model's current best guesses about the value of taking each action from this state
            targets[i] = model.predict(starting_observation)[0]
            
            # Now we need to update the value of the action that was taken                      
            if game_over: 
                # if the game is over, give the actual reward received
                targets[i, action_taken] = reward_received
            else:
                # if the game is not over, give the reward received (always zero in this particular game)
                # plus the maximum reward predicted for state we got to by taking this action (with a discount)
                Q_sa = np.max(model.predict(new_observation)[0])
                targets[i, action_taken] = reward_received + self.discount * Q_sa
        return inputs, targets

##  Functions for creating, training, and visualizing the model

In [4]:
# parameters
epsilon = .1  # probability of exploration (choosing a random action instead of the current best one)
num_actions = 3  # [move_left, stay, move_right]
max_memory = 500
hidden_size = 100
batch_size = 50
grid_size = 10

def run_episode(env, policy, grid_size=10, episode_len=100):
    total_reward = 0
    obs = env.reset()
    for t in range(episode_len):
        fruit_x, fruit_y, basket_center = env.state
        obs, reward, done = env.act(policy)
        total_reward += reward
        if done:
            # print('Episode finished after {} timesteps.'.format(t+1))
            break
    return total_reward

def evaluate_policy(env, policy, n_episodes=100):
    total_rewards = 0.0
    for _ in range(n_episodes):
        total_rewards += run_episode(env, policy)
    return total_rewards / n_episodes

def gen_random_policy():
    return np.random.choice(3, size=((10,10,10)))

def crossover(policy1, policy2, grid_size=10):
    new_policy = policy1.copy()
    for i in range(10):
        for j in range(10):
            for k in range(10):
                rand = np.random.uniform()
                if rand > 0.5:
                    new_policy[i][j][k] = policy2[i][j][k]
    return new_policy

def mutation(policy, p=0.05):
    new_policy = policy.copy()
    for i in range(10):
        for j in range(10):
            for k in range(10):
                rand = np.random.uniform()
                if rand < p:
                    new_policy[i] = np.random.choice(4)
    return new_policy

def build_model():
    '''
     Returns three initialized objects: the model, the environment, and the replay.
    '''
    model = Sequential()
    model.add(Dense(hidden_size, input_shape=(grid_size**2,), activation='relu'))
    model.add(Dense(hidden_size, activation='relu'))
    model.add(Dense(num_actions))
    model.compile(sgd(lr=.2), "mse")

    # Define environment/game
    env = Catch()

    # Initialize experience replay object
    exp_replay = ExperienceReplay(max_memory=max_memory)
    
    return model, env, exp_replay

def create_animation(model, env, num_games):
    '''
    Inputs:
        model and env objects as returned from build_model
        num_games: integer, the number of games to be included in the animation
        
    Returns: a matplotlib animation object
    '''
    # Animation code from 
    # https://matplotlib.org/examples/animation/dynamic_image.html
    # https://stackoverflow.com/questions/35532498/animation-in-ipython-notebook/46878531#46878531
    
    # First, play the games and collect all of the images for each observed state
    observations = []
    for _ in range(num_games):
        env.reset()
        observation = env.observe()
        observations.append(observation)
        game_over = False
        while game_over == False:
            q = model.predict(observation)
            action = np.argmax(q[0])
            
            # apply action, get rewards and new state
            observation, reward, game_over = env.act(policy)
            observations.append(observation)
            
    fig = plt.figure()
    image = plt.imshow(np.zeros((grid_size, grid_size)),interpolation='none', cmap='gray', animated=True, vmin=0, vmax=1)
    
    def animate(observation):
        image.set_array(observation.reshape((grid_size, grid_size)))
        return [image]
   
    animation = matplotlib.animation.FuncAnimation(fig, animate, frames=observations, blit=True, )
    return animation

In [5]:
random.seed(4904) # we kick bot
np.random.seed(4904)
model, env, exp_replay = build_model()
## Policy search
n_policy = 100
n_steps = 1
# start = time.time()

# note to self: the underscore can be used as a throwaway value when iterating through something,
# like when you want to do something x times but don't care about the value of x
policy_pop = [gen_random_policy() for _ in range(n_policy)]

env.reset()
env.state
policy_pop
test_policy = policy_pop[0]

for idx in range(n_steps):
    policy_scores = [evaluate_policy(env, p) for p in policy_pop]
    print('Generation %d : max score = %0.2f' %(idx+1, max(policy_scores)))
    policy_ranks = list(reversed(np.argsort(policy_scores)))
    elite_set = [policy_pop[x] for x in policy_ranks[:5]]
    select_probs = np.array(policy_scores) / np.sum(policy_scores)
    child_set = [crossover(
        policy_pop[np.random.choice(range(n_policy), p=select_probs)],
        policy_pop[np.random.choice(range(n_policy), p=select_probs)])
        for _ in range(n_policy - 5)]
    mutated_list = [mutation(p) for p in child_set]
    policy_pop = elite_set
    policy_pop += mutated_list
policy_score = [evaluate_policy(env, p) for p in policy_pop]
best_policy = policy_pop[np.argmax(policy_score)]

end = time.time()
print('Best policy score = %0.2f.'
        %(np.max(policy_score)))

Generation 1 : max score = -0.02


Exception: Invalid action 3

In [None]:
# test_policy[0][0][0]
a = env.act(test_policy)
fig = plt.figure()
image = plt.imshow(np.zeros((grid_size, grid_size)),interpolation='none', cmap='gray', animated=True, vmin=0, vmax=1)

def animate(observation):
    image.set_array(observation.reshape((grid_size, grid_size)))
    return [image]

animation = matplotlib.animation.FuncAnimation(fig, animate, frames=a, blit=True, )