## Course FSD311 - Monte Carlo Tree Search and AlphaZero like algorithms

In this course, we will study the famous Monte Carlo Tree Search (MCTS) algorithm as well as its recent improvement AlphaZero that relies on Neural Networks. MCTS and AlphaZero have been both famous for their results at the Go Game. Notably, AlphaZero beat the world best Go Player which was thought for decades as one of the challenges AI would never address. Recent versions of the algorithm demonstrated astonishing results. The same AlphaZero algorithm, with the same hyperparameters, might be trained from scratch to beat the best chess software (StockFish) in 3hrs and to beat best players at Go in a little bit less than 10hrs.

While both algorithms have been largely used for two players games, they might be adapted to address one player games, i.e. standard Markov Decision Processes (MDPs). Recently, MuZero, a model-based version of AlphaZero, was shown to beat all standard Deep RL algorithms on Atari games. In this notebook, we will developp a one-player version of both algorithms, MCTS and AlphaZero. In this first time, we will make it work on the simple CartPole environmment and then, once the algorithm is functionnal, we will test it on the game FlappyBird.

### First imports and environment setting

<div class="alert alert-block alert-info">
    In this course, we assume that our environments have a <b>get_state</b> and <b>set_state</b> functions that enable to restart the environment in any state. Standard Gymenvironments do not have this property, that's why we rely on a custom CartPole environment and not on the standard CartPole-v0 Gym environment.
</div>

In [None]:
import numpy as np
%cd ../env
from cartpole import CartPole

In [None]:
env = CartPole()

# once the environment has been reset, get_state and set_state might be used
obs = env.reset()

# let's play a random action
new_obs, reward, done, info = env.step(env.action_space.sample())

# we can save the state
state = env.get_state()

# play a new random action
new_obs_2, _, _, _ = env.step(env.action_space.sample())

# the observation has changed
print(new_obs_2 == new_obs)

# we reset to the previous state
new_obs_3 = env.set_state(state)

# and check if the observations are equal now
print(new_obs_3 == new_obs)

Let's play some episodes with a random agent and print both the time per episode and the cummulated reward per episode.

In [None]:
# Play N episodes with a random agent, print the mean score per episode, its std 
# as well as mean and std time per episodes
    

### The MCTS algorithm

The MCTS algorithm principle is simple, we construct a tree in which nodes represent environment states/observations and in which the edges represent actions. Thus, if in the tree we go from a node $s_0$ and take the edge corresponding to an action $a_0$, the node at the end of the edge gets the observation $s_1$ obtained in the environment when applying $a_0$ from $s_0$. We also store in the node the reward $r_1$ that we obtained. The goal is to construct a tree to explore in a smart way the action-state space to find a sequence of actions, from the initial state $s_0$, that maximize the sum of rewards obtained.

The algorithm operates in two consecutives phases: during a first phase, simulations are run from the root node to explore the action-space space. During a second phase, an action is chosen according to the visit counts of all the nodes that have been created in the tree from the root node. The node obtained starting from the root node and taking the selected action becomes the new tree root node. Then both phases are executed again and so on until we reach a state that is final or until we reached a maximum number of steps that has been specified by the user.

#### MCTS simulations

When an MCTS node is created, we create three arrays: one containing the number of visits of its potential children, one containing the Q-values of its potential children and one containing priors over its potential children. All arrays have a size that equals the number of actions available in the environment. We initialize them to 0. The Q-value of a node estimates the averaged sum of reward obtained from its parent node when the action that leads to this node has been chosen. The prior of a node is a prior of the probability for the action that leads to this node to be chosen from its parent node. In standard MCTS, we don't have access to meaningful priors, thus we set them all to 1/number of actions. However, these priors will have an important role when using AlphaZero.

MCTS simulations are usually split into 3 phases:
- node selection
- value and priors estimation
- value backup

<b>Node selection:</b> Starting from a node $s_t$, we select a next node to visit with a criteria that uses the node prior and Q-value estimations. We may define several different criterias. When a node is selected we instanciate the corresponding python object and store it has one of the children of the current node.

<b>Value and priors estimation:</b> When a node is selected we estimate priors over its potential children as well as the V-value of this node. As said, in standard MCTS, we always compute the priors as one over the number of actions. The V-Value corresponds to the averaged sum of rewards obtained from this node. In standard MCTS, to estimate it, we simply perform one or several random rollouts from this node and compute the mean of the sum of rewards encountered. Note that this estimation has a very high variance which is good for epxloration but also decrease significantly the search speed and accuracy. AlphaZero notably proposes a better alternative to estimate this value. If the node is final, i.e. correspond to the end of an episode, we estimate the value as the final reward obtained.

<b>Value backup:</b> Once the node V-Value $v$ has been estimated, we backpropagate this value in the tree up to the root-node to update along the path all the nodes Q-values. Note that is the version of the algorithm we propose in this course, we discount the sum of rewards by a factor $\gamma$. Thus both the V-value and the Q-value computed estimate a discounted sum of rewards. During the backup, we also update the nodes visit counts $N$.

<div class="alert alert-block alert-info">
We repeat the three phases, $n_{simu}$ times where $n_{simu}$ is a hyperparameter of the algorithm.
</div>

### Value backup expression

We get an estimated V-value $v_t$ at node $s_t$ and would like to update all Q-values and visit counts $N$ along the path up to the root-node. To do so, for $k = t \dots 0$, we form an $t-k$-step estimate of the cumulative discounted reward, bootstrapping from the value function $v_t$:

\begin{equation*}
G^k = \sum_{\tau = 0}^{t-1-k} \gamma^{\tau} r_{k+1+\tau} + \gamma^{t-k} v_t
\end{equation*}

Then, for $k = t \dots 1$ , we update the statistics for each edge ($s_{k−1}$, $a_{k}$) in the simulation path as follows:

\begin{equation*}
\begin{aligned}
&Q(s_{k-1}, a_{k}) := \dfrac{N(s_{k-1}, a_{k}).Q(s_{k-1}, a_{k}) + G^k}{N(s_{k-1}, a_{k}) + 1}\\
&N(s_{k-1}, a_{k}) := N(s_{k-1}, a_{k}) + 1
\end{aligned}
\end{equation*}

### Action selection

When the simulations have been run, we need to select the next tree root node (which corresponds to the action we'll actually play). To do so, we rely on the number of visits of the current root node $s_t$ children. We want to select the node that has been visited most but also let some randomness for exploration: this is a traditional exploration / exploitation trade-off.

There are two main approaches to this final selection, the most common is known as the <i>robust child</i>:
in this approach we compute a softmax using the children visit counts as scores and transform them into probabilities, we call the resulting vector <b>tree policy</b> for the state $s_t$. The action is then sampled from the tree policy vector.

In other words, we compute the tree policy $\pi^{\text{mcts}}$ as:

\begin{equation*}
\pi^{\text{mcts}}(s_{t}, a) = \frac{N(s_{t+1}, a)^{\frac{1}{\tau}}}{\sum_\limits{a'} N(s_{t+1}, a')^{\frac{1}{\tau}}}
\end{equation*}

where $\tau$ is a temperature parameter that can be tuned to balance exploration and exploitation. When $\tau \rightarrow 0$, the tree policy distribution tends to a Dirac centered on the action corresponding to the largest number of visits. When $\tau \rightarrow \infty$, the tree policy distribution tends to a uniform distributions over the actions. Of course, $N(s, a)$ is the number of visits during the simulations of the node obtained when starting at the node $s$ and selecting action $a$.

The other final selection method, known as <i>max child</i>, uses a similar process with values of children instead of number of visits.

Finally, the action is sampled from the tree policy vector:
\begin{equation*}
a_t \sim \pi^{\text{mcts}} (s_t|.)
\end{equation*}

<div class="alert alert-block alert-warning">
Be careful, applying directly this expression might lead to an overflow if the number of visits are high. Thus, before applying the formula we may divide all visits counts by the maximum visit count. It avoids overflow while giving the same numerical values.
</div>

<img src="mcts_recap.png" alt="MCTS principle." title="MCTS" />

### MCTS code

We will now write the MCTS node class that contains the attributes and methods depicted above. In a first time, you'll be asked to write selection criteria and backup expressions as you think it should be done. Don't be afraid to try naive expressions. In a second time you'll be provided better criterias. 

First, we define a "fake" tree root node. This node will just be used as a placeholder to propagate information inside the tree, it does correspond a to a state as the other nodes will. Especially, this node corresponds to the depth 0, it has no parent. It just held a copy of the environment that will be used for the simulations. This node is just a coding "trick".

In [None]:
class RootParentNode(object):
    def __init__(self, env):
        self.parent = None
        self.child_q_value = collections.defaultdict(float)
        self.child_number_visits = collections.defaultdict(float)
        self.depth = 0
        self.env = env

Now, let us write the main node class. Please, complete the <b>backup</b> and <b>best_action</b> methods.

In [None]:
class Node:
    def __init__(self, action, reward, obs, state, mcts, depth, done, parent=None):

        self.env = parent.env
        self.action = action  # Action used to go to this state
        self.done = done

        self.is_expanded = False
        self.parent = parent
        self.children = {}
        self.depth = depth

        self.action_space_size = self.env.action_space.n
        
        self.child_q_value = np.zeros([self.action_space_size], dtype=np.float32)  # Q
        self.child_priors = np.zeros([self.action_space_size], dtype=np.float32)  # P
        self.child_number_visits = np.zeros([self.action_space_size], dtype=np.float32)  # N

        self.reward = reward
        self.obs = obs
        self.state = state

        self.mcts = mcts

    @property
    def number_visits(self):
        return self.parent.child_number_visits[self.action]

    @number_visits.setter
    def number_visits(self, value):
        self.parent.child_number_visits[self.action] = value

    @property
    def q_value(self):
        return self.parent.child_q_value[self.action]

    @q_value.setter
    def q_value(self, value):
        self.parent.child_q_value[self.action] = value

    # todo: complete this method
    def best_action(self):
        # compute a score based on the children priors and values
        child_score = 
        return np.argmax(child_score)

    def select(self):
        current_node = self
        while current_node.is_expanded:
            best_action = current_node.best_action()
            current_node = current_node.get_child(best_action)
        return current_node

    def expand(self, child_priors):
        self.is_expanded = True
        self.child_priors = child_priors

    def get_child(self, action):
        # create a node for a selected action and attach it to the current node.
        if action not in self.children:

            self.env.set_state(self.state)
            obs, reward, done, info = self.env.step(action)
            next_state = self.env.get_state()

            self.children[action] = Node(
                obs=obs,
                done=done,
                state=next_state,
                action=action,
                depth=self.depth+1,
                parent=self,
                reward=reward,
                mcts=self.mcts,
            )
        return self.children[action]

    # todo: write a backup function
    def backup(self, value):
        # update current node
        self.q_value = 
        self.number_visits += 1
        # update all other nodes up to root node
        current = self.parent
        while current.parent is not None:
            current.qvalue =
            current.number_visits += 1
            current = current.parent

Now we can write the MCTS class.

In [None]:
class MCTS:
    def __init__(self, mcts_param):
        self.params = mcts_param

    def compute_priors_and_value(self, node):
        env = node.env
        env.set_state(node.state)
        
        # todo: write the value estimation, we may estimate it with a single random rollout now
        value =

        # todo: compute the priors, the resulting tensors must have a (1, num_actions) shape
        priors = 
        return priors, value

    def compute_action(self, node):
        # Run simluations
        for _ in range(self.params['num_simulations']):
            leaf = node.select()
            if leaf.done:
                value = leaf.reward
            else:
                child_priors, value = self.compute_priors_and_value(leaf)
                leaf.expand(child_priors)
            leaf.backup(value)

        # Compute Tree policy target (TPT): todo: complete the tree policy computation
        tree_policy = 
        
        # Choose action according to tree policy
        action = np.random.choice(np.arange(node.action_space_size), p=tree_policy)
        return action, node.children[action]

Finally, we define an MCTS agent class. 

In [None]:
class MCTSAgent:
    def __init__(self, env_creator, config):
        self.env = env_creator()
        self.env_creator = env_creator
        self.config = config

    def play_episode(self):
        
        obs = self.env.reset()
        env_state = self.env.get_state()

        done = False
        t = 0
        total_reward = 0.0

        mcts = MCTS(self.config)

        root_node = Node(
            state=env_state,
            done=False,
            obs=obs,
            reward=0,
            action=None,
            parent=RootParentNode(env=self.env_creator()),
            mcts=mcts,
            depth=0
        )

        while not done:
            t += 1
            # compute action choice
            tree_policy, action, _, root_node = mcts.compute_action(root_node)
            # remove old part of the tree that we wont use anymore
            root_node.parent = RootParentNode(env=self.env_creator())

            # take action
            obs, reward, done, info = self.env.step(action)
            total_reward += reward

        return t, total_reward

### Let's test it

<div class="alert alert-block alert-info">
Test the MCTS agent on CartPole and compare with a random agent. Please, also assess the different hyperparameters importances.
</div>

In [None]:
env_creator = lambda: CartPole()

mcts_config = {
    "num_simulations": 10,
    "gamma": 0.997,
    "temperature": 1.0
}

agent = MCTSAgent(env_creator, mcts_config)

#todo: assess the agent performance and study the hyperparameters importances

***

## Selection improvement during Simulations 

In this part we will propose a way to deal with this conondrum with the UCT algorithm and one of its variants. But first we'll walk through some theory to explain how this algorithm came to be.

### The Multi-Arm Bandit problem


<img src="mab1.png" alt="Multi-Arm Bandit illustration (alt text)" title="Multi-Arm Bandit" />

#### Idea of the problem

You stand in front of multiple slot machines, the mechanism of which you have absolutely no knowledge of. You have a definite amount of coins. How do you make the most of your coins by just playing the machines ?

#### Formal description

A Bernoulli multi-armed bandit can be described as a tuple of ⟨$A$,$R$⟩, where:
* we have K machines with reward probabilities, ${\theta_1, \theta_2, … ,\theta_K}$
* At each time step $t$, we take an action $a$ on one slot machine and receive a reward $r$.

* $A$ is a set of actions, each referring to the interaction with one slot machine. The value of action $a$ is the expected reward, $Q(a) = \mathbb{E}[r|a]=\theta$. If action $a_t$ at the time step $t$ is on the $i$-th machine, then $Q(a_t)=\theta_i$

* $R$ is a reward function. In the case of Bernoulli bandit (with binomial distributions), we observe a reward $r$ in a stochastic fashion: at the time step $t$, $r_t = R(a_t)$ may return reward 1 with a probability $Q(a_t)$ or 0 otherwise.

Note that the Multi-Arm Bandit is a simplified version of Markov decision process, as there is no state $S$.


The goal of this problem is to maximize the cumulative reward $\sum_{t=1}^{T} r_t$. If we know the optimal action $a^*$ with the best reward, then the goal can also be expressed as <b>minimizing the potential regret</b> or loss by not picking the optimal action. We denote the optimal reward probability $\theta^∗$ of the optimal action $a^∗$ : 

\begin{equation*}
θ^∗= Q(a^∗) = \max\limits_{a \in A} Q(a) = \max\limits_{1≤i≤K} \theta_i
\end{equation*}

Our loss function $L$ is the total regret we might have by not selecting the optimal action up to the time step $T$:

\begin{equation*}
L_T = \mathbb{E} \left[ \sum_{t=1}^T (\theta^* - Q(a_t)) \right]
\end{equation*}

### Exploration VS Exploitation in the Multi Arm Bandit setting

We could explore randomly, pull levers without any priors : thats an opportunity to try out option and gain knowledge over some actions. However, a full random search means we can select an action we already know as bad, which we want to avoid.
On the other hand, if we are really risk-averse we could explore just a little to try each action and then perform a greedy policy. However if we are too greedy and never come back to previous actions, we could miss actions that are really good (in terms of mean reward) just because we had tough luck the first time we tried them. Here's the classic exploration VS exploitation conundrum again!

### UCB1, a solution to this dilemma

One solution to this conondrum is to have "optimism in the face of uncertainty". In other terms we favor exploration of actions with a strong potential to have a optimal value.

The Upper Confidence Bounds (UCB) algorithm measures this potential by an upper confidence bound of the reward value. We take actions that have a high potential, observe their rewards and use them to refine our upper confidence bound estimate so that it describes the potential of this action with even more precision.

Formally, we define this upper bound confidence for action $a$ at trial $t$, $U_t(a)$. Let us note the mean of rewards observed for action $a$ after choosing it $N_t(a)$ times as $\hat{Q_t}(a)$.  Since our reward is bounded, we can use [Hoeffding's inequality](http://cs229.stanford.edu/extra-notes/hoeffding.pdf) to write:

\begin{equation*}
\mathbb{P}(Q(a) > \hat{Q_t}(a) + U_t(a)) \leq e^{-2 N_t(a) U_t(a)^2}
\end{equation*}

The right part of the above equation is the probability that the reward of action $a$ will be superior to our estimation of $Q(a)$ + our bound. We want this probability to decrease fast as our number of trials grows. In UCB1, we choose $e^{-2 N_t(a) U_t(a)^2} = t^4$. Thus we can write our bound as:

\begin{equation*}
U_t(a) = \sqrt{\dfrac{2 \ln{t}}{N_t(a)}}
\end{equation*}


Thus we have a way to efficiently express the potential reward of an action as $Q_t(a) + U_t(a)$ The rightmost part of this expression is an <b>incentive to explore actions we're estimating with less confidence</b>. Observe that *not* selecting an action will increase this incentive and selecting this action will increase it.



UCB1 solution to the Multi Arm Bandit is then:

for each trial $t$, with each action $a_i$ being previously tried $N(a_i)$ times :
* select action $a_t$ such that: $a_t = \text{argmax}_{a \in A} \ Q(a) \ + \ \sqrt{\dfrac{2 \ln{t}}{N(a)}}$
* observe reward $r_t(a)$
* use this reward to update expected reward from this action $Q(a_t) := \dfrac{N(a_t) Q(a_t) + r_t(a)}{N(a_t) + 1}$ 
* update number of trials of this action: $N(a_t) := N(a_t) \ + \ 1$

Now let's see how this strategy translates to MCTS.

## Advanced selection methods for MCTS

### UCT

MCTS face the same exploration VS exploitation dilemma, so it's natural some algorithm emulated the logic of UCB1 described just before and ported it to trees. This is how the original UCT, for Upper Confidence bound for Trees, was designed. Note that the proofs for this algorithm's properties require further explanation, but these won't be detailed here.

The idea is to take each node as a separate Multi-Arm Bandit during selection process, with the slight modification that the node's action reward will take discounted future rewards into account - hence the necessity to have a [good backup function](#Value-backup-expression). Starting from root, we'll apply UCB1 to choose an action until we find a non-expanded node. This means we'll have to compute both a Q-value $Q$ and a utility function $U$ for each edge.

Formally, this means that starting from node $s$ from which the set of available actions is $A_s$, we select an edge using:

\begin{equation*}
a = \underset{a \in A_s}{\operatorname{argmax}} \ Q(s,a) +  U(s,a) 
\end{equation*}


Writing the number of times the edge $a$ was chosen from node $s$ as $N(s,a)$ and the number of visits of node $s$ as $N(s)$, UCB1 allows us to write the utility function : 

\begin{equation*}
U(s,a) = 2 \ C_P \ \sqrt{ \frac{ln(N(s))}{N(s,a)}}
\end{equation*}

where $C_P$ is an hyperparameter used for fine-tuning to each problem (in the original UCB1 $C_P$ was fixed at $\sqrt2/2$).

Observe that each visit of the parent node will increase the incentive to visit each of its children, but a visit to one of its child will decrease the incentive to visit this particular child.

<div class="alert alert-block alert-info">
    Go back to the <b>Node.best_action</b> method and implement UCT selection method.
    This will require you to carry a new hyperparameter $C_P$, however we don't want you to spend too much, if any time for this hyperparameter's optimization.
</div>

<div class="alert alert-block alert-info">
    Run a few trials of this new MCTS in the cell below and compare its results to your first model.
</div>

In [None]:
env_creator = lambda: CartPole()

cp = np.sqrt(2)/2

mcts_config = {
    "num_simulations": 10,
    "gamma": 0.997,
    "temperature": 1.0,
    "cp_coefficient": cp,
}

agent = MCTSAgent(env_creator, mcts_config)

# assess the agent performance with the same simulation budget

### PUCT


Numerous variations of UCT exist. Since our ultimate goal is learning AlphaZero today, we suggest using <b>AlphaZero-style PUCT</b>. PUCT stands for Polynomial UCT, and the main concept is to use an adaptive value instead of the $C_P$ constant in UCT formula.
This value is now called $C(s)$ and depends on the number of visits of the parent node.

We continue denoting:
* the set of available actions from state $s$ : $A_s$
* the number of visits of node $s$: $N(s)$
* the number of visits of edge $a \in A_s$ : $N(s,a)$.

AlphaZero's PUCT also makes use of the priors for each action $a$ coming from state $s$, P(s,a). Since we don't have clever priors with MCTS (because the tree's "default policy" is random): $\forall \ a \ \in \ A_s, \ \ P(s,a) = \frac{1}{|A_s|}$ (note that this will change in AlphaZero).

The adaptive coefficient of the upper bound $C(s)$ makes use of 2 hyperparameters: $c_1$ and $c_2$. It is written as:

\begin{equation*}
C(s) = \log{ \left( \dfrac{1 + c_2 + N(s)}{c_2} \right) } \ + \ c_1 
\end{equation*}

Note that for a very high $c_2$ and $N(s) = 0$, $C(s) \approx c_1$ so $c_1$ as the role of the *initial* upper bound coefficient. The log member of this definition is used to progressively increase the overall importance of the utility to compensate for priors that will negatively impact some actions, and thus to preserve some exploration incentive. The $c_2$ hyperparameter controls the rate at which this increase occurs.

The selection process becomes now: from node $s$, we select an edge using:

\begin{equation*}
a = \underset{a \in A_s}{\operatorname{argmax}} \ Q(s,a) \ + \ U(s,a)
\end{equation*}

where

$$U(s,a) = C(s) \ P(s,a) \ \dfrac{\sqrt{N(s)}}{(1 + N(s,a))} $$

Here is your PUCT (AlphaZero version) ! 

<div class="alert alert-block alert-info">
    Go back once again to the <b>Node.best_action</b> method and implement UCT selection method.
    We provide experimentally validated hyperparameters $c_1$, $c_2$, however the variance of this algorithm can still be high.
</div>

<div class="alert alert-block alert-info">
    Run a few trials of this new MCTS. It's ready to be turned into an AlphaZero agent!
</div>

In [None]:
env_creator = lambda: CartPole()

mcts_config = {
    "num_simulations": 10,
    "gamma": 0.997,
    "temperature": 1.0,
    "c1_coefficient": 1.25,
    "c2_coefficient": 19652
}

agent = MCTSAgent(env_creator, mcts_config)

# assess the agent performance and study the hyperparameters importances

### A final improvement: Q-values online normalization

The coefficient $c_1$ and $c_2$ values have a severe impact on the search performance in practice. Furthermore, their optimal values depend of the $Q$-values scale which may vary from one environment to another as well as during training for AlphaZero. Thus, a simple trick consists in normalizing online all $Q$-values using the maximum and minimum $Q$-values in the tree.

This can be done easily in adding two functions to the MCTS:


>     def update_q_value_stats(self, q_value):
        self.max_q_value = max(self.max_q_value, q_value)
        self.min_q_value = min(self.min_q_value, q_value)

>     def normalize_q_value(self, q_value):
        if self.max_q_value > self.min_q_value:
            return (q_value - self.min_q_value) / (self.max_q_value - self.min_q_value)
        else:
            return q_value

And updating the node $Q$-value setter:

>     def q_value(self, value):
        self.parent.child_q_value[self.action] = value
        self.mcts.update_q_value_stats(value)

You can this changes to your MCTS and see how it improves the results! It'll be also a key ingredient for AlphaZero.

***

# AlphaZero

Several AlphaZero algorithms have been developped to play at the Go game, first AlphaGo then AlphaGo Zero and finally AlphaZero which has been developed to play chess, Shogi and Go with the same hyperparameters. Recently, a new algorithm MuZero has been developped to play one player games and rely on model based Reinforcement Learning to leverage the need to reset the environment in a given state. In this notebook we develop a one-player version of the AlphaZero algorithm, exploiting all the tricks introduced in MuZero but without using model based RL.

The main idea behind AlphaZero is to speed up the search with a neural network that returns for a given observation a value and priors over the available actions. The algorithm is trained iteratively, we run searches to produce data to train the neural network, ater each training phase the neural networks predictions become more accurate, thus speeding up the next tree search and so on.

<div class="alert alert-block alert-success">    
    This is the main strength of this approach: <b>the MCTS search improves accross time</b>. The first searches try naively all solutions. During training the algorithms learn to "remember" bad solutions and not to test them anyore, resulting in a better and faster exploration, thus leading to better results.
</div>

<img src="az_training.png" alt="AlphaZero training process." title="AZ training" />

### AlphaZero neural network

The AlphaZero neural net takes a batch of environments observations and returns a batch of priors and values. In this notebook, we assume that the observations are simple vectors, thus we use a simple dense neural network. We will use two shared dense layers with num_hidden neurons. This value, as well as other hyperparameters, will be passed to the class through a config dict argument. After both shared layers, we will add two heads: one to outputs logits over the actions (these logits will be transformed into probabilities later through a softmax layer) and another head to output the value.

<div class="alert alert-block alert-success">
We note $f_{\theta}$ the neural net, $p_t$ and $v_t$ the priors and value estimated by the NN for a given observation $s_t$, thus: $p_t, v_t = f_{\theta}(s_t)$
</div>

<div class="alert alert-block alert-warning">
To compute the value in this notebook, we rely on two tricks: first, we use a scaling function $h(x) = \text{sign}{\left(x \right)}\left( \sqrt{|x|+1} -1 + 0.001x \right)$ to scale the values returned by the MCTS that are used to train the value head of the network. Second, we assume that the values are always included between two threshold $v_{\text{min}}$ and $v_{\text{max}}$, thus we discretize the support between $v_{\text{min}}$ and $v_{\text{max}}$ and learn a distribution instead of a scalar value. Under this transformation, a scalar is represented as the linear combination of its two adjacent supports, such that the original value can be recovered by $x = x_{\text{low}} ∗ p_{\text{low}} + x_{\text{high}} ∗ p_{\text{high}}$. As an example, a target of 3.7 would be represented as a weight of 0.3 on the support for 3 and a weight of 0.7 on the support for 4. The value ehad of the network thus outputs $v_{max} - v_{min} + 1$ logits. To compute the final value, tese logits are transformed into probabilities, averaged on the support and passing through $h^{-1}$. 
</div>

<div class="alert alert-block alert-info">
The first trick ensures that the values always lies in a controllable range, whatever the environment. The second trick allows to fit the value head using a cross-entropy instead of a standard mean-square error. Furthermore, it seems to be able to train a neural network to output a distribution rather than a single scalar. One reason might be the increased degree of freedom. Other might be that all neurons might be the controlled range of the last layer neurons weights, i.e. they all need to outputs small value rather than one single neuron that should output both small and large values if the environment rewards are large. 
</div>

The functions that project values on the support, compute a value from a given support as well as the function implementing $h$ and $h^{-1}$ are given in <b>utils.py</b>. Note that all these functions take and return batches to be used more easily with deep learning libraries.

In [None]:
import torch
import torch.nn as nn
from alphazero.utils import scaling_func_inv, compute_value_from_support_torch


class AlphaZeroModel(nn.Module):
    def __init__(self, obs_space, action_space, model_config):

        self.num_actions = action_space.n
        self.obs_dim = obs_space.shape[0]
        self.value_min_val = model_config['value_support_min_val']
        self.value_max_val = model_config['value_support_max_val']
        self.value_support_size = self.value_max_val - self.value_min_val + 1

        nn.Module.__init__(self)

        num_hidden = model_config['num_hidden']
        self.decision_function_shared = # complete it
        self.policy_head = # complete it
        self.value_head = # complete it

    def forward(self, obs):
        x = self.decision_function_shared(obs)
        logits = self.policy_head(x)
        values_support_logits = self.value_head(x)
        return logits, values_support_logits

    def compute_priors_and_value(self, obs):
        with torch.no_grad():
            obs = torch.from_numpy(obs).float()
            #TODO: complete here
            
            # this function takes a batch of numpy observations,
            # computes a batch of priors (as probabilities, not raw logits)
            # and a batch of values (scalar and unscaled with h^{-1})
            
            # both priors and values returned are numpy as well
            return prior.cpu().numpy(), value.cpu().numpy()

### The AlphaZero Agent

AlphaZero training happens in two phases that are repeated until convergence. During the first phase, the agent plays episodes using its MCTS. The experience collected during the episodes is stored in what we call a <b>Replay Buffer</b>. We store independently all the transitions, a transition being a tuple $\left( s_t, a_t, \pi_t, z_t \right)$, where $z_t$ is a value target, see below. During a second phase, we sample batches of transitions in the replay buffer and train the neural network.

<div class="alert alert-block alert-success">
Having a replay buffer allows to sample batches of uncorrelated data, i.e. the transitions inside a batches come from different episodes at different timesteps. Thus, it increases the probability of the batch to be i.i.d and improves the gradient descent. The replay buffer is a major trick used in plenty of deep RL algorithms. It was first introduced in <a href=https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf >DQN</a>.
</div>

<div class="alert alert-block alert-info">
    The replay buffer code is given in <b>replay_buffer.py</b>.
</div>

We consider a batch of transitions $\left( s_t, a_t, \pi_t, z_t \right)$ and would like to train our network to improve its predictions. In AlphaZero, we may prove that the MCTS operates as a <b>policy improvement operator</b> which means that the tree policy $\pi_t$ and the value target $z_t$ are closer to the optimal policy and optimal value than the network outputs. Thus, to improve our network we simple solve a supervised learning problem. We minimize a distance between the priors and values returned by the network and the tree policies and value targets.

### Computing the value target

To compute the neural network, we both need a tree policy and a value target. We saw above how to compute tree policies. Thus, we only need to define a value target. A first and simple way to define a value target is to take the sum (or the discounted sum in our case) of the rewards encountered during the episode. This is the strategy used in the original AlphaZero. This strategy is commonly called Monte Carlo estimation. In this notebook, we will use an n-steps target as described in the [MuZero paper](https://arxiv.org/abs/1911.08265). N-steps bootstrapping is a standard technique in RL that relies on the Bellman equation. It has a lower variance than Monte Carlo estimation.

First we introduce the tree value estimate $\nu_t$. For a node $s_t$, we define $\nu_t$ as:

\begin{equation*}
\nu_t = \underset{a \in A_s}{\operatorname{\mathbb{E}}} \ Q(s_t, a)
\end{equation*}

where $Q(s_t, a)$ is the Q-value updated for the children of the current node corresponding to the action $a$. $\nu_t$ is a good estimate of the value $V_t$ at node $s_t$. Then, we compute the value target $z_t$ at node $s_t$ as:

\begin{equation*}
z_t = r_{t+1} + \gamma r_{t+2} + \dots + \gamma^{n-1} r_{t+n} + \gamma^n \nu_{t+n}
\end{equation*}

This target $z_t$ might be computed at the end of each episode. We will do it in a <b>postprocess_transitions</b> function.

<div class="alert alert-block alert-info">
    The function which compute the value targets given a trajectory of rewards and tree values is given in <b>utils.py</b>.
</div>

### AlphaZero loss function

AlphaZero loss is simple: it is the sum of two losses, a policy loss and a value loss. For a given state, the policy loss is a distance between the network priors prediction and the tree policy while the value loss is a distance between the network value prediction and the value target. For the priors, a common loss function is to use a cross-entropy. For the values, commonly a mean squared error is used. However in this notebook, for the sake of simplicity we use a slightly different strategy.
Remember that the value head of the network outputs a distribution over a support that represent a scaled value. Thus, we scale the targets with function $h$ and project the scaled target value onto the support. Then, we minimize a distance between supports (of value and target value) relying on cross-entropy.

<div class="alert alert-block alert-info">
    The function to project values on the support as well as the scaling functions are given in <b>utils.py</b>.
</div>

<div class="alert alert-block alert-info">
    The cross-entropy loss is also given in <b>utils.py</b>. Be careful that the networks predictions must be raw logits while the targets must be probabilities (logits that has been passed through a softmax).
</div>

### AlphaZero MCTS

The AlphaZero MCTS is very similar to the one presented above. The only difference is the value and priors estimation. Instead of computing the value with random rollouts and the priors as uniform over the action, we make a forward pass through AlphaZero neural network. When the network is trained, the estimation of the value of the priors are very accurate and thus drastically speed up the tree search.

One potential issue is for the neural network to output not uniform enough priors distributions at the beggining of the training, thus slowing down or even preventing exploration. To prevent this phenomenon, to boost exploration we can add noise to the priors returned by the neural network:

\begin{equation*}
p_t := (1 - \epsilon)p_t + \epsilon \eta(\alpha)
\end{equation*}

where $\epsilon$ and $\alpha$ are two hyperparameters. $\eta(\alpha)$ is a Dirichlet noise distribution with parameter $\alpha$. We will take $\epsilon = 0.2$ and $\alpha$ as one over the number of actions.

<div class="alert alert-block alert-info">
The node class introduced above do not need to change. We only need to change the <b>compute_priors_and_value</b> method of the MCTS class, as well as computing the tree value estimates and returning the tree policies at the end of the <b>compute_action</b> function.
</div>

In [None]:
class AlphaZeroMCTS(MCTS):
    def __init__(self, mcts_param, model):
        MCTS.__init__(self, mcts_param)
        self.model = model
        self.config = mcts_params
        
    def add_noise_to_priors(self, priors):
        noise = np.random.dirichlet([self.config['dir_noise']] * priors.size)
        # complete this method, assuming that epsilon is stored in self.config['dir_epsilon']
        return priors
    
    def compute_priors_and_value(self, node):
        obs = np.expand_dims(node.obs, axis=0)  # add batch size of 1
        priors, value = self.model.compute_priors_and_value(obs)
        if self.config['add_dirichlet_noise']:
            priors = self.add_noise_to_priors(priors)
        return priors, value
    
    def compute_action(self, node):
        # Run simulations
        for _ in range(self.params['num_simulations']):
            leaf = node.select()
            if leaf.done:
                value = leaf.reward
            else:
                child_priors, value = self.compute_priors_and_value(leaf)
                leaf.expand(child_priors)
            leaf.backup(value)

        # Compute Tree policy target (TPT): todo: complete the tree policy computation
        tree_policy = 
        
        
        # Compute Tree value
        tree_value = 
        
        # Choose action according to tree policy
        action = np.random.choice(np.arange(node.action_space_size), p=tree_policy)
        
        return tree_policy, action, tree_value, node.children[action]

### AlphaZero agent code

You turn to write the AlphaZero agent code :)

In [None]:
class AlphaZero:
    def __init__(self, env_creator, config):
        self.env_creator = env_creator
        self.env = env_creator()
        self.config = config
        self.mcts_config = config['mcts_config']
        self.mcts_config.update(config)
        self.model = AlphaZeroModel(self.env.observation_space, self.env.action_space, config['model_config'])
        self.replay_buffer = ReplayBuffer(config['buffer_size'])
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config['lr'])
        self.total_num_steps = 0

    def play_episode(self):
        transitions = {
            "observations": [],
            "actions": [],
            "rewards": [],
            "tree_policies": [],
            "tree_values": [],
        }
        # TODO: complete this method
        
        # play one episode with mcts and store the observations, actions, rewards, 
        # tree policies and tree values at each timestep in the dictionnary transitions

        return transitions

    def postprocess_transitions(self, transitions):
        
        # TODO: complete this method
        
        # transitions dict flows directly into this function when an episode has been played
        # compute the value targets from the rewards and tree values
        # the parameter gamma is in self.config['gamma'] and the parameter n is in
        # self.config['n_steps']
        value_targets =
        
        # we scale the value targets using function h
        value_targets = scaling_func(value_targets, mode='numpy')
        
        # we transform the np array into a list of numpy arrays, one per transition
        transitions['value_targets'] = np.split(value_targets, len(value_targets))

        # we dont store useless arrays in the buffer
        del transitions['rewards']
        del transitions['tree_values']

        return transitions

    def compute_loss(self, batch):
        
        # TODO: complete this method
        
        # compute AlphaZero loss in this function
        # batch is a dict of transitions with keys: 'observations', 'tree_policies', 'value_targets'
        # each key is associated to a numpy which first dim equals batch size
        
        # first we get supports parameters
        v_support_minv, v_support_maxv = self.model.value_min_val, self.model.value_max_val
        
        # transform numpy vectors to torch tensors
        observations = torch.from_numpy(batch['observations']).float()
        mcts_policies = torch.from_numpy(batch["tree_policies"]).float()
        value_targets = torch.from_numpy(batch["value_targets"]).float()[:, 0]
        
        # compute losses

        policy_loss = 
        value_loss = 

        # compute total loss
        # we rescale the value loss with a coefficient given as an hyperparameter
        value_loss = self.config['value_loss_coefficient'] * value_loss
        total_loss = policy_loss + value_loss
        return total_loss, policy_loss, value_loss

    def train(self):
        # we train the agent for several epochs. In this notebook we define an epoch as the succession
        # of data generation (we play episodes with the MCTS) and training (we sample 
        # batches of data in the replay buffer and train on them)
        for _ in range(self.config['num_epochs']):
            episode_rewards = []
            num_steps = 0
            for _ in range(self.config['num_episodes_per_epoch']):
                # play an episode
                transitions = self.play_episode()
                episode_rewards.append(np.sum(transitions['rewards']))
                num_steps += len(transitions['rewards'])
                # process the transitions
                transitions = self.postprocess_transitions(transitions)
                # store them in the replay buffer
                self.replay_buffer.add(transitions)

            avg_rewards = np.mean(episode_rewards)
            max_rewards = np.max(episode_rewards)
            min_rewards = np.min(episode_rewards)
            self.total_num_steps += num_steps

            s = 'Num timesteps sampled so far {}'.format(self.total_num_steps)
            s += ', mean accumulated reward: {}'.format(avg_rewards)
            s += ', min accumulated reward: {}'.format(min_rewards)
            s += ', max accumulated reward: {}'.format(max_rewards)
            print(s)

            # we want for the buffer to contain a minimum numer of transitions
            # if enough timesteps collected, then start training
            if self.total_num_steps >= self.config['learning_starts']:
                
                # perform one SGD per transition sampled
                for _ in range(num_steps):
                    # sample transitions in the replay buffer
                    batch = self.replay_buffer.sample(self.config['batch_size'])
                    # compute loss
                    total_loss, policy_loss, value_loss = self.compute_loss(batch)
                    # do backprop
                    self.optimizer.zero_grad()
                    total_loss.backward()
                    self.optimizer.step()

### Ready to be tested !

Let's test our AlphaZero agent ! Please try first we the given hyper-parameters. As there is variance between seeds don't be afraid to rerun several times. Then feel free to play with the parameters to see their impact. You can also compare with standard MCTS ;)

In [None]:
# create an env_creator function
env_creator = lambda: CartPole()

# define the config with the hyper-parameters
config = {
    'buffer_size': 1000,
    'batch_size': 256,
    'lr': 1e-3,

    'gamma': 0.997,
    'n_steps': 10,

    'num_epochs': 100,
    'num_episodes_per_epoch': 5,
    'learning_starts': 500,  # number of timesteps to sample before SGD

    'value_loss_coefficient': 0.2,

    'model_config': {
        'value_support_min_val': 0,
        'value_support_max_val': 30,
        'num_hidden': 32,
    },

    'mcts_config': {
        'num_simulations': 20,
        "temperature": 1.0,
        "c1_coefficient": 1.25,
        "c2_coefficient": 19652,
        'add_dirichlet_noise': True,
        'dir_noise': 0.5,
        'dir_epsilon': 0.2,
    }
}

# instanciate the agent
agent = AlphaZero(env_creator, config)

# train it
agent.train()