# Implementing sampling functions for any environment with discrete actions

## Background: environments which have more than 2 actions

Let's say an environment has 3 actions: `0`, `1` and `2`. An example of such an environment is `MountainCar-v0`, which you saw in Chapter 1.

In this case, the original `get_action_greedy_policy()` (from the video) won't work anymore. This is because it hardcodes the action choices.

```
def get_action_greedy_policy(observation, q_value_average):
    """Sampling function for greedy policy
    """
    try:
        q_values = np.array([q_value_average[(tuple(observation), action)] for action in (0, 1)])    # hardcodes action choices
    except KeyError:
        return get_action_random(observation)    # hardcodes action choices
    return np.argmax(q_values)
```

The `get_action_random()` function also hardcodes the action choices.

```
def get_action_random(observation):
    """Sampling function for random policy
    """
    if random.random() < 0.5:
        return 0    # hardcoding action choices
    return 1    # hardcoding action choices
```

To avoid this hardcoding, we can use `env.action_space.n`, which returns the number of actions allowed in the environment. Run the cells below to see how this works.

In [1]:
# env.action_space.n returns 2 for CartPole-v0 because there are two possible actions
import gym 

cartpole_env = gym.make("CartPole-v0")
print(cartpole_env.action_space.n)

2


In [2]:
# env.action_space.n returns 3 for MountainCar-v0 because there are three possible actions
mountaincar_env = gym.make("MountainCar-v0")
print(mountaincar_env.action_space.n)

3


The more general sampling function that works for both `CartPole-v0` and `MountainCar-v0` would do the following.

1. Accept `env` as an argument.
2. Use `env.action_space.n` and `env.action_space.sample()` to avoid hardcoding the action choices. 

I have rewritten `get_action_greedy_policy()` below to implement this change. I have called it `get_action_greedy_policy_general()`.

Study it carefully and then run the cell to load the function into memory.

In [3]:
import numpy as np

def get_action_greedy_policy_general(env, observation, q_value_average):
    """Sampling function for greedy policy
    """
    try:
        q_values = np.array([q_value_average[(tuple(observation), action)] for action in range(env.action_space.n)])    # avoids hardcoding action choices
    except KeyError:
        return env.action_space.sample()    # avoids hardcoding action choices
    return np.argmax(q_values)

Notice that I am still using `np.argmax()`, which means ties won't be broken.

## When there are more actions, we can have more complicated ties!

For example, look at the fictitous Q-value dictionary (for a random policy) for `MountainCar-v0` below.  

Run the cell to load the dictionary into memory.

In [4]:
q_value_average_with_ties_complicated = {((0.1, 0.01), 0): -5,    # action 0 doesn't correspond to the max Q-value
                                         ((0.1, 0.01), 1): -3,    # action 1 and 2 are tied for the max Q-value
                                         ((0.1, 0.01), 2): -3,    
                                         ((0.2, 0.02), 0): -7,    # the state np.array([0.2, 0.02]) does not have a tie
                                         ((0.2, 0.02), 1): -10,
                                         ((0.2, 0.02), 2): -15
                                         }

For the state `np.array([0.1, 0.01])`, actions `1` and `2` are tied for the max Q value. But action `0` has a lower Q-value. In this case, we should break the tie by randomly choosing between `1` and `2`.  We should not choose `0` at all.

Of course, the more general sampling function that I wrote `get_action_greedy_policy_general()` still uses `np.argmax()`. So it will return the first index `1` all the time.

Verify by running the cell below.

In [5]:
for _ in range(10):
    print(get_action_greedy_policy_general(mountaincar_env, np.array([0.1, 0.01]), q_value_average_with_ties_complicated))

1
1
1
1
1
1
1
1
1
1


## Implement a greedy sampling function with tie breaking that works for any environment with discrete actions (no matter how many)

- In the last exercise, you wrote `get_action_greedy_policy_random_tie_break()` taking inspiration from `get_action_greedy_policy()` (from the video).
- This would work in the `CartPole-v0` environment.
- Now, write a function `get_action_greedy_policy_general_random_tie_break()` that would work for `CartPole-v0`, `MountainCar-v0` or any environment with discrete actions (no matter how many).
- You can take inspiration from `get_action_greedy_policy_general()` that I already wrote and modify it to break ties properly.

Ready? Your code below below.

In [None]:
def get_action_greedy_policy_general_random_tie_break(env, observation, q_value_average):
    # Implement the greedy policy with random tie breaking for any env with discrete actions

## Check if your implementation works as expected

Run the cell below. It will run `get_action_greedy_policy_general_random_tie_break()` on the tied state `np.array([0.1, 0.1])` in `MountainCar-v0` 10 times.

If you implementation is correct, it choose `1` some of the times, and `2` at other times. It should never choose `0`.

In [None]:
for _ in range(10):
    print(get_action_greedy_policy_general_random_tie_break(mountaincar_env, 
                                                            np.array([0.1, 0.01]), 
                                                            q_value_average_with_ties_complicated
                                                            )
         )

## Did it work? If yes, then congrats! Now we can do policy improvement not just in `CartPole-v0`, but in any environment with discrete actions!