In [1]:
import random
import copy
from collections import defaultdict
from collections import deque
from collections import namedtuple
import numpy as np
import unittest
import gym
from gym.spaces import Box
import os
import math
import argparse
from agent import Q, Agent, Trainer


In [None]:
class Q():

    def __init__(self, n_actions, observation_space, bin_size, low_bound=None, high_bound=None, initial_mean=0.0, initial_std=0.0):
        self.n_actions = n_actions
        self._observation_dimension = 1
        for d in observation_space.shape:
            self._observation_dimension *= d

        self._bin_sizes = bin_size if isinstance(bin_size, list) else [bin_size] * self._observation_dimension
        self._dimension_bins = []
        for i, low, high in self._low_high_iter(observation_space, low_bound, high_bound):
            b_size = self._bin_sizes[i]
            bins = self._make_bins(low, high, b_size)
            self._dimension_bins.append(bins)

        # if we encounter the new observation, we initialize action evaluations
        self.table = defaultdict(lambda: initial_std * np.random.randn(self.n_actions) + initial_mean)
    
    @classmethod
    def _make_bins(cls, low, high, bin_size):
        bins = np.arange(low, high, (float(high) - float(low)) / (bin_size - 2))  # exclude both ends
        if min(bins) < 0 and 0 not in bins:
            bins = np.sort(np.append(bins, [0]))  # 0 centric bins
        return bins
    
    @classmethod
    def _low_high_iter(cls, observation_space, low_bound, high_bound):
        lows = observation_space.low
        highs = observation_space.high
        for i in range(len(lows)):
            low = lows[i]
            if low_bound is not None:
                _low_bound = low_bound if not isinstance(low_bound, list) else low_bound[i]
                low = low if _low_bound is None else max(low, _low_bound)
            
            high = highs[i]
            if high_bound is not None:
                _high_bound = high_bound if not isinstance(high_bound, list) else high_bound[i]
                high = high if _high_bound is None else min(high, _high_bound)
            
            yield i, low, high

    def observation_to_state(self, observation):
        state = 0
        # caution: bin_size over 10 will not work accurately
        unit = max(self._bin_sizes)
        for d, o in enumerate(observation.flatten()):
            state = state + np.digitize(o, self._dimension_bins[d]) * pow(unit, d)  # bin_size numeral system
        return state
    
    def values(self, observation):
        state = self.observation_to_state(observation)
        return self.table[state]

In [None]:
class Agent():

    def __init__(self, q, epsilon=0.05):
        self.q = q
        self.epsilon = epsilon
    
    def act(self, observation):
        action = -1
        if np.random.random() < self.epsilon:
            action = np.random.choice(self.q.n_actions)
        else:
            action = np.argmax(self.q.values(observation))
        
        return action


class Trainer():

    def __init__(self, agent, gamma=0.95, learning_rate=0.1, learning_rate_decay=None, epsilon=0.05, epsilon_decay=None, max_step=-1):
        self.agent = agent
        self.gamma = gamma
        self.learning_rate = learning_rate
        self.learning_rate_decay = learning_rate_decay
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.max_step = max_step

    def train(self, env, episode_count, render=False):
        default_epsilon = self.agent.epsilon
        self.agent.epsilon = self.epsilon
        values = []
        steps = deque(maxlen=100)
        lr = self.learning_rate
        for i in range(episode_count):
            obs = env.reset()
            step = 0
            done = False
            while not done:
                if render:
                    env.render()

                action = self.agent.act(obs)
                next_obs, reward, done, _ = env.step(action)

                state = self.agent.q.observation_to_state(obs)
                future = 0 if done else np.max(self.agent.q.values(next_obs))
                value = self.agent.q.table[state][action]
                self.agent.q.table[state][action] += lr * (reward + self.gamma * future - value)

                obs = next_obs
                values.append(value)
                step += 1
                if self.max_step > 0 and step > self.max_step:
                    done = True
            else:
                mean = np.mean(values)
                steps.append(step)
                mean_step = np.mean(steps)
                print("Episode {}: {}steps(avg{}). epsilon={:.3f}, lr={:.3f}, mean q value={:.2f}".format(
                    i, step, mean_step, self.agent.epsilon, lr, mean)
                    )
                
                if self.epsilon_decay is not None:                
                    self.agent.epsilon = self.epsilon_decay(self.agent.epsilon, i)
                if self.learning_rate_decay is not None:
                    lr = self.learning_rate_decay(lr, i)

In [None]:
class TestCartPole(unittest.TestCase):

    def test_make_bins(self):
        env = gym.make("CartPole-v0")
        q = Q(env.action_space.n, env.observation_space, bin_size=7, low_bound=-3, high_bound=3)

        bin_range = (-2, 3)
        correct = np.arange(*bin_range).tolist()  # expected bins: ~-2, ~-1, ~0, ~1, ~2, ~3, (3~) = 7bin, 6 boundary
        bins = q._make_bins(bin_range[0], bin_range[1], 7)
        self.assertEqual(tuple(correct), tuple(bins))
    
    def test_make_bins_multi_sizes(self):
        dummy_observation_space = Box(0, 6, (2,))
        q = Q(4, dummy_observation_space, bin_size=[3, 5])
        self.assertEqual(3 - 2, len(q._dimension_bins[0]))
        self.assertEqual(5 - 2, len(q._dimension_bins[1]))

    def test_make_bins_multi_bounds(self):
        dummy_observation_space = Box(-3, 3, (2,))
        q = Q(4, dummy_observation_space, bin_size=[3, 5], low_bound=[-2, -1], high_bound=[2, 1])
        self.assertEqual(-2, q._dimension_bins[0][0])
        self.assertEqual(-1, q._dimension_bins[1][0])
        self.assertLess(q._dimension_bins[0][-1], 2)
        self.assertLess(q._dimension_bins[1][-1], 1)

    def test_observation_to_state(self):
        dummy_observation_space = Box(-2, 3, (2,))
        bin_size = 7
        q = Q(4, dummy_observation_space, bin_size=bin_size, low_bound=-3, high_bound=3)

        state = q.observation_to_state(np.array([-3, 1]))
        self.assertEqual(state, 0 * bin_size ** 0 + 4 * bin_size ** 1)

In [None]:
RECORD_PATH = os.path.join(os.path.dirname(__file__), "./upload")


def main(episodes, render, monitor):
    env = gym.make("CartPole-v0") 

    q = Q(
        env.action_space.n, 
        env.observation_space, 
        bin_size=[3, 3, 8, 5],
        low_bound=[None, -0.5, None, -math.radians(50)], 
        high_bound=[None, 0.5, None, math.radians(50)]
        )
    agent = Agent(q, epsilon=0.05)

    learning_decay = lambda lr, t: max(0.1, min(0.5, 1.0 - math.log10((t + 1) / 25)))
    epsilon_decay = lambda eps, t: max(0.01, min(1.0, 1.0 - math.log10((t + 1) / 25)))
    trainer = Trainer(
        agent, 
        gamma=0.99,
        learning_rate=0.5, learning_rate_decay=learning_decay, 
        epsilon=1.0, epsilon_decay=epsilon_decay,
        max_step=250)

    if monitor:
        env.monitor.start(RECORD_PATH)

    trainer.train(env, episode_count=episodes, render=render)

    if monitor:
        env.monitor.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="train & run cartpole ")
    parser.add_argument("--episode", type=int, default=1000, help="episode to train")
    parser.add_argument("--render", action="store_true", help="render the screen")
    parser.add_argument("--monitor", action="store_true", help="monitor")
    parser.add_argument("--upload", type=str, default="", help="upload key to openai gym (training is not executed)")

    args = parser.parse_args()

    if args.upload:
        if os.path.isdir(RECORD_PATH):
            gym.upload(RECORD_PATH, api_key=args.upload)
    else:
        main(args.episode, args.render, args.monitor)