# Lab 14: MuZero

In this lab, we experiment with a reinforcement learning method claimed to be a prototype of AGI (Artificial General Intelligence): MuZero.

Material in this lab comes from several sources:
- https://www.deepmind.com/blog/muzero-mastering-go-chess-shogi-and-atari-without-rules
- https://github.com/werner-duvaud/muzero-general
- https://medium.com/applied-data-science/how-to-build-your-own-muzero-in-python-f77d5718061a
- https://arxiv.org/src/1911.08265v1/anc/pseudocode.py
- https://github.com/suragnair/alpha-zero-general
- https://arxiv.org/pdf/1911.08265.pdf

## Artificial General Intelligence (AGI)

AI methods that solve a single problem such as classification, object detection, or text translation are called "Narrow AI" methods.
Systems that employ narrow AI cannot switch to a new task without modification by an external actor. AI methods that can adapt to new
problems in the same way humans are capable of are called "General AI" methods.

*Artificial general intelligence* (AGI) is the capability of learning or understanding any intellectual task that humans can. AGI is
some AI researchers' ultimate goal and a favorite topic for science fiction writers and prognosticators about the future (try one of
our favorite "future science" books, [Life 3.0 by Max Tegmark](https://www.amazon.com/Life-3-0-Being-Artificial-Intelligence/dp/1101946598)).


A [2020 survey by the Global Catastrophic Risk Institute](https://gcrinstitute.org/papers/055_agi-2020.pdf) identifies 72 active AGI R&D projects spread across 37 countries.

## MuZero

We've read about DeepMind's 2016 AlphaGo system, which was the first program that could play Go at a master level.
In 2018, DeepMind released AlphaZero, which could learn to play Go from scratch without relying on example games by humans.
AlphaZero could master three games rather than just one: Go, chess and shogi. In 2020, DeepMind introduced MuZero, a
step toward more general-purpose algorithms that learns any game without being told the rules.

Here is [a preprint of DeepMind's MuZero paper on arXiv](https://arxiv.org/abs/1911.08265).

Recall that most of the RL methods we've looked at so far have been *model free* meaning they don't try to predict
the consequences of an action. This is sensible, because it dramatically simplifies the learning algorithm to something
tractable.

It's clear, however, that humans do some planning when getting ready to perform some task. With some introspection,
you will probably realize that when we plan how to perform a task, we usually mentally practice our steps, visualizing
what the world/environment is going to look like as we take the steps in our plan.

But doing this prediction obviously requires some model of the world/environment! In the MDP formulation of RL, the model is a
distribution $p(s' \mid s, a)$. To even begin to think about this distibution, we need to know what are $\mathcal{S}$
and $\mathcal{A}$. How could it be possible to build a system that can learn any task that, like humans, predicts the
consequences of its actions, without hard coding the structure of the world $\mathcal{S}$ and what it can do $\mathcal{A}$?

This is the clever contribution of MuZero. It eschews the model-free RL architectures we have mainly adopted thus
far, instead building the model from scratch as it learns, then using that model to decide what actions are best for the task
at hand in a way similar to the lookahead search of AlphaGo and AlphaZero.

With this approach, MuZero set set a new state of the art result on the Atari benchmark, while also matching AlphaZero
in Go, chess and shogi.

Here's a [summary from DeepMind](https://www.deepmind.com/blog/muzero-mastering-go-chess-shogi-and-atari-without-rules)
of the relative advances represented by AlphaGo, AlphaZero, and MuZero:
<img src="img/alphago_summary.jpeg" title="alphago_summary" style="width: 800px;" />


### MuZero vs. AlphaZero

AlphaZero can in principle learn any game but requires knowledge of the rules of how pieces move and which moves are legal.
This would obviously be troublesome in environments like Atari. To get AlphaZero to learn Atari, you would have to model the
effects of each action on the environment explicitly and hard code them into the game. You would basically be reimplementing
the game itself!

MuZero uses much less knowledge. We only need to tell it what moves are legal in the current position and when the game is over.

During exploration, MuZero, through its model, has to learn the rules of the game implicitly, like when you are first learning
chess by playing against your older brother and he only stops you from making illegal moves while he crushes you!

### MuZero models

Predicting the consequences of our actions without replicating the entire environment in our heads is a difficult task
that we are pretty good at. MuZero approximates this capability by modeling just the aspects of the environment that are
important to the decision making process. Three elements of the environment are critical to planning:

- Value: how good is the current state?
- Reward: how good was the last action?
- Dynamics: what will happen if I take a particular action?

The MuZero policy utilizes all of the elements, represented by
deep neural networks, to understand what happens when it takes certain actions and to plan accordingly.

### Pseudocode

Alongside the [MuZero paper](https://arxiv.org/pdf/1911.08265.pdf), DeepMind have released [Python pseudocode]((https://arxiv.org/src/1911.08265v1/anc/pseudocode.py)) detailing the interactions between each part of the algorithm.

In [None]:
# Lint as: python3
"""Pseudocode description of the MuZero algorithm."""
# pylint: disable=unused-argument
# pylint: disable=missing-docstring
# pylint: disable=g-explicit-length-test

from __future__ import absolute_import
from __future__ import division
from __future__ import google_type_annotations
from __future__ import print_function

import collections
import math
import typing
from typing import Dict, List, Optional

import numpy
import tensorflow as tf

##########################
####### Helpers ##########

MAXIMUM_FLOAT_VALUE = float('inf')

KnownBounds = collections.namedtuple('KnownBounds', ['min', 'max'])


class MinMaxStats(object):
    """A class that holds the min-max values of the tree."""

    def __init__(self, known_bounds: Optional[KnownBounds]):
        self.maximum = known_bounds.max if known_bounds else -MAXIMUM_FLOAT_VALUE
        self.minimum = known_bounds.min if known_bounds else MAXIMUM_FLOAT_VALUE

    def update(self, value: float):
        self.maximum = max(self.maximum, value)
        self.minimum = min(self.minimum, value)

    def normalize(self, value: float) -> float:
        if self.maximum > self.minimum:
            # We normalize only when we have set the maximum and minimum values.
            return (value - self.minimum) / (self.maximum - self.minimum)
        return value


class MuZeroConfig(object):

    def __init__(self,
                 action_space_size: int,
                 max_moves: int,
                 discount: float,
                 dirichlet_alpha: float,
                 num_simulations: int,
                 batch_size: int,
                 td_steps: int,
                 num_actors: int,
                 lr_init: float,
                 lr_decay_steps: float,
                 visit_softmax_temperature_fn,
                 known_bounds: Optional[KnownBounds] = None):
        
        ### Self-Play
        self.action_space_size = action_space_size
        self.num_actors = num_actors

        self.visit_softmax_temperature_fn = visit_softmax_temperature_fn
        self.max_moves = max_moves
        self.num_simulations = num_simulations
        self.discount = discount

        # Root prior exploration noise.
        self.root_dirichlet_alpha = dirichlet_alpha
        self.root_exploration_fraction = 0.25

        # UCB formula
        self.pb_c_base = 19652
        self.pb_c_init = 1.25

        # If we already have some information about which values occur in the
        # environment, we can use them to initialize the rescaling.
        # This is not strictly necessary, but establishes identical behaviour to
        # AlphaZero in board games.
        self.known_bounds = known_bounds

        ### Training
        self.training_steps = int(1000e3)
        self.checkpoint_interval = int(1e3)
        self.window_size = int(1e6)
        self.batch_size = batch_size
        self.num_unroll_steps = 5
        self.td_steps = td_steps

        self.weight_decay = 1e-4
        self.momentum = 0.9

        # Exponential learning rate schedule
        self.lr_init = lr_init
        self.lr_decay_rate = 0.1
        self.lr_decay_steps = lr_decay_steps

    def new_game(self):
        return Game(self.action_space_size, self.discount)


def make_board_game_config(action_space_size: int, max_moves: int,
                           dirichlet_alpha: float,
                           lr_init: float) -> MuZeroConfig:

    def visit_softmax_temperature(num_moves, training_steps):
        if num_moves < 30:
            return 1.0
        else:
            return 0.0  # Play according to the max.

    return MuZeroConfig(
            action_space_size=action_space_size,
            max_moves=max_moves,
            discount=1.0,
            dirichlet_alpha=dirichlet_alpha,
            num_simulations=800,
            batch_size=2048,
            td_steps=max_moves,  # Always use Monte Carlo return.
            num_actors=3000,
            lr_init=lr_init,
            lr_decay_steps=400e3,
            visit_softmax_temperature_fn=visit_softmax_temperature,
            known_bounds=KnownBounds(-1, 1))


def make_go_config() -> MuZeroConfig:
    return make_board_game_config(
            action_space_size=362, max_moves=722, dirichlet_alpha=0.03, lr_init=0.01)


def make_chess_config() -> MuZeroConfig:
    return make_board_game_config(
            action_space_size=4672, max_moves=512, dirichlet_alpha=0.3, lr_init=0.1)


def make_shogi_config() -> MuZeroConfig:
    return make_board_game_config(
            action_space_size=11259, max_moves=512, dirichlet_alpha=0.15, lr_init=0.1)


def make_atari_config() -> MuZeroConfig:

    def visit_softmax_temperature(num_moves, training_steps):
        if training_steps < 500e3:
            return 1.0
        elif training_steps < 750e3:
            return 0.5
        else:
            return 0.25

        return MuZeroConfig(
                action_space_size=18,
                max_moves=27000,  # Half an hour at action repeat 4.
                discount=0.997,
                dirichlet_alpha=0.25,
                num_simulations=50,
                batch_size=1024,
                td_steps=10,
                num_actors=350,
                lr_init=0.05,
                lr_decay_steps=350e3,
                visit_softmax_temperature_fn=visit_softmax_temperature)


class Action(object):

    def __init__(self, index: int):
        self.index = index

    def __hash__(self):
        return self.index

    def __eq__(self, other):
        return self.index == other.index

    def __gt__(self, other):
        return self.index > other.index


class Player(object):
    pass


class Node(object):

    def __init__(self, prior: float):
        self.visit_count = 0
        self.to_play = -1
        self.prior = prior
        self.value_sum = 0
        self.children = {}
        self.hidden_state = None
        self.reward = 0

    def expanded(self) -> bool:
        return len(self.children) > 0

    def value(self) -> float:
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count


class ActionHistory(object):
    """Simple history container used inside the search.

    Only used to keep track of the actions executed.
    """

    def __init__(self, history: List[Action], action_space_size: int):
        self.history = list(history)
        self.action_space_size = action_space_size

    def clone(self):
        return ActionHistory(self.history, self.action_space_size)

    def add_action(self, action: Action):
        self.history.append(action)

    def last_action(self) -> Action:
        return self.history[-1]

    def action_space(self) -> List[Action]:
        return [Action(i) for i in range(self.action_space_size)]

    def to_play(self) -> Player:
        return Player()


class Environment(object):
    """The environment MuZero is interacting with."""

    def step(self, action):
        pass


class Game(object):
    """A single episode of interaction with the environment."""

    def __init__(self, action_space_size: int, discount: float):
        self.environment = Environment()  # Game specific environment.
        self.history = []
        self.rewards = []
        self.child_visits = []
        self.root_values = []
        self.action_space_size = action_space_size
        self.discount = discount

    def terminal(self) -> bool:
        # Game specific termination rules.
        pass

    def legal_actions(self) -> List[Action]:
        # Game specific calculation of legal actions.
        return []

    def apply(self, action: Action):
        reward = self.environment.step(action)
        self.rewards.append(reward)
        self.history.append(action)

    def store_search_statistics(self, root: Node):
        sum_visits = sum(child.visit_count for child in root.children.values())
        action_space = (Action(index) for index in range(self.action_space_size))
        self.child_visits.append([
            root.children[a].visit_count / sum_visits if a in root.children else 0
            for a in action_space
        ])
        self.root_values.append(root.value())

    def make_image(self, state_index: int):
        # Game specific feature planes.
        return []

    def make_target(self, state_index: int, num_unroll_steps: int, td_steps: int,
                    to_play: Player):
        # The value target is the discounted root value of the search tree N steps
        # into the future, plus the discounted sum of all rewards until then.
        targets = []
        for current_index in range(state_index, state_index + num_unroll_steps + 1):
            bootstrap_index = current_index + td_steps
            if bootstrap_index < len(self.root_values):
                value = self.root_values[bootstrap_index] * self.discount**td_steps
            else:
                value = 0

            for i, reward in enumerate(self.rewards[current_index:bootstrap_index]):
                value += reward * self.discount**i  # pytype: disable=unsupported-operands

            if current_index < len(self.root_values):
                targets.append((value, self.rewards[current_index],
                               self.child_visits[current_index]))
            else:
                # States past the end of games are treated as absorbing states.
                targets.append((0, 0, []))
        return targets

    def to_play(self) -> Player:
        return Player()

    def action_history(self) -> ActionHistory:
        return ActionHistory(self.history, self.action_space_size)


class ReplayBuffer(object):

    def __init__(self, config: MuZeroConfig):
        self.window_size = config.window_size
        self.batch_size = config.batch_size
        self.buffer = []

    def save_game(self, game):
        if len(self.buffer) > self.window_size:
            self.buffer.pop(0)
        self.buffer.append(game)

    def sample_batch(self, num_unroll_steps: int, td_steps: int):
        games = [self.sample_game() for _ in range(self.batch_size)]
        game_pos = [(g, self.sample_position(g)) for g in games]
        return [(g.make_image(i), g.history[i:i + num_unroll_steps],
                 g.make_target(i, num_unroll_steps, td_steps, g.to_play()))
                for (g, i) in game_pos]

    def sample_game(self) -> Game:
        # Sample game from buffer either uniformly or according to some priority.
        return self.buffer[0]

    def sample_position(self, game) -> int:
        # Sample position from game either uniformly or according to some priority.
        return -1


class NetworkOutput(typing.NamedTuple):
    value: float
    reward: float
    policy_logits: Dict[Action, float]
    hidden_state: List[float]


class Network(object):

    def initial_inference(self, image) -> NetworkOutput:
        # representation + prediction function
        return NetworkOutput(0, 0, {}, [])

    def recurrent_inference(self, hidden_state, action) -> NetworkOutput:
        # dynamics + prediction function
        return NetworkOutput(0, 0, {}, [])

    def get_weights(self):
        # Returns the weights of this network.
        return []

    def training_steps(self) -> int:
        # How many steps / batches the network has been trained for.
        return 0


class SharedStorage(object):

    def __init__(self):
        self._networks = {}

    def latest_network(self) -> Network:
        if self._networks:
            return self._networks[max(self._networks.keys())]
        else:
            # policy -> uniform, value -> 0, reward -> 0
            return make_uniform_network()

    def save_network(self, step: int, network: Network):
        self._networks[step] = network


##### End Helpers ########
##########################


# MuZero training is split into two independent parts: Network training and
# self-play data generation.
# These two parts only communicate by transferring the latest network checkpoint
# from the training to the self-play, and the finished games from the self-play
# to the training.
def muzero(config: MuZeroConfig):
    storage = SharedStorage()
    replay_buffer = ReplayBuffer(config)

    for _ in range(config.num_actors):
        launch_job(run_selfplay, config, storage, replay_buffer)

    train_network(config, storage, replay_buffer)

    return storage.latest_network()


##################################
####### Part 1: Self-Play ########


# Each self-play job is independent of all others; it takes the latest network
# snapshot, produces a game and makes it available to the training job by
# writing it to a shared replay buffer.
def run_selfplay(config: MuZeroConfig, storage: SharedStorage,
                 replay_buffer: ReplayBuffer):
    while True:
        network = storage.latest_network()
        game = play_game(config, network)
        replay_buffer.save_game(game)


# Each game is produced by starting at the initial board position, then
# repeatedly executing a Monte Carlo Tree Search to generate moves until the end
# of the game is reached.
def play_game(config: MuZeroConfig, network: Network) -> Game:
    game = config.new_game()

    while not game.terminal() and len(game.history) < config.max_moves:
        # At the root of the search tree we use the representation function to
        # obtain a hidden state given the current observation.
        root = Node(0)
        current_observation = game.make_image(-1)
        expand_node(root, game.to_play(), game.legal_actions(),
                    network.initial_inference(current_observation))
        add_exploration_noise(config, root)

        # We then run a Monte Carlo Tree Search using only action sequences and the
        # model learned by the network.
        run_mcts(config, root, game.action_history(), network)
        action = select_action(config, len(game.history), root, network)
        game.apply(action)
        game.store_search_statistics(root)
    return game


# Core Monte Carlo Tree Search algorithm.
# To decide on an action, we run N simulations, always starting at the root of
# the search tree and traversing the tree according to the UCB formula until we
# reach a leaf node.
def run_mcts(config: MuZeroConfig, root: Node, action_history: ActionHistory,
             network: Network):
    min_max_stats = MinMaxStats(config.known_bounds)

    for _ in range(config.num_simulations):
        history = action_history.clone()
        node = root
        search_path = [node]

        while node.expanded():
            action, node = select_child(config, node, min_max_stats)
            history.add_action(action)
            search_path.append(node)

        # Inside the search tree we use the dynamics function to obtain the next
        # hidden state given an action and the previous hidden state.
        parent = search_path[-2]
        network_output = network.recurrent_inference(parent.hidden_state,
                                                     history.last_action())
        expand_node(node, history.to_play(), history.action_space(), network_output)

        backpropagate(search_path, network_output.value, history.to_play(),
                      config.discount, min_max_stats)


def select_action(config: MuZeroConfig, num_moves: int, node: Node,
                  network: Network):
    visit_counts = [
        (child.visit_count, action) for action, child in node.children.items()
    ]
    t = config.visit_softmax_temperature_fn(
        num_moves=num_moves, training_steps=network.training_steps())
    _, action = softmax_sample(visit_counts, t)
    return action


# Select the child with the highest UCB score.
def select_child(config: MuZeroConfig, node: Node,
                 min_max_stats: MinMaxStats):
    _, action, child = max(
        (ucb_score(config, node, child, min_max_stats), action,
         child) for action, child in node.children.items())
    return action, child


# The score for a node is based on its value, plus an exploration bonus based on
# the prior.
def ucb_score(config: MuZeroConfig, parent: Node, child: Node,
              min_max_stats: MinMaxStats) -> float:
    pb_c = math.log((parent.visit_count + config.pb_c_base + 1) /
                  config.pb_c_base) + config.pb_c_init
    pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

    prior_score = pb_c * child.prior
    value_score = min_max_stats.normalize(child.value())
    return prior_score + value_score


# We expand a node using the value, reward and policy prediction obtained from
# the neural network.
def expand_node(node: Node, to_play: Player, actions: List[Action],
                network_output: NetworkOutput):
    node.to_play = to_play
    node.hidden_state = network_output.hidden_state
    node.reward = network_output.reward
    policy = {a: math.exp(network_output.policy_logits[a]) for a in actions}
    policy_sum = sum(policy.values())
    for action, p in policy.items():
        node.children[action] = Node(p / policy_sum)


# At the end of a simulation, we propagate the evaluation all the way up the
# tree to the root.
def backpropagate(search_path: List[Node], value: float, to_play: Player,
                  discount: float, min_max_stats: MinMaxStats):
    for node in search_path:
        node.value_sum += value if node.to_play == to_play else -value
        node.visit_count += 1
        min_max_stats.update(node.value())

        value = node.reward + discount * value


# At the start of each search, we add dirichlet noise to the prior of the root
# to encourage the search to explore new actions.
def add_exploration_noise(config: MuZeroConfig, node: Node):
    actions = list(node.children.keys())
    noise = numpy.random.dirichlet([config.root_dirichlet_alpha] * len(actions))
    frac = config.root_exploration_fraction
    for a, n in zip(actions, noise):
        node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac


######### End Self-Play ##########
##################################

##################################
####### Part 2: Training #########


def train_network(config: MuZeroConfig, storage: SharedStorage,
                  replay_buffer: ReplayBuffer):
    network = Network()
    learning_rate = config.lr_init * config.lr_decay_rate**(
        tf.train.get_global_step() / config.lr_decay_steps)
    optimizer = tf.train.MomentumOptimizer(learning_rate, config.momentum)

    for i in range(config.training_steps):
        if i % config.checkpoint_interval == 0:
            storage.save_network(i, network)
        batch = replay_buffer.sample_batch(config.num_unroll_steps, config.td_steps)
        update_weights(optimizer, network, batch, config.weight_decay)
    storage.save_network(config.training_steps, network)


def update_weights(optimizer: tf.train.Optimizer, network: Network, batch,
                   weight_decay: float):
    loss = 0
    for image, actions, targets in batch:
        # Initial step, from the real observation.
        value, reward, policy_logits, hidden_state = network.initial_inference(image)
        predictions = [(1.0, value, reward, policy_logits)]

        # Recurrent steps, from action and previous hidden state.
        for action in actions:
            value, reward, policy_logits, hidden_state = network.recurrent_inference(
                hidden_state, action)
            predictions.append((1.0 / len(actions), value, reward, policy_logits))

            hidden_state = tf.scale_gradient(hidden_state, 0.5)

        for prediction, target in zip(predictions, targets):
            gradient_scale, value, reward, policy_logits = prediction
            target_value, target_reward, target_policy = target

            l = (
                scalar_loss(value, target_value) +
                scalar_loss(reward, target_reward) +
                tf.nn.softmax_cross_entropy_with_logits(
                    logits=policy_logits, labels=target_policy))

            loss += tf.scale_gradient(l, gradient_scale)

    for weights in network.get_weights():
        loss += weight_decay * tf.nn.l2_loss(weights)

    optimizer.minimize(loss)


def scalar_loss(prediction, target) -> float:
    # MSE in board games, cross entropy between categorical values in Atari.
    return -1

######### End Training ###########
##################################

################################################################################
############################# End of pseudocode ################################
################################################################################


# Stubs to make the typechecker happy.
def softmax_sample(distribution, temperature: float):
    return 0, 0


def launch_job(f, *args):
    f(*args)


def make_uniform_network():
    return Network()


### 1. Muzero function

We'll take a look at the essential parts of the pseudocode.
The `muzero` function provides an overview of the entire process.

<img src="img/muzero_sum.png" title="muzero sum" style="width: 800px;" />

In [None]:
def muzero(config: MuZeroConfig):
    storage = SharedStorage()
    replay_buffer = ReplayBuffer(config)

    for _ in range(config.num_actors):
        launch_job(run_selfplay, config, storage, replay_buffer)

    train_network(config, storage, replay_buffer)

    return storage.latest_network()

The entry point function muzero is passed a `MuZeroConfig` object, which stores important information
about parameter settings such as the `action_space_size` (number of possible actions) and `num_actors`
(the number of parallel game simulations to run).

There are two independent parts to the MuZero algorithm, self-play (creating game data) and training
(producing improved versions of the neural network models). The SharedStorage and ReplayBuffer objects
can be accessed by both halves of the algorithm in order to store neural network versions and game data.

### Shared Storage and the Replay Buffer

The SharedStorage object contains methods for saving a version of the neural network and retrieving the latest neural network from the store.

In [None]:
class SharedStorage(object):

    def __init__(self):
        self._networks = {}

    def latest_network(self) -> Network:
        if self._networks:
            return self._networks[max(self._networks.keys())]
        else:
            # policy -> uniform, value -> 0, reward -> 0
            return make_uniform_network()

    def save_network(self, step: int, network: Network):
        self._networks[step] = network

### Replay buffer

The ReplayBuffer stores data from previous games. The ReplayBuffer class contains a sample_batch method to sample a batch of observations from the buffer.

The default batch_size of MuZero for chess is 2048. This number of games are selected from the buffer and one position is chosen from each.

A single batch is a list of tuples, where each tuple consists of three elements:
- g.make_image(i) — the observation at the chosen position
- g.history[i:i + num_unroll_steps] — a list of the next num_unroll_steps actions taken after the chosen position (if they exist)
- g.make_target(i, num_unroll_steps, td_steps, g.to_play() — a list of the targets that will be used to train the neural networks. Specifically, this is a list of tuples:target_value, target_reward and target_policy.

For each observation in the batch, we will be ‘unrolling’ the position num_unroll_steps into the future using the actions provided. For the initial position, we will use the initial_inference function to predict the value, reward and policy and compare these to the target value, target reward and target policy. For subsequent actions, we will use the recurrent_inference function to predict the value, reward and policy and compare to the target value, target reward and target policy. This way, all three networks are utilised in the predictive process and therefore the weights in all three networks will be updated.

During training, the model is unrolled alongside the collected experience, at each step predicting the previously saved information: the value function v predicts the sum of observed rewards (u), the policy estimate (p) predicts the previous search outcome (π), the reward estimate r predicts the last observed reward (u). This process is the overall of recurrent_inference.

<img src="img/unroll.gif" title="" style="width: 800px;" />

In [None]:
class ReplayBuffer(object):

    def __init__(self, config: MuZeroConfig):
        self.window_size = config.window_size
        self.batch_size = config.batch_size
        self.buffer = []

    def save_game(self, game):
        if len(self.buffer) > self.window_size:
            self.buffer.pop(0)
        self.buffer.append(game)

    ...

The <code>window_size</code> parameter limits the maximum number of games stored in the buffer. In MuZero, this is set to the latest 1,000,000 games.

### Self-play (run_selfplay)

MuZero launches <code>num_actors</code> parallel game environments, that run independently.

The method plays thousands of games against itself. In the process, the games are saved
to a buffer, and then training utilizes the data from those games.
This step is the same as AlphaZero.

In [None]:
def run_selfplay(config: MuZeroConfig, storage: SharedStorage,
                 replay_buffer: ReplayBuffer):
    while True:
        network = storage.latest_network()
        game = play_game(config, network)
        replay_buffer.save_game(game)

### Monte Carlo tree search (MCTS)

All of the algorithms from AlphaGo to MuZero use Monte Carlo Tree Search (MCTS) to select the next best move.
We "play out" future scenarios from the current position, evaluate the resulting states with a value network, and
after a period of tree search, choose the action that maximises the future expected value.

The figure below illustrates Monte Carlo Tree Search in MuZero.
Starting at the current state, we use a *representation function* $h$ to map from the observation to an embedding of
the state, which is denoted $s_0$. Using the *dynamics function* $g$ and the *prediction function* $f$,
we can build a search tree over future states reachable from $s_0$ based on our past experience.

<img src="img/mcts.gif" title="MCTS" style="width: 800px;" />

The diagram below shows a comparison between the MCTS processes in AlphaZero and MuZero. We see that
while AlphaZero has only one function, an estimator of the value function called the prediction network, which relies
on a game-specific state representation obviating $h$ and the game rulels obviating $g$. 

<img src="img/alphagovsmuzero.png" title="AlphaZero VS MuZero" style="width: 800px;" />

MuZero, on the other hand, since it doesn't know the rules of the game, has to create $g$ from scratch.

### More on the AlphaZero and MuZero networks

AlphaZero has just one neural network:

1. $f: \mathcal{S} \rightarrow \mathbb{R}^{|\mathcal{A}|} \times \mathbb{R}$ outputs a probability distribution over the actions indicating their optimality, along
   with the estimated value of state $s$.

The prediction is made every time MCTS hits an unexplored leaf node, so that it can immediately assign an estimated value to the new position and a probability to each subsequent action. The values are backfilled up the tree, back to the root node, so that after many simulations, the root node has a good idea of the future value of the current state, having explored many different possible futures.

As already discussed, MuZero uses three neural networks. Although the
actual state of the environment is unknown, we model it as a simple vector of reals, i.e., $\mathcal{S} = \mathbb{R}^d$.

1. Represenation $h: \mathcal{O} \rightarrow \mathcal{S}$. Calculates an embedding of the observation intended to serve as a proxy for the actual state of the environment. Since
   $h$ is learned, in
   practice, the embedding should encode the attributes of the observation of the environment that are most useful for predicting eventual rewards.
2. Prediction $f$: Same as $f$ in AlphaZero.
3. Dynamics $g: \mathcal{S} \times \mathcal{A} \rightarrow \mathcal{S} \times \mathbb{R}$. Maps current state $s$ and chosen action $a$ to the new
   state $s'$ and immediate reward $r$.

<img src="img/muzero-network.png" title="MuZero Network" style="width: 800px;" />


MuZero uses the experience it collects when interacting with the environment to train its neural networks.
This experience includes the observations and rewards from the environment.

The process of using these networks is visualized in the following diagram.

<img src="img/experience.gif" title="" style="width: 800px;" />

In terms of the pseudocode, there are two key inference functions used to move through the MCTS tree making predictions:
- `initial_inference` for the current state. Calls $h$ followed by $f$.
- `recurrent_inference` for moving between states inside the MCTS tree. Calls $g$ followed by $f$.


In [None]:
class NetworkOutput(typing.NamedTuple):
    value: float
    reward: float
    policy_logits: Dict[Action, float]
    hidden_state: List[float]


class Network(object):

    def initial_inference(self, image) -> NetworkOutput:
        # representation + prediction function
        return NetworkOutput(0, 0, {}, [])

    def recurrent_inference(self, hidden_state, action) -> NetworkOutput:
        # dynamics + prediction function
        return NetworkOutput(0, 0, {}, [])

    def get_weights(self):
        # Returns the weights of this network.
        return []

    def training_steps(self) -> int:
        # How many steps / batches the network has been trained for.
        return 0

### Playing a game

A game is a loop. The game ends when a terminal condition is met or the maximum number of moves is reached.

When a new game is started, MCTS must be started over at the root node.

In [None]:
def play_game(config: MuZeroConfig, network: Network) -> Game:
    game = config.new_game()

    while not game.terminal() and len(game.history) < config.max_moves:
        # At the root of the search tree we use the representation function to
        # obtain a hidden state given the current observation.
        root = Node(0)
        current_observation = game.make_image(-1)
        expand_node(root, game.to_play(), game.legal_actions(),
                    network.initial_inference(current_observation))
        add_exploration_noise(config, root)

        # We then run a Monte Carlo Tree Search using only action sequences and the
        # model learned by the network.
        run_mcts(config, root, game.action_history(), network)
        action = select_action(config, len(game.history), root, network)
        game.apply(action)
        game.store_search_statistics(root)
    return game

Each Node stores key statistics relating to the number of times it has been visited <code>visit_count</code>, whose turn it is <code>to_play</code>, the predicted prior probability of choosing the action that leads to this node prior, the backfilled value sum of the node <code>node_sum</code>, its child nodes children, the hidden state it corresponds to <code>hidden_state</code> and the predicted reward received by moving to this node reward.


In [None]:
class Node(object):

    def __init__(self, prior: float):
        self.visit_count = 0
        self.to_play = -1
        self.prior = prior
        self.value_sum = 0
        self.children = {}
        self.hidden_state = None
        self.reward = 0

    def expanded(self) -> bool:
        return len(self.children) > 0

    def value(self) -> float:
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

Then we request the game to return the current observation

In [None]:
current_observation = game.make_image(-1)

Next, we expand the root node using the known legal actions provided by the game and the inference about the current observation provided by the `initial_inference` function.

In [None]:
def expand_node(node: Node, to_play: Player, actions: List[Action],
                network_output: NetworkOutput):
    node.to_play = to_play
    node.hidden_state = network_output.hidden_state
    node.reward = network_output.reward
    policy = {a: math.exp(network_output.policy_logits[a]) for a in actions}
    policy_sum = sum(policy.values())
    for action, p in policy.items():
        node.children[action] = Node(p / policy_sum)

In [None]:
expand_node(root, game.to_play(), game.legal_actions(), network.initial_inference(current_observation))

We need to add exploration noise to the root node actions, to ensure that MCTS explores a range of possible actions
rather than only exploring the action which it currently believes to be optimal. For chess, we use `root_dirichlet_alpha = 0.3`.

In [None]:
def add_exploration_noise(config: MuZeroConfig, node: Node):
    actions = list(node.children.keys())
    noise = numpy.random.dirichlet([config.root_dirichlet_alpha] * len(actions))
    frac = config.root_exploration_fraction
    for a, n in zip(actions, noise):
        node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac

In [None]:
add_exploration_noise(config, root)

### MCTS run function

As MuZero has no knowledge of the environment rules, it also has no knowledge of the bounds on the rewards that it may receive throughout the learning process. The MinMaxStats object is created to store information on the current minimum and maximum rewards encountered so that MuZero can normalise its value output accordingly. Alternatively, this can also be initialised with known bounds for a game such as chess (-1, 1).

The main MCTS loop iterates over num_simulations, where one simulation is a pass through the MCTS tree until a leaf node (i.e. unexplored node) is reached and subsequent backpropagation.
1. The <code>history</code> is initialized with the list of actions taken from the start of the game. The current node is an initialize node, so it is the root node, and currently the search path has only one node.
2. MuZero first traverses down the MCTS tree, always selecting the action with the highest UCB (Upper Confidence Bound) score.
3. The UCB score is a measure that balances the estimated value of the action Q(s,a)with a exploration bonus based on the prior probability of selecting the action P(s,a) and the number of times the action has already been selected N(s,a).

$$a^k=\text{arg} \max_a[Q(s,a) + P(s,a) \cdot \frac{\sqrt{\sum_b N(s,b)}}{1+N(s,a)}(c_1 + \log (\frac{\sum_b N(s,b)+c_2+1}{c_2}))]$$

4. the recurrent_inference function is called on the parent of the leaf node, in order to obtain the predicted reward and new hidden state (from the dynamics network) and policy and value of the new hidden state (from the prediction network).
5. the value predicted by the network is back-propagated up the tree, along the search path.

In [None]:
def select_child(config: MuZeroConfig, node: Node,
                 min_max_stats: MinMaxStats):
    _, action, child = max(
        (ucb_score(config, node, child, min_max_stats), action,
         child) for action, child in node.children.items())
    return action, child

In [None]:
def backpropagate(search_path: List[Node], value: float, to_play: Player,
                  discount: float, min_max_stats: MinMaxStats):
    for node in search_path:
        node.value_sum += value if node.to_play == to_play else -value
        node.visit_count += 1
        min_max_stats.update(node.value())

        value = node.reward + discount * value

In [None]:
def run_mcts(config: MuZeroConfig, root: Node, action_history: ActionHistory,
             network: Network):
    min_max_stats = MinMaxStats(config.known_bounds)

    for _ in range(config.num_simulations):
        history = action_history.clone()
        node = root
        search_path = [node]

        while node.expanded():
            action, node = select_child(config, node, min_max_stats)
            history.add_action(action)
            search_path.append(node)

        # Inside the search tree we use the dynamics function to obtain the next
        # hidden state given an action and the previous hidden state.
        parent = search_path[-2]
        network_output = network.recurrent_inference(parent.hidden_state,
                                                     history.last_action())
        expand_node(node, history.to_play(), history.action_space(), network_output)

        backpropagate(search_path, network_output.value, history.to_play(),
                      config.discount, min_max_stats)

In [None]:
run_mcts(config, root, game.action_history(), network)

After num_simulations passes through the tree, the process stops and an action is chosen based on the number of times each child node of the root has been visited

For the first 30 moves, the temperate of the softmax is set to 1, meaning that the probability of selection for each action is proportional to the number of times it has been visited. From the 30th move onwards, the action with the maximum number of visits is selected.

$$p_\alpha = \frac{N(\alpha)^{1/T}}{\sum_b N(b)^{1/T}}$$

Though the number of visits may feel a strange metric on which to select the final action, it isn’t really, as the UCB selection criteria within the MCTS process is designed to eventually spend more time exploring actions that it feels are truly high value opportunities, once it has sufficiently explored the alternatives early on in the process.

The chosen action is then applied to the true environment and relevant values are appended to the following lists in the gameobject.
- game.rewards — a list of true rewards received at each turn of the game
- game.history — a list of actions taken at each turn of the game
- game.child_visits — a list of action probability distributions from the root node at each turn of the game
- game.root_values — a list of values of the root node at each turn of the game

All of the game data (rewards, history, child_visits, root_values) is saved to the replay buffer and the actor is then free to start a new game.

In [None]:
def select_action(config: MuZeroConfig, num_moves: int, node: Node,
                  network: Network):
    visit_counts = [
        (child.visit_count, action) for action, child in node.children.items()
    ]
    t = config.visit_softmax_temperature_fn(
        num_moves=num_moves, training_steps=network.training_steps())
    _, action = softmax_sample(visit_counts, t)
    return action

def visit_softmax_temperature(num_moves, training_steps):
    if num_moves < 30:
        return 1.0
    else:
        return 0.0  # Play according to the max.

### The training

It first creates a new Network object (that stores randomly initialised instances of MuZero’s three neural networks) and sets the learning rate to decay based on the number of training steps that have been completed. We also create the gradient descent optimiser that will calculate the magnitude and direction of the weight updates at each training step.

The last part of this function simply loops over training_steps (=1,000,000 in the paper, for chess). At each step, it samples a batch of positions from the replay buffer and uses them to update the networks, which is saved to storage every checkpoint_interval batches (=1000).

There are therefore two finals parts we need to cover — how MuZero creates a batch of training data and how it uses this to update the weights of the three neural networks.

In [None]:
def train_network(config: MuZeroConfig, storage: SharedStorage,
                  replay_buffer: ReplayBuffer):
    network = Network()
    learning_rate = config.lr_init * config.lr_decay_rate**(
        tf.train.get_global_step() / config.lr_decay_steps)
    optimizer = tf.train.MomentumOptimizer(learning_rate, config.momentum)

    for i in range(config.training_steps):
        if i % config.checkpoint_interval == 0:
            storage.save_network(i, network)
        batch = replay_buffer.sample_batch(config.num_unroll_steps, config.td_steps)
        update_weights(optimizer, network, batch, config.weight_decay)
    storage.save_network(config.training_steps, network)

### The MuZero loss function

The loss function of MuZero is shown as:

$$ \mathcal{L}_t(\theta) = \sum_{k=0}^K \mathcal{L}^r(u_{t+k}, r_t^k)+\mathcal{L}^v (z_{t+k},v_t^k) + \mathcal{L}^p (\pi_{t+k},p_t^k + c||\theta||^2 $$

$K$ is the <code>num_unroll_steps</code> variable. There are three losses we are trying to minimise:
1. The difference between the predicted reward $k$ steps ahead of turn $t$ ($r$) and the actual reward ($u$)
2. The difference between the predicted value $k$ steps ahead of turn $t$ ($v$) and the TD target value ($z$)
3. The difference between the predicted policy $k$ steps ahead of turn $t$ ($p$) and the MCTS policy($\pi$)

These losses are summed over the rollout to generate the loss for a given position in the batch. There is also a regularisation term to penalise large weights in the network.

In [None]:
def update_weights(optimizer: tf.train.Optimizer, network: Network, batch,
                   weight_decay: float):
    loss = 0
    for image, actions, targets in batch:
        # Initial step, from the real observation.
        value, reward, policy_logits, hidden_state = network.initial_inference(
            image)
        predictions = [(1.0, value, reward, policy_logits)]

    # Recurrent steps, from action and previous hidden state.
      for action in actions:
          value, reward, policy_logits, hidden_state = network.recurrent_inference(
              hidden_state, action)
          predictions.append((1.0 / len(actions), value, reward, policy_logits))

          hidden_state = tf.scale_gradient(hidden_state, 0.5)

      for prediction, target in zip(predictions, targets):
          gradient_scale, value, reward, policy_logits = prediction
          target_value, target_reward, target_policy = target

          l = (
              scalar_loss(value, target_value) +
              scalar_loss(reward, target_reward) +
              tf.nn.softmax_cross_entropy_with_logits(
                  logits=policy_logits, labels=target_policy))

          loss += tf.scale_gradient(l, gradient_scale)

    for weights in network.get_weights():
        loss += weight_decay * tf.nn.l2_loss(weights)

    optimizer.minimize(loss)

## In lab practice (MuZero)

Because MuZero uses a lot of GPU to run, we won't ask you to do so. Download the sample code from
[this github respository](https://github.com/werner-duvaud/muzero-general) and open it in visual studio code.
You can also run in Google Colab,
following from [this link](https://stackoverflow.com/questions/51194303/how-to-run-a-python-script-in-a-py-file-from-a-google-colab-notebook)

The code from the link is the same concept as pseudo code, but it has been written to support full running in Gym environments.
The code supports multi-GPU running, and we cannot modify it to run on only one GPU for the short time, so it is not recommended to run in the shared server.

The programs that you can run in this program (without modification) are:
- Cartpole (Tested with the fully connected network)
- Lunar Lander (Tested in deterministic mode with the fully connected network)
- Gridworld (Tested with the fully connected network)
- Tic-tac-toe (Tested with the fully connected network and the residual network)
- Connect4 (Slightly tested with the residual network)
- Gomoku
- Twenty-One / Blackjack (Tested with the residual network)
- Atari Breakout

**Note**: you may need to install some extra requirements as:
- gym[classic_control]
- nevergrad
- numpy
- ray
- seaborn
- tensorboard
- torch

### Running MuZero

When you want to run the MuZero, run is as:

In [None]:
!python muzero.py

There are command line functions that you can input the number to select the games and load pretrained weight, training and playing the game.

#### Running Tensorboard to show the results

In [None]:
!tensorboard --logdir ./results

<img src="img/tensorboard.png" title="" style="width: 800px;" />

#### Check CPU/GPU/RAM in used

When running the program, you can check PC performance by opening localhost:8265

<img src="img/ray_dashboard.png" title="" style="width: 800px;" />

## In lab practice 2 (AlphaZero)

The AlphaZero is more light weight and more friendly to try! You can run on any PC without any problems.

Please download the code from [github respository](https://github.com/suragnair/alpha-zero-general)

Before running, install a library named <code>coloredlogs</code>. This library will allow you to show the text color in terminal

In [None]:
!pip install coloredlogs

### Running AlphaZero

Same as MuZero, to run the AlphaZero in **Othello** game can run by the command.

In [None]:
!python main.py

To modify the selected GPU, for example from slot 0 to 1, open the code file name <code>NNet.py</code>. Search all the file and change from <code>xxx.cuda()</code> to <code>xxx.to("cuda:1")</code>

Here is the all code which is changed from automatic cuda to "cuda:1"

In [None]:
import os
import sys
import time

import numpy as np
from tqdm import tqdm

sys.path.append('../../')
from utils import *
from NeuralNet import NeuralNet

import torch
import torch.optim as optim

from .OthelloNNet import OthelloNNet as onnet

args = dotdict({
    'lr': 0.001,
    'dropout': 0.3,
    'epochs': 10,
    'batch_size': 64,
    'cuda': torch.cuda.is_available(),
    'num_channels': 512,
})


class NNetWrapper(NeuralNet):
    def __init__(self, game):
        self.nnet = onnet(game, args)
        self.board_x, self.board_y = game.getBoardSize()
        self.action_size = game.getActionSize()

        if args.cuda:
            # self.nnet.cuda()
            self.nnet.to("cuda:1")

    def train(self, examples):
        """
        examples: list of examples, each example is of form (board, pi, v)
        """
        optimizer = optim.Adam(self.nnet.parameters())

        for epoch in range(args.epochs):
            print('EPOCH ::: ' + str(epoch + 1))
            self.nnet.train()
            pi_losses = AverageMeter()
            v_losses = AverageMeter()

            batch_count = int(len(examples) / args.batch_size)

            t = tqdm(range(batch_count), desc='Training Net')
            for _ in t:
                sample_ids = np.random.randint(len(examples), size=args.batch_size)
                boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
                boards = torch.FloatTensor(np.array(boards).astype(np.float64))
                target_pis = torch.FloatTensor(np.array(pis))
                target_vs = torch.FloatTensor(np.array(vs).astype(np.float64))

                # predict
                if args.cuda:
                    # boards, target_pis, target_vs = boards.contiguous().cuda(), target_pis.contiguous().cuda(), target_vs.contiguous().cuda()
                    boards, target_pis, target_vs = boards.contiguous().to("cuda:1"), target_pis.contiguous().to("cuda:1"), target_vs.contiguous().to("cuda:1")

                # compute output
                out_pi, out_v = self.nnet(boards)
                l_pi = self.loss_pi(target_pis, out_pi)
                l_v = self.loss_v(target_vs, out_v)
                total_loss = l_pi + l_v

                # record loss
                pi_losses.update(l_pi.item(), boards.size(0))
                v_losses.update(l_v.item(), boards.size(0))
                t.set_postfix(Loss_pi=pi_losses, Loss_v=v_losses)

                # compute gradient and do SGD step
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

    def predict(self, board):
        """
        board: np array with board
        """
        # timing
        start = time.time()

        # preparing input
        board = torch.FloatTensor(board.astype(np.float64))
        # if args.cuda: board = board.contiguous().cuda()
        if args.cuda: board = board.contiguous().to("cuda:1")
        board = board.view(1, self.board_x, self.board_y)
        self.nnet.eval()
        with torch.no_grad():
            pi, v = self.nnet(board)

        # print('PREDICTION TIME TAKEN : {0:03f}'.format(time.time()-start))
        return torch.exp(pi).data.cpu().numpy()[0], v.data.cpu().numpy()[0]

    def loss_pi(self, targets, outputs):
        return -torch.sum(targets * outputs) / targets.size()[0]

    def loss_v(self, targets, outputs):
        return torch.sum((targets - outputs.view(-1)) ** 2) / targets.size()[0]

    def save_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'):
        filepath = os.path.join(folder, filename)
        if not os.path.exists(folder):
            print("Checkpoint Directory does not exist! Making directory {}".format(folder))
            os.mkdir(folder)
        else:
            print("Checkpoint Directory exists! ")
        torch.save({
            'state_dict': self.nnet.state_dict(),
        }, filepath)

    def load_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'):
        # https://github.com/pytorch/examples/blob/master/imagenet/main.py#L98
        filepath = os.path.join(folder, filename)
        if not os.path.exists(filepath):
            raise ("No model in path {}".format(filepath))
        map_location = None if args.cuda else 'cpu'
        checkpoint = torch.load(filepath, map_location=map_location)
        self.nnet.load_state_dict(checkpoint['state_dict'])


To change the game, go to <code>main.py</code> and change the library above. In current version, the pytorch game support has only one game, othello. If you have keras or tensorflow, you can try connect4 and gobang (Go game)

In [None]:
import logging

import coloredlogs

from Coach import Coach
from othello.OthelloGame import OthelloGame as Game        # Change here
from othello.pytorch.NNet import NNetWrapper as nn         # Change here
from utils import *

### To do on your own

- Add a self play simulator that runs with a fixed model for n-loops.
- Try a different game