# Value function in terms of discounted reward sum: goodness of a state given a certain policy

# You are vegetarian

- Action choices: 
    1. eating plant based food
    2. eat animal based food 
    
- Given any state $s$, you choose action number 1.
- $\pi(1 | s) = 1, \pi(2 | s) = 0$

# You wanted to change the policy (maybe out of curiosity), but not by too much. What is the smallest change that you can make?

- For one step (once in your life), and one step only, you are going to try animal based food. 
- For every time step after that, you are going to continue following your vegetarian policy

# You were following a policy $\pi$ in a MDP, the smallest change to the policy is the following
- You take a particular action $a$ for only a single time step (ignore the $\pi(a | s)$ in that particular time step).
- Then afterwards, you continue following policy $\pi$ for all subsequent time steps.

# The action value function (which depends on that action $a$ that you took in that single time step and the state $s$ that you were in) measures the result of this slight deviation from the policy

In [1]:
import gym

In [2]:
class InitMod(gym.Wrapper):
    def __init__(self, env, initial_state):
        super().__init__(env)
        self.initial_state = initial_state
        
    def reset(self):
        observation = self.env.reset()
        self.unwrapped.state = self.initial_state
        return self.unwrapped.state

In [3]:
# Sampling function for the random policy
import random 

def get_action_random_policy(observation):
    if random.random() < 0.5:
        return 0
    return 1

In [4]:
import numpy as np
pole_right_init_cartpole_env = InitMod(env=gym.make("CartPole-v0"), initial_state=np.array([0, 0.01, 0.15, 0]))

# We are interested in computing the action value function sample for state `[0, 0.01, 0.15, 0]` and the action `1`: `5.69`

In [5]:
observation = pole_right_init_cartpole_env.reset()
step_count = 0
discounted_reward_sum = 0
gamma = 0.9
while True:
    if step_count == 0:
        action = 1
    else:
        action = get_action_random_policy(observation)
    next_observation, reward, done, _ = pole_right_init_cartpole_env.step(action)
    observation = next_observation
    discounted_reward_sum += reward * (gamma**step_count)
    step_count += 1
    if done:
        break
pole_right_init_cartpole_env.close()
print(discounted_reward_sum)

5.6953279000000006


# We are interested in computing the action value function sample for state `[0, 0.01, 0.15, 0]` and the action `0`: `4.68`

In [6]:
observation = pole_right_init_cartpole_env.reset()
step_count = 0
discounted_reward_sum = 0
gamma = 0.9
while True:
    if step_count == 0:
        action = 0
    else:
        action = get_action_random_policy(observation)
    next_observation, reward, done, _ = pole_right_init_cartpole_env.step(action)
    observation = next_observation
    discounted_reward_sum += reward * (gamma**step_count)
    step_count += 1
    if done:
        break
pole_right_init_cartpole_env.close()
print(discounted_reward_sum)

4.68559


| State (s) | Action (a) | Policy ($\pi$) | Sample of $Q_{\pi}(s, a)$ from one episode |
| --- | --- | --- | --- |
| `[0, 0.01, 0.15, 0]` | 1 | random | `5.69` |
| `[0, 0.01, 0.15, 0]` | 0 | random |  `4.68` |

In [11]:
class QValue:
    def __init__(self, gamma, visit_number=None, q_value_average=None):
        """
        visit_number: {(observation, action): 3}
        q_value_average: {(observation, action): 4.5}
        """
        self.gamma = gamma
        if visit_number is None:
            self.visit_number = {}
        else:
            self.visit_number = visit_number
        if q_value_average is None:
            self.q_value_average = {}
        else:
            self.q_value_average = q_value_average
        
    def update(self, episode_history):
        backward_reward_sum = 0
        for step in reversed(episode_history):
            backward_reward_sum = (self.gamma * backward_reward_sum) + step["reward"]
            key = (tuple(step["observation"]), step["action"])
            try:
                visit_number = self.visit_number[key]
            except KeyError:
                visit_number = 0
            if visit_number == 0:
                self.q_value_average[key] = backward_reward_sum
            else:
                self.q_value_average[key] = ((visit_number * self.q_value_average[key]) + backward_reward_sum) / (visit_number + 1)
            self.visit_number[key] = visit_number + 1

In [12]:
q_value_info = QValue(gamma=0.9)
num_episodes = 1000
for num_episode in range(num_episodes):
    episode_history = []
    observation = pole_right_init_cartpole_env.reset()
    while True:
        action = get_action_random_policy(observation)
        next_observation, reward, done, _ = pole_right_init_cartpole_env.step(action)
        episode_history.append({"observation": observation, "reward": reward, "action": action})
        observation = next_observation
        if done:
            break
    q_value_info.update(episode_history)       
pole_right_init_cartpole_env.close()

In [13]:
state_action_pair = ((0, 0.01, 0.15, 0), 1)
print(q_value_info.q_value_average[state_action_pair])
print(q_value_info.visit_number[state_action_pair])

8.157367163378453
477


In [14]:
state_action_pair = ((0, 0.01, 0.15, 0), 0)
print(q_value_info.q_value_average[state_action_pair])
print(q_value_info.visit_number[state_action_pair])

6.339425112972492
523


| State (s) | Action (a) | Policy ($\pi$) | $Q_{\pi}(s, a)$ |
| --- | --- | --- | --- |
| `[0, 0.01, 0.15, 0]` | 1 | random | 8.15 |
| `[0, 0.01, 0.15, 0]` | 0 | random | 6.33 |