# Greedy policy improvement ($\pi_{\textrm{greedy}}$ or $\textrm{greedy}(\pi)$) in `CartPole-v0`

# Setup the wrapped environment

In [16]:
import gym
import numpy as np


class InitMod(gym.Wrapper):
    """Wrapper class to change initial state  in CartPole-v0
    """
    def __init__(self, env, initial_state):
        super().__init__(env)
        self.initial_state = initial_state
        
    def reset(self):
        observation = self.env.reset()
        self.unwrapped.state = self.initial_state
        return self.unwrapped.state
    

# create the wrapped env    
wrapped_env = InitMod(env=gym.make("CartPole-v0"), initial_state=np.array([0, 10.0, 0., 0.]))

# Sampling function for random policy

In [2]:
import random

def get_action_random(observation):
    """Sampling function for random policy
    """
    if random.random() < 0.5:
        return 0
    return 1

# Helpers for calculating values of states and action-values (Q-values) of state-action pairs

In [3]:
class QValue:
    """Helper for computing Q-value of state-action pairs. 
    It has an update() method that updates averages of Q-value samples with new episode data
    """
    def __init__(self, gamma, visit_number=None, q_value_average=None):
        self.gamma = gamma
        if visit_number is None:
            self.visit_number = {}
        else:
            self.visit_number = visit_number
        if q_value_average is None:
            self.q_value_average = {}
        else:
            self.q_value_average = q_value_average
        
    def update(self, episode_history):
        backward_reward_sum = 0
        for step in reversed(episode_history):
            backward_reward_sum = (self.gamma * backward_reward_sum) + step["reward"]
            key = (tuple(step["observation"]), step["action"])
            try:
                visit_number = self.visit_number[key]
            except KeyError:
                visit_number = 0
            if visit_number == 0:
                self.q_value_average[key] = backward_reward_sum
            else:
                self.q_value_average[key] = (visit_number * self.q_value_average[key] + 
                                             backward_reward_sum
                                             ) / (visit_number + 1)
            self.visit_number[key] = visit_number + 1
            

class Value:
    """Helper for computing value of states. 
    It has an update() method that updates averages of value samples with new episode data
    """
    def __init__(self, gamma, visit_number=None, value_average=None):
        self.gamma = gamma
        if visit_number is None:
            self.visit_number = {}
        else:
            self.visit_number = visit_number
        if value_average is None:
            self.value_average = {}
        else:
            self.value_average = value_average
        
    def update(self, episode_history):
        backward_reward_sum = 0
        for step in reversed(episode_history):
            backward_reward_sum = self.gamma * backward_reward_sum + step["reward"]
            key = tuple(step["observation"])
            try:
                visit_number = self.visit_number[key]
            except KeyError:
                visit_number = 0
            if visit_number == 0:
                self.value_average[key] = backward_reward_sum
            else:
                self.value_average[key] = (visit_number * self.value_average[key] + 
                                           backward_reward_sum
                                           ) / (visit_number + 1)
            self.visit_number[key] = visit_number + 1

# Go through `10000` episodes using random policy $\pi_{\textrm{random}}$ and calculate value and Q-value

In [4]:
num_episodes = 10000
gamma = 0.95

In [17]:
q_value_random_policy = QValue(gamma=gamma)
value_random_policy = Value(gamma=gamma)

for num_episode in range(num_episodes):
    episode_history = []
    observation = wrapped_env.reset()
    step_count = 0
    while True:
        action = get_action_random(observation)
        if step_count == 0:
            print(action)
        next_observation, reward, done, _ = wrapped_env.step(action)
        episode_history.append({"observation": observation, "reward": reward, "action": action})
        observation = next_observation
        step_count += 1
        if done:
            break
    q_value_random_policy.update(episode_history)
    value_random_policy.update(episode_history)
wrapped_env.close()

1
0
1
1
1
0
0
0
0
0
1
0
1
1
0
1
1
1
1
0
0
1
1
1
0
1
0
0
0
1
1
1
1
0
1
1
0
1
0
1
0
0
1
0
1
0
1
1
0
1
0
1
0
0
0
0
0
0
1
0
1
0
1
1
1
1
0
1
0
1
0
1
1
0
0
0
0
0
0
1
0
0
0
0
0
1
1
1
1
1
0
1
0
1
1
0
1
1
1
0
0
1
1
0
0
1
1
0
0
1
1
0
1
1
1
1
0
0
1
1
1
1
1
1
0
0
0
0
1
1
0
1
0
1
0
1
0
0
0
1
1
0
0
0
1
1
0
1
1
0
1
1
0
0
0
1
0
1
1
0
0
1
0
1
0
0
1
1
1
0
1
0
1
1
0
1
0
0
0
1
1
1
0
0
1
1
0
1
1
1
0
1
0
0
1
1
1
0
1
0
1
1
0
1
1
0
1
0
1
0
1
1
0
1
0
0
1
1
1
0
0
0
0
1
0
0
0
0
0
1
1
1
0
0
0
0
1
1
1
1
0
0
0
1
1
0
1
0
0
1
1
0
0
0
0
1
1
1
0
0
1
0
0
1
0
1
1
1
1
1
1
0
0
1
1
0
1
0
1
1
1
1
0
0
0
0
0
0
1
0
0
1
0
0
1
1
1
0
0
1
1
0
1
1
1
0
0
1
1
0
0
1
0
1
1
1
1
0
0
0
0
0
1
0
0
1
1
0
0
0
1
1
0
1
1
0
0
1
0
1
0
1
0
0
1
0
1
0
0
0
0
0
0
1
1
1
1
1
0
0
1
1
1
1
0
1
0
0
1
1
0
0
1
1
1
1
1
0
1
1
0
0
0
1
1
0
1
0
1
1
0
1
1
0
0
1
0
0
1
1
1
0
1
0
0
0
1
1
1
0
1
0
0
0
1
1
0
0
0
1
1
1
0
0
0
0
0
1
1
0
0
1
0
0
0
1
1
0
1
1
1
1
0
0
0
1
0
1
0
0
1
0
1
1
1
1
0
1
0
0
1
1
1
1
1
0
1
0
0
1
0
1
1
0
0
1
0
0
0
1
1
0
0
1
0
1
0
1
1
0
1
1
1
0
0
1
1
1
1
1


1
0
0
0
1
0
1
0
1
1
0
1
0
0
1
1
0
0
1
0
0
0
1
1
1
0
1
1
1
0
1
1
1
0
0
1
1
0
1
0
1
0
0
1
1
0
1
0
0
1
1
0
1
1
0
0
0
0
1
1
1
0
1
0
0
0
1
0
1
1
1
1
1
1
1
1
0
0
1
0
1
1
0
1
1
1
1
1
1
0
1
0
0
0
1
1
1
1
0
0
1
0
0
1
0
0
0
0
1
1
0
1
0
1
0
0
1
0
1
1
1
1
1
0
0
1
1
1
0
1
0
0
0
1
0
0
0
0
1
0
1
0
0
1
1
1
1
0
1
0
1
0
0
0
0
0
1
1
1
0
1
1
0
0
1
1
0
1
1
0
0
0
0
0
1
1
0
1
0
1
0
0
0
0
0
1
1
0
1
0
0
1
0
0
1
0
1
0
0
1
1
0
1
0
0
1
0
0
1
1
1
1
1
0
1
0
1
0
0
1
0
1
0
1
1
0
1
0
0
0
1
1
0
1
0
1
0
0
1
1
0
1
0
0
1
0
1
1
0
0
1
0
0
1
0
1
0
1
1
1
1
0
1
1
0
0
1
0
1
0
0
1
0
0
1
0
0
0
0
0
0
1
1
1
1
1
1
1
0
1
1
1
0
0
0
1
1
1
1
1
0
0
0
1
0
1
1
0
0
0
0
1
0
1
1
1
0
1
0
0
1
0
0
0
0
0
1
1
1
1
0
0
1
1
1
0
0
0
1
1
0
1
1
0
1
0
1
1
1
1
1
1
0
0
0
1
1
0
1
1
1
1
1
0
1
1
1
0
1
1
0
1
1
1
0
1
0
0
0
1
0
1
1
0
1
1
0
0
0
0
0
1
1
0
1
1
0
0
0
0
0
1
1
0
1
0
0
0
1
1
1
1
1
1
0
0
1
1
0
1
0
0
1
1
0
1
1
0
1
0
0
0
0
0
0
0
0
0
0
1
1
0
0
1
1
1
1
1
0
0
0
0
1
0
0
1
0
0
1
1
1
1
0
0
1
0
1
0
1
0
0
1
1
0
0
0
1
0
0
0
0
0
1
1
1
1
1
1
1
0
0
1
0
0
0
0
0
0
1
0


1
0
0
0
0
1
1
0
0
0
0
1
1
0
0
0
1
1
0
1
0
0
1
0
0
0
0
0
0
1
0
1
0
0
0
0
1
1
0
0
0
0
1
1
1
0
0
0
1
0
1
0
1
0
1
1
0
1
1
0
1
0
0
1
1
0
0
0
1
1
1
0
0
1
0
1
1
0
0
1
0
0
1
0
1
1
0
0
0
0
0
1
1
0
1
1
0
1
0
0
0
0
1
1
1
1
0
1
0
0
1
1
1
1
1
1
0
0
0
0
0
0
1
0
0
1
1
0
1
0
0
1
1
1
0
0
1
0
1
0
0
1
0
1
1
1
0
0
1
1
1
0
0
1
1
1
0
0
1
1
1
1
1
0
0
1
1
1
1
0
0
1
1
0
0
0
1
1
1
1
1
1
1
0
1
0
0
1
0
0
1
1
1
0
1
0
0
1
0
0
0
0
0
1
1
0
0
1
0
0
0
0
0
0
0
1
0
0
1
0
0
1
0
1
0
0
0
0
0
1
1
1
1
1
0
1
1
0
0
1
0
0
1
0
1
0
0
1
0
0
1
0
0
1
1
1
1
0
1
1
1
1
1
0
1
1
0
1
1
0
0
0
1
1
1
0
0
0
0
1
1
1
0
0
1
0
1
1
0
1
0
0
1
1
1
0
0
0
0
0
1
0
0
1
1
1
1
0
1
1
0
1
1
0
1
1
0
0
1
0
1
0
1
1
0
1
0
0
0
0
0
1
1
0
1
1
0
1
0
1
0
1
1
1
1
0
0
0
0
0
1
0
1
0
0
1
0
1
1
1
0
1
0
0
0
1
1
1
1
0
1
1
1
0
0
0
0
1
0
0
1
1
0
1
0
0
0
1
1
1
1
0
0
1
1
0
1
1
1
0
0
1
0
0
0
0
0
0
1
0
1
1
0
0
0
1
1
0
1
1
1
1
1
1
0
0
0
0
0
0
1
1
1
1
1
0
1
0
1
0
1
0
1
1
0
0
1
1
0
0
0
0
1
1
0
1
0
1
0
0
0
1
1
0
0
1
1
1
1
0
0
1
0
1
1
1
1
0
1
1
0
1
1
0
1
0
1
1
1
0
1
1
1
0
0
1
1
0
1
1


# Sampling function for policy $\textrm{greedy}(\pi_{\textrm{random}})$

In [1]:
def get_action_greedy_policy(observation, q_value_average):
    try:
        q_values = np.array([q_value_average[(tuple(observation), action)] for action in (0, 1)])
    except KeyError:
        return get_action_random(observation)
    # q_values = np.array([2, 5])
    # np.argmax(q_values) == 1
    return np.argmax(q_values)

# Go through `10000` episodes using *greedy policy* ($\pi_{\textrm{greedy}}$ or $\textrm{greedy}(\pi)$) and calculate values of states (need this data to compare the two policies)

In [18]:
value_greedy_policy = Value(gamma=gamma)

for num_episode in range(num_episodes):
    episode_history = []
    observation = wrapped_env.reset()
    step_count = 0
    while True:
        action = get_action_greedy_policy(observation, q_value_random_policy.q_value_average)
        if step_count == 0:
            print(action)
        next_observation, reward, done, _ = wrapped_env.step(action)
        episode_history.append({"observation": observation, "reward": reward, "action": action})
        observation = next_observation
        step_count += 1
        if done:
            break
    value_greedy_policy.update(episode_history)
wrapped_env.close()

0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0


0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0


0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0


# Which states can we compare?

- Don't have good value estimates for all states, but only those that were visited many times.

In [11]:
max(value_random_policy.visit_number.items(), key=lambda x: x[1])

((0.0, 0.01, 0.15, 0.0), 10000)

In [12]:
min(value_random_policy.visit_number.items(), key=lambda x: x[1])

((0.29949134350161566,
  1.3635969384704734,
  -0.1964374636086723,
  -1.8074266406273143),
 1)

# Compare $\pi_{\textrm{random}}$ and $\textrm{greedy}(\pi_{\textrm{random}})$

In [19]:
common_states = [state for state in value_random_policy.visit_number if state in value_greedy_policy.visit_number]

In [20]:
sorted_common_states = sorted(common_states,
                              key=lambda state: min(value_random_policy.visit_number[state],
                                                    value_greedy_policy.visit_number[state]
                                                    ),
                              reverse=True 
                             )

In [21]:
for state in sorted_common_states[:3]:
    print(f"Value of {state} given random policy is {value_random_policy.value_average[state]}")
    print(f"Value of {state} given greedy policy is {value_greedy_policy.value_average[state]}")

Value of (0.0, 10.0, 0.0, 0.0) given random policy is 9.353292424894313
Value of (0.0, 10.0, 0.0, 0.0) given greedy policy is 9.73315833440939
Value of (0.2, 9.804878048780488, 0.0, 0.2926829268292683) given random policy is 8.88388194129123
Value of (0.2, 9.804878048780488, 0.0, 0.2926829268292683) given greedy policy is 9.192798246746513
Value of (0.3960975609756098, 10.0, 0.005853658536585366, 0.0) given random policy is 8.341690057353452
Value of (0.3960975609756098, 10.0, 0.005853658536585366, 0.0) given greedy policy is 8.623998154469287
