# Off-policy divergence: Baird's MDP

This notebook implements an empirical proof of `off-policy` divergence using [Baird's MDP]()

> See `exercise 11.3` in [Sutton & Barto's RL Introduction book](http://incompleteideas.net/book/RLbook2020.pdf)

<img src="https://lcalem.github.io/imgs/sutton/bairds.png" 
     alt="Sutton &amp; Barto summary chap 11 - Off-policy methods for approximation ..."
     width="650"
/>

More generally, off-policy divergence occurs when we have instances of the [Deadly Triad](https://arxiv.org/pdf/1812.02648v1.pdf):

 - **Function approximation**: (e.g., linear function approximation or ANNs)
 - **Bootstrapping**: Update targets that include existing estimates (as in dynamic programming or TD methods) rather than relying exclusively on actual rewards and complete returns (as in MC methods)
 - **Off-policy training**: Training on a distribution of transitions other than that produced by the target policy

## Parameters

In [13]:
n_states = 7
n_actions = 2
disccount_factor = 0.99

In [14]:
import numpy as np


class BairdMDP:
    def __init__(self, n_states: int, n_actions: int):
        self.n_states = n_states
        self.n_actions = n_actions  # dashed & solid
        self.state = None

    def _random_state(self, n):
        # uniform probability over 0..n states
        return np.random.choice(np.arange(n), p=np.ones(n) * 1 / n)

    def reset(self):
        self.state = self._random_state(self.n_states)
        return self.state, {}  # state, info

    def step(self, action: int):
        if action < 0 or action > self.n_actions - 1:
            raise ValueError(f"Invalid action {action}. Not between 0 and 1")

        if action == 0:
            self.state = self._random_state(self.n_states - 1)
        elif action == 1:
            self.state = self.n_states - 1

        # next_state, reward, terminated, truncated, info
        return self.state, 0, False, False, {}

## Estimator

In [16]:
class Estimator:
    def __init__(self, n_states: int):
        self.n = n_states
        self.weights = np.zeros(self.n + 1)
        self.features = np.eye(self.n + 1)
        self.features[: self.n, : self.n] *= 2

    def __call__(self, s: int):
        if s < 0 or s > self.n - 1:
            raise ValueError(
                f"State '{s}' out of range. Must be within 0 and {self.n-1}"
            )
        return np.dot(self.weights, self.features)[s]

In [17]:
env = BairdMDP(n_states, n_actions)
V = Estimator(n_states)

# s, _ = env.reset()
# print(s)

# for i in range(10):
#     a = np.random.choice([0, 1])
#     next_s, _, _, _, _ = env.step(a)
#     print(f"S:{s} x A:{a} -> S':{next_s}")
#     s = next_s