# Iterative greedy policy improvement

\\[
\begin{equation*}
    \pi_{1} 
    \leq \textrm{greedy}(\pi_{1})
    = \pi_{2}
    \leq \textrm{greedy}(\pi_{2})
    = \pi_{3}
    \leq \textrm{greedy}(\pi_{3})
    = \pi_{4}
    \leq \cdots
    \pi_{*}
\end{equation*}
\\]

# Works well in theory, but not very practical! Let's see why.

# Setup code from lesson on greedy policy improvement in `CartPole-v0`

In [1]:
import random

import gym
import numpy as np


class InitMod(gym.Wrapper):
    """Wrapper class to change initial state  in CartPole-v0
    """
    def __init__(self, env, initial_state):
        super().__init__(env)
        self.initial_state = initial_state
        
    def reset(self):
        observation = self.env.reset()
        self.unwrapped.state = self.initial_state
        return self.unwrapped.state
    

# create the wrapped env    
wrapped_env = InitMod(env=gym.make("CartPole-v0"), initial_state=np.array([0, 0.01, 0.15, 0.]))


def get_action_random(observation):
    """Sampling function for random policy
    """
    if random.random() < 0.5:
        return 0
    return 1


class QValue:
    """Helper for computing Q-value of state-action pairs. 
    It has an update() method that updates averages of Q-value samples with new episode data
    """
    def __init__(self, gamma, visit_number=None, q_value_average=None):
        self.gamma = gamma
        if visit_number is None:
            self.visit_number = {}
        else:
            self.visit_number = visit_number
        if q_value_average is None:
            self.q_value_average = {}
        else:
            self.q_value_average = q_value_average
        
    def update(self, episode_history):
        backward_reward_sum = 0
        for step in reversed(episode_history):
            backward_reward_sum = (self.gamma * backward_reward_sum) + step["reward"]
            key = (tuple(step["observation"]), step["action"])
            try:
                visit_number = self.visit_number[key]
            except KeyError:
                visit_number = 0
            if visit_number == 0:
                self.q_value_average[key] = backward_reward_sum
            else:
                self.q_value_average[key] = (visit_number * self.q_value_average[key] + 
                                             backward_reward_sum
                                             ) / (visit_number + 1)
            self.visit_number[key] = visit_number + 1
            

class Value:
    """Helper for computing value of states. 
    It has an update() method that updates averages of value samples with new episode data
    """
    def __init__(self, gamma, visit_number=None, value_average=None):
        self.gamma = gamma
        if visit_number is None:
            self.visit_number = {}
        else:
            self.visit_number = visit_number
        if value_average is None:
            self.value_average = {}
        else:
            self.value_average = value_average
        
    def update(self, episode_history):
        backward_reward_sum = 0
        for step in reversed(episode_history):
            backward_reward_sum = self.gamma * backward_reward_sum + step["reward"]
            key = tuple(step["observation"])
            try:
                visit_number = self.visit_number[key]
            except KeyError:
                visit_number = 0
            if visit_number == 0:
                self.value_average[key] = backward_reward_sum
            else:
                self.value_average[key] = (visit_number * self.value_average[key] + 
                                           backward_reward_sum
                                           ) / (visit_number + 1)
            self.visit_number[key] = visit_number + 1

## Code from lesson on greedy policy improvement in `CartPole-v0`:

# Go through `50000` episodes using random policy $\pi_{\textrm{random}}$ and calculate value and Q-value

- The longer we run the random policy, the more states we see.

In [2]:
num_episodes = 50000
gamma = 0.95

In [3]:
q_value_random_policy = QValue(gamma=gamma)
value_random_policy = Value(gamma=gamma)

for num_episode in range(num_episodes):
    episode_history = []
    observation = wrapped_env.reset()
    step_count = 0
    while True:
        action = get_action_random(observation)
        next_observation, reward, done, _ = wrapped_env.step(action)
        episode_history.append({"observation": observation, "reward": reward, "action": action})
        observation = next_observation
        step_count += 1
        if done:
            break
    q_value_random_policy.update(episode_history)
    value_random_policy.update(episode_history)
    if (num_episode + 1) % 10000 == 0:
        print(f"num state-action pairs seen after {num_episode + 1} episodes is {len(q_value_random_policy.visit_number)}")
wrapped_env.close()

num state-action pairs seen after 10000 episodes is 65523
num state-action pairs seen after 20000 episodes is 121696
num state-action pairs seen after 30000 episodes is 174034
num state-action pairs seen after 40000 episodes is 224356
num state-action pairs seen after 50000 episodes is 271638


# Practical problem 1: We haven't seen many state-action pairs in the 1000 episodes with $\pi_{random}$

- Can't take greedy actions if we don't have Q-value estimates for state-action pairs

In [4]:
state = np.array([0.1, 0.1, 0.1, 0.1])
missing_state_action_pair = (tuple(state), 1)
q_value_random_policy.q_value_average[missing_state_action_pair]

KeyError: ((0.1, 0.1, 0.1, 0.1), 1)

## Code from lesson on greedy policy improvement in `CartPole-v0`:

# Sampling function for policy $\textrm{greedy}(\pi_{\textrm{random}})$

- Only approximates $\textrm{greedy}(\pi_{\textrm{random}})$

In [5]:
def get_action_greedy_policy(observation, q_value_average):
    try:
        q_values = np.array([q_value_average[(tuple(observation), action)] for action in (0, 1)])
    except KeyError:
        # Compromise we make since we don't have Q-value estimates for all state action pairs
        return get_action_random(observation)
    # tie breaking
    return np.random.choice(np.flatnonzero(np.isclose(q_values, q_values.max())))

## Code from lesson on greedy policy improvement in `CartPole-v0`: 

# Go through `10000` episodes using *greedy policy* ($\pi_{\textrm{greedy}}$ or $\textrm{greedy}(\pi_{random})$)

In [6]:
num_episodes = 10000
q_value_greedy_policy = QValue(gamma=gamma)
value_greedy_policy = Value(gamma=gamma)

for num_episode in range(num_episodes):
    episode_history = []
    observation = wrapped_env.reset()
    step_count = 0
    while True:
        action = get_action_greedy_policy(observation, q_value_random_policy.q_value_average)
        if step_count == 0:
            print(action)
        next_observation, reward, done, _ = wrapped_env.step(action)
        episode_history.append({"observation": observation, "reward": reward, "action": action})
        observation = next_observation
        step_count += 1
        if done:
            break
    value_greedy_policy.update(episode_history)
    q_value_greedy_policy.update(episode_history)
wrapped_env.close()

1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1


1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1


1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1


# Practical problem 2: Another round of greedy policy improvement $\textrm{greedy}(\textrm{greedy}(\pi_{random}))$

- Q-values when following $\textrm{greedy}(\pi_{random})$
- Greedy policy takes actions that maximize Q-values. In the initial state, this is `1`. So `(initial_state, 0)` is never seen while following this policy.
- If it cannot estimate Q-values of a state action pair, it cannot do another round of greedy policy improvement

In [7]:
state = np.array([0., 0.01, 0.15, 0.])
missing_state_action_pair = (tuple(state), 0)
q_value_greedy_policy.q_value_average[missing_state_action_pair]

KeyError: ((0.0, 0.01, 0.15, 0.0), 0)

# Solution: balance exploitation and exploration 

Simultaneously take advantage of these two aspects of the two different policies
- Policy improvement aspect of exploitation (greedy policy)
- Discovery aspect of exploration (random policy) 

# Epsilon greedy policy
- Much like the epsilon pole direction policy 
- With probability $\epsilon$, the agent takes random actions. With probability $1 - \epsilon$, it takes greedy actions

In [8]:
def get_action_epsilon_greedy_policy(observation, q_value_average, epsilon):
    if random.random() < epsilon:
        return get_action_random(observation)
    return get_action_greedy_policy(observation, q_value_average)

# Advantages

- Still takes a lot of random actions, and therefore, finds new states and state-action pairs.
- The agent is now able to estimate the Q-value for all state-action pairs
- The agent can construct an epsilon-greedy policy with respect to the first epsilon-greedy policy i.e. $\epsilon-\textrm{greedy}(\epsilon-\textrm{greedy}(\pi_{\textrm{random}}))$

In [9]:
num_episodes = 10000
epsilon = 0.9
q_value_epsilon_greedy_policy = QValue(gamma=gamma)

for num_episode in range(num_episodes):
    episode_history = []
    observation = wrapped_env.reset()
    step_count = 0
    while True:
        action = get_action_epsilon_greedy_policy(observation, q_value_random_policy.q_value_average, epsilon)
        if step_count == 0:
            print(action)
        next_observation, reward, done, _ = wrapped_env.step(action)
        episode_history.append({"observation": observation, "reward": reward, "action": action})
        observation = next_observation
        step_count += 1
        if done:
            break
    q_value_epsilon_greedy_policy.update(episode_history)
wrapped_env.close()

0
1
1
0
0
0
0
0
1
0
1
1
0
0
1
1
1
1
0
0
0
1
1
1
0
1
1
0
0
0
1
0
0
0
0
0
1
1
0
0
1
0
0
1
0
1
0
1
1
0
0
0
0
1
1
1
1
0
0
1
0
0
1
1
0
0
1
1
0
0
0
0
0
0
1
1
1
1
1
0
0
0
1
0
1
1
0
1
0
1
0
0
0
0
1
1
0
1
1
1
0
1
0
1
0
1
1
1
0
0
0
1
1
1
0
0
1
0
0
1
1
0
0
0
0
0
0
1
0
0
0
0
0
0
1
1
1
0
0
1
0
1
1
0
1
1
0
1
0
0
1
0
0
0
0
0
0
0
0
1
0
1
1
1
0
0
0
1
1
1
1
1
1
0
1
1
1
0
1
1
1
1
1
1
0
1
0
0
0
1
1
0
0
1
0
1
1
0
0
0
0
1
1
0
0
0
0
1
0
0
0
1
0
1
1
1
1
0
0
1
0
0
0
1
0
1
0
0
0
0
0
1
1
0
0
0
1
0
1
0
0
1
1
1
1
0
1
0
0
1
0
0
1
1
1
1
0
1
1
1
0
0
0
1
0
0
1
0
1
1
0
0
0
1
1
0
1
1
1
1
1
1
0
1
1
0
0
0
0
1
1
0
0
0
1
0
1
1
1
1
1
1
0
0
1
0
0
1
0
1
1
0
1
1
0
1
1
1
1
1
1
1
1
1
0
0
0
1
0
1
0
1
1
1
0
0
0
0
0
1
1
0
1
0
0
0
1
1
1
1
1
1
0
1
0
1
0
1
1
1
1
1
1
0
1
0
1
1
1
1
1
0
0
1
0
1
0
0
1
1
0
0
0
0
0
0
0
1
1
0
1
0
0
1
0
0
1
1
0
0
1
1
1
0
0
1
1
1
0
1
0
1
1
0
1
1
1
1
1
1
1
0
1
1
1
0
0
1
1
1
1
1
0
1
1
0
1
0
0
1
0
1
0
0
0
0
1
0
1
0
0
1
1
1
1
0
0
1
1
0
1
1
1
1
0
1
1
0
0
0
0
1
1
0
0
0
0
1
1
0
1
1
0
1
1
1
1
0
1
1
1
0
0
1
1
1
0
1
0
0


1
1
1
1
0
1
0
0
1
1
1
0
0
0
0
0
0
1
1
0
1
1
0
0
0
1
0
0
1
0
1
1
1
1
1
1
0
1
0
0
1
1
1
1
0
0
1
1
0
1
1
0
1
1
0
1
0
0
1
1
0
1
1
0
1
1
1
0
1
1
0
0
1
0
1
1
0
1
1
1
0
1
1
0
0
1
1
0
1
1
0
1
1
1
1
0
0
1
1
0
1
0
0
1
1
1
0
1
1
0
1
0
0
1
1
0
0
1
0
1
1
0
1
1
0
0
0
0
0
0
0
1
1
1
1
1
1
0
1
0
1
0
0
1
1
0
0
0
0
1
1
0
1
1
1
1
1
1
1
1
1
0
1
0
1
0
1
1
0
1
1
0
1
1
1
0
1
1
0
0
0
1
1
0
1
1
0
1
0
0
0
0
1
0
0
1
1
0
0
0
0
0
1
0
0
1
1
0
1
0
0
0
1
1
0
0
0
1
1
1
0
1
0
1
1
1
1
0
1
1
1
1
1
1
1
1
0
0
1
1
1
0
0
1
0
0
1
0
0
1
1
0
1
1
1
1
0
1
0
0
0
0
1
0
0
1
0
1
0
0
0
0
1
1
1
0
0
1
0
1
1
1
1
1
0
0
0
1
0
0
0
0
0
0
1
0
1
0
1
1
0
0
1
0
0
1
1
0
1
1
0
0
0
0
0
1
0
0
1
0
0
0
1
0
0
1
1
1
1
0
0
1
0
1
1
1
1
1
0
0
1
0
0
1
1
0
0
1
0
0
1
1
1
0
0
0
0
1
1
1
0
0
1
0
0
1
1
1
1
1
0
0
0
0
1
0
0
1
1
0
1
1
1
1
1
0
1
0
0
1
1
1
1
1
1
1
1
0
1
1
1
1
0
1
1
1
0
0
0
0
0
1
1
1
1
0
1
1
1
0
1
0
0
0
1
0
0
0
0
1
1
1
1
1
0
0
1
1
0
1
0
0
1
1
1
1
1
0
0
1
1
1
1
1
1
0
0
0
0
1
0
1
1
1
0
0
0
0
0
1
1
1
1
0
1
1
1
1
0
0
1
0
1
1
1
0
0
0
1
1
0
0
1
1
1
0
1
1
1
1


0
0
0
0
1
0
0
0
1
0
1
1
0
0
1
0
1
0
1
0
0
0
1
0
1
0
1
1
0
0
0
1
1
1
1
0
0
1
1
0
0
1
0
1
1
0
0
1
1
0
0
1
0
0
0
1
1
0
0
1
1
1
1
1
0
1
1
0
1
1
0
0
1
1
1
0
1
1
1
0
1
0
1
1
1
0
1
0
0
0
0
1
0
0
0
1
0
1
1
0
0
1
0
1
0
0
1
1
1
0
1
1
0
1
0
1
0
1
1
1
1
1
0
1
1
1
1
0
0
1
0
1
0
0
0
0
1
0
1
1
0
0
1
0
0
0
1
0
0
0
0
0
1
1
0
1
0
0
1
1
1
0
1
1
0
0
0
1
0
0
0
1
0
0
0
0
0
0
0
1
0
1
1
0
0
1
0
1
0
0
1
1
0
0
0
1
1
1
1
0
0
0
1
1
1
1
0
0
0
1
0
1
0
1
1
0
0
1
1
1
1
0
0
1
0
1
0
1
0
1
0
0
1
0
0
1
1
0
1
0
1
1
1
1
0
0
1
1
0
0
0
1
1
0
0
0
1
1
0
1
1
0
1
0
0
0
0
0
1
1
1
0
1
0
0
1
1
0
1
0
1
0
0
0
1
1
1
1
0
0
0
1
1
1
1
1
1
1
0
0
0
0
1
0
0
0
0
1
0
0
0
1
0
0
0
1
0
1
0
0
0
0
1
0
1
0
1
0
1
1
0
1
0
0
1
1
0
0
0
0
1
0
1
1
1
1
1
0
1
1
1
0
0
0
1
0
1
1
1
0
0
0
1
0
0
1
1
1
1
0
1
1
1
1
1
1
0
0
1
1
0
0
1
1
1
1
0
0
0
1
1
0
1
0
0
1
1
0
1
0
1
1
1
0
0
0
1
0
1
0
0
1
1
0
0
1
1
1
1
0
0
0
0
1
1
0
0
1
1
1
1
1
0
0
0
1
0
0
1
1
0
1
0
0
1
0
1
0
1
1
0
0
0
1
0
0
1
0
0
1
1
1
1
0
1
0
1
0
1
0
1
0
0
0
0
1
1
1
1
0
1
0
1
1
0
0
0
1
0
1
0
1
0
1
0
1
1
0
0
1


In [10]:
state = np.array([0., 0.01, 0.15, 0.])
missing_state_action_pair = (tuple(state), 0)
q_value_epsilon_greedy_policy.visit_number[missing_state_action_pair]

4576

In [12]:
q_value_epsilon_greedy_policy.q_value_average[missing_state_action_pair]

9.035266493040849