# Value Function Approximation

In [None]:
# Test if you can run this block
import numpy as np
import random
import matplotlib.pyplot as plt

## Review Numpy Inner Product
Vectors: $\mathbf{x}, \mathbf{w}$

Inner product: $\mathbf{x}^\top \mathbf{w}$

Let's create two np.array vectors x = [1,2] and w = [3,4]

In [None]:
x = np.array([1, 2])
w = np.array([3, 4])

# Task: Compute the inner product of x and w

## Representing States with Feature Vectors
Let's decide how we represent states using features.

Define a function $x$ that maps a state to the corresponding feature vector

$$x(S_1) = [1,1]$$
$$x(S_2) = [2,1]$$

## Approximating Value Function
Now, instead of storing the value function $v$ in a table, let's parameterize $v$ with weight $\mathbf{w}$.

$$v(S,\mathbf{w}) = x(S)^\top \mathbf{w}= \sum_j x_j(S)\mathbf{w}_j$$

e.g. If $\mathbf{w} = [1,1]$, the V value of state $S_1$ and $S_2$ are:

In [None]:
xs1 = np.array([1, 1])
xs2 = np.array([2, 1])

# Initialize w
w = np.array([1, 1])

# Compute the value of each state based on the current w
def show_all_values(w):
    pass
    return [val_s1, val_s2]

So if we change $\mathbf{w}$, we can change the value function.

But how do we choose a nice $\mathbf{w}$ to have an accurate V values for all states?

We want to approximate TRUE value function $v$. Assume a blackbox algorithm gave us the true value funcrion. For example, $v(S_1) = 4$, and $v(S_2) = 6$.








----------------------------
I think you can understand the benefits of function approx now, comparing it with the tabular method.
1. Even when you visit a state (e.g. $S_1$), the value function of the other states is also updated (e.g. $S_2$).
2. You can reduce memory to save the value function. Imagine you have 1 million states like $S_1$. If it's the tabular method, you need to store 1 million entries of $(s, v(s))$. If it's function approx, you only store two numbers $w = (w_1, w_2)$.
----------------------------

In [None]:
# Manually change w to a few different values and observe the difference in the value function
w = 

print(show_all_values(w))

So we know $\mathbf{w}^* = [2,2]$ is the best parameter approximating $v$. 

Is there any way to find such $\mathbf{w}^*$? The best $\mathbf{w}^*$ should minimize the loss function:
$$J(\mathbf{w})=\mathbb{E}\left[\left(v(S)-\mathbf{x}(S)^{\top} \mathbf{w}\right)^{2}\right]$$
This is becase, in the ideal scenario, $v(s) = \mathbf{x}(S)^{\top} \mathbf{w}^*$

Here, it is a simple minimization of function $J(\mathbf{w})$. Therefore, we can use the gradient descent method to solve it.

$$\begin{aligned} \mathbf{w} &\gets \mathbf{w}-\frac{1}{2} \alpha \nabla_{\mathbf{w}} J(\mathbf{w}) \\ &= \mathbf{w} + \alpha \left(v_{\pi}(S)-\hat{v}(S, \mathbf{w})\right) \nabla_{\mathbf{w}} \hat{v}(S, \mathbf{w}) \end{aligned}$$

Since we have $v(s) = \mathbf{x}(S)^{\top} \mathbf{w}^*$, $$\nabla_{\mathbf{w}} \hat{v}(S, \mathbf{w}) = \mathbf{x}(S).$$

In [7]:
def plotJs(Js, show_limit = 80):
  plt.plot(range(0,show_limit), Js[0:show_limit], label='J')
  plt.ylabel('Loss Function J')
  plt.xlabel('Iteration Round')
  plt.plot()
  plt.show()

In [None]:
num_iter = 4000

# step size
alpha = 0.02

# assume an oracle gave us these values
# in practice, we don't have the oracle. The details will be discussed later.
true_vs1 = 4
true_vs2 = 6

# initialization
w = np.array([1,1])

Js = []

for i in range(num_iter):
  # we assume that an RL agent visits each state (s1, s2) uniformly randomly.
  # this is not true for real scenarios since the agent uses a certain exploration strategy 
  

  # compute the value of the visited state
  # update w
  w =
  
  # compute the loss (true - current)^2
  J = 

  # record Js
  Js.append(J)

  # Print every 100 rounds
  if i % 100 == 0:
    print(f'w: {w}, J: {J}')

print(f'Final w: {w}')
plotJs(Js)

However, we do not know the true value function. (This was the motivation why we started learning RL.) 

What we have access to is sequences of $s, a, r, s^\prime, ...$. We know how we can estimate the true $v$ using a lot of experiences by MC, TD, and other control methods!

If it is MC, we can use return $G_t$ of an episode. We can simply replace the $v_{\pi}(S)$ term in SGD with $G_t$.

$$\alpha\left(G_t-\hat{v}(S, \mathbf{w})\right) \nabla_{\mathbf{w}} \hat{v}(S, \mathbf{w})$$

For example, if it is MC:

Instead of 
```
  w = w + alpha * (true_vs - vs) * xs

  J = pow(true_vs - vs, 2)

```
we will use
```
  w = w + alpha * (Gt - vs) * xs

  J = pow(Gt - vs, 2)
```

As the accumulation of Gt estimates the true value funciton, it should approximate the targe


(X is a terminal state.)

In [None]:
# Assume that we only have two states S1 = (1,1) and S2 = (1,2)
states = {'s1': np.array([1, 1]), 's2': np.array([1, 2])}


# Let's think about a batch MC scenario given the following episodes
episodes = [
            [('s1', 'a1', '3.95'),],
            [('s2', 'a2', '4'),],
            [('s1', 'a2', '3.9'),],
            [('s1', 'a1', '-1.4'), ('s2', 'a1', '6')],
            [('s1', 'a2', '4.1'),],
            [('s1', 'a3', '-1.2'), ('s2', 'a1', '5.8')],
            [('s1', 'a2', '-1.6'), ('s2', 'a2', '6.1')],
            [('s2', 'a3', '5.9'),],
            [('s1', 'a4', '4'),],
            [('s2', 'a4', '6'),]]

gamma = 0.9
S_IDX = 0
R_IDX = 2
alpha = 0.02

num_iter = 100_000

# initialization
w = np.array([0,0])

for round in range(num_iter):
  # randomly pick one episode (replay buffer)
  epi = random.choice(episodes)

  # print(epi)
  Gt = 0.0
  for exp in reversed(epi):
    # Compute the return
    Gt = 
    
    # Update w
    w = 

  # Print every 100 episode
  if round % 100 == 0:
    J = pow(Gt - vs, 2)
    # print(f'r{round} - w: {w}, J: {J}')

print(f'Final w: {w}')