In [None]:
!pip install gym
!pip install gym-maze-trustycoder83

In [1]:
import gym
import gym_maze
import numpy as np 
import random
import time
import os
# Create video device
os.environ['SDL_VIDEODRIVER']='dummy'
import pygame
pygame.display.set_mode((640,480))

pygame 1.9.6
Hello from the pygame community. https://www.pygame.org/contribute.html


<Surface(640x480x8 SW)>

### Create environment

In [90]:
env = gym.make('maze-sample-5x5-v0')
state_n = 25
action_n = 4

### Creating state

In [76]:
def get_state(obs):
    return int(obs[1] * np.sqrt(state_n) + obs[0])

### Agent

In [77]:
class RandomAgent():
    def __init__(self, action_n):
        self.action_n = action_n
        return None
    
    def get_action(self, state):
        return random.randint(0, self.action_n - 1)

### Cross enthropy method

In [78]:
class CEM():
    def __init__(self, state_n, action_n):
        self.state_n = state_n
        self.action_n = action_n
        self.policy = np.ones((self.state_n, self.action_n)) / self.action_n
    
    def get_action(self, state):
        return int(np.random.choice(np.arange(self.action_n), p=self.policy[state]))
    
    def update_policy(self, elite_trajectories):
        pre_policy = np.zeros((self.state_n, self.action_n))
        
        for trajectory in elite_trajectories:
            for state, action in zip(trajectory['states'], trajectory['actions']):
                pre_policy[state][action] += 1
                
        for state in range(self.state_n):
            if sum(pre_policy[state]) == 0:
                self.policy[state] = np.ones(self.action_n) / self.action_n
            else:
                self.policy[state] = pre_policy[state] / sum(pre_policy[state])
                
        return None

### Creating agent tragectory

In [79]:
def get_trajectory(agent, trajectory_len):
    trajectory = {'states':[], 'actions': [], 'total_reward': 0}
    
    obs = env.reset()
    state = get_state(obs)
    trajectory['states'].append(state)
    
    for _ in range(trajectory_len):
        
        action = agent.get_action(state)
        trajectory['actions'].append(action)
        
        obs, reward, done, _ = env.step(action)
        state = get_state(obs)
        trajectory['total_reward'] += reward
        
        if done:
            break
            
        trajectory['states'].append(state)
            
    return trajectory


###Elite traectiries

In [80]:
def get_elite_trajectories(trajectories, q_param):
    total_rewards = [trajectory['total_reward'] for trajectory in trajectories]
    quantile = np.quantile(total_rewards, q=q_param) 
    return [trajectory for trajectory in trajectories if trajectory['total_reward'] > quantile]

## Test

### Random agent

In [91]:
agent = RandomAgent(action_n)
trajectory = get_trajectory(agent, trajectory_len=20)
print(trajectory)

{'states': [0, 1, 2, 2, 3, 3, 2, 2, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1], 'actions': [3, 3, 0, 3, 3, 2, 1, 2, 0, 2, 2, 1, 2, 2, 3, 2, 2, 3, 2, 3], 'total_reward': -0.08000000000000004}


### Cross-entropy agent

In [92]:
agent = CEM(state_n, action_n)
episode_n = 50
trajectory_n = 100
trajectory_len = 100
q_param = 0.9


for _ in range(episode_n):
    trajectories = [get_trajectory(agent, trajectory_len) for _ in range(trajectory_n)]
    
    mean_total_reward = np.mean([trajectory['total_reward'] for trajectory in trajectories])
    print(mean_total_reward)
    
    elite_trajectories = get_elite_trajectories(trajectories, q_param)
    
    if len(elite_trajectories) > 0:
        agent.update_policy(elite_trajectories)

-0.37932000000000016
0.3755999999999997
0.7961199999999998
0.8821599999999996
0.92124
0.9362
0.9350800000000001
0.9363999999999999
0.9367599999999999
0.9359599999999999
0.9356799999999998
0.93548
0.9359599999999998
0.9359199999999999
0.9355999999999998
0.9358
0.9357999999999999
0.9351599999999999
0.9365200000000001
0.9347599999999998
0.9361999999999998
0.9356799999999998
0.9368799999999997
0.9359599999999998
0.9357599999999999
0.9351599999999999
0.93536
0.9354799999999998
0.93604
0.93648
0.9358799999999999
0.9363999999999999
0.9364400000000002
0.9361200000000003
0.93612
0.93612
0.93628
0.9363199999999998
0.9358399999999999
0.9358
0.9352399999999997
0.9361599999999999
0.9359199999999999
0.9354399999999998
0.9355199999999999
0.9357599999999998
0.9357999999999999
0.93556
0.9361999999999998
0.9365599999999997
