# AlphaGo Zero
---
For this exercise we will implement AlphaZero which is a more generalized form of AlphaGO and train it on a game of 
connect 4. We will first play against AlphaZero when it has not trained at all. We will then let it train for a few training 
cycles and play it again to observe how it has improved. 

# 1. AlphaZero Configuration
---

First lets make a class to keep all of our hyper parameters in one spot. This has been done for you.

In [0]:
 
class AlphaZeroConfig(object):
    """
    This holds the configuration parameters
    """
    def __init__(self):
        # Self-Play ==
        self.max_moves = 42
        self.num_simulations = 25

        # Root prior exploration noise.
        self.root_dirichlet_alpha = 0.3         # for chess, 0.03 for Go and 0.15 for shogi.
        self.root_exploration_fraction = 0.25

        # UCB formula
        self.pb_c_base = 19652
        self.pb_c_init = 1.25

        # Training ==
        self.self_play_games = 30       # number of selfplay games per cycle
        self.training_steps = int(40)   # number of times we perform gradient descent
        self.batch_size = 50            # size of training batch
        self.cycles = 5                 # number of policy iterations to do

        self.weight_decay = 1e-4
        self.momentum = 0.9
        self.learning_rate = 5e-4

# 2. Game Definition
---
Next lets set up the connect 4 game. This part has been done for you. 

In [0]:
import math
import numpy
from typing import List
import numpy as np
from torch.utils.data import TensorDataset
import torch
import torch.nn as nn

if torch.cuda.is_available():  
  device = "cuda:0" 
else:  
  device = "cpu" 

class Node(object):

    def __init__(self, prior: float):  # prior = how good the network thought it would be
        self.visit_count = 0
        self.to_play = -1
        self.prior = prior
        self.value_sum = 0
        self.children = {}

    def expanded(self):
        return len(self.children) > 0

    def value(self):
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

class Game(object):

    def __init__(self, history=None):
        # Connect 4 specific ===
        self._num_rows = 6
        self._num_cols = 7

        self._winner = None

        # Masks to "convolve" over board and detect a winner
        self._win_masks = []
        # Horizontal wins
        for i in range(4):
            mask = np.zeros((4, 4), dtype=np.bool)
            mask[i, :] = True
            self._win_masks.append(mask)
        # Vertical wins
        for j in range(4):
            mask = np.zeros((4, 4), dtype=np.bool)
            mask[:, j] = True
            self._win_masks.append(mask)
        # Diagonal wins
        down = np.zeros((4, 4), dtype=np.bool)
        for i, j in zip(range(4), range(4)):
            down[i, j] = True
        self._win_masks.append(down)
        up = np.zeros((4, 4), dtype=np.bool)
        for i, j in zip(reversed(range(4)), range(4)):
            up[i, j] = True
        self._win_masks.append(up)

        # All games will have these ===
        self.history = history or []
        self.child_visits = []
        self.num_actions = self._num_cols  # 7 for connect 4, 512 for chess/shogi, and 722 for Go.

    def terminal(self):
        """
        returns bool if the game is finished or not
        """
        if self._winner is not None or len(self.history) == 42:
            return True

        image = self.make_image(len(self.history))
        # check for wins from the bottom of the board up. Wins are more likely to appear there.
        for i in reversed(range(self._num_rows - 3)):
            for j in range(self._num_cols - 3):
                for mask in self._win_masks:
                    for player in range(2):
                        test = image[player, i:i + 4, j:j + 4][mask]
                        if np.alltrue(test == 1):
                            self._winner = player
                            return True

        return False

    def terminal_value(self, to_play):
        """
        The result of the game from the player that's going to_play? If player 1
        won then and to_play is 1 then return 1 if to_play is 2 then return -1?
        """

        # call just to ensure that state is set
        self.terminal()

        if self._winner is None and len(self.history) == 42:
            return 0
        if to_play == self._winner:
            return 1
        else:
            return -1

    def legal_actions(self):
        image = self.make_image(len(self.history))
        return [j for j in range(self._num_cols) if image[0, 0, j] == 0 and image[1, 0, j] == 0]

    def clone(self):
        return Game(list(self.history))

    def apply(self, action: int):
        self.history.append(action)

    def store_search_statistics(self, root: Node):
        sum_visits = sum(child.visit_count for child in iter(root.children.values()))
        self.child_visits.append([
            root.children[a].visit_count / sum_visits if a in root.children else 0
            for a in range(self.num_actions)
        ])

    def make_image(self, state_index: int):
        """
        returns what the game looked like at state_index i
        """
        player_0 = np.zeros((self._num_rows, self._num_cols), dtype=numpy.float)
        player_1 = np.zeros((self._num_rows, self._num_cols), dtype=numpy.float)
        for move_i, move in enumerate(self.history[:state_index+1]):
            for row in reversed(range(self._num_rows)):
                if player_0[row, move] == 0 and player_1[row, move] == 0:
                    if move_i % 2 == 0:
                        player_0[row, move] = 1
                    if move_i % 2 == 1:
                        player_1[row, move] = 1
                    break

        to_play = (state_index + 1) % 2 * np.ones((self._num_rows, self._num_cols), dtype=numpy.float)

        return np.array([player_0, player_1, to_play], dtype=numpy.float)

    def make_target(self, state_index: int):
        """
        returns the nural network target i.e. what the NN should be gessing given the image
        """
        return (self.terminal_value(state_index % 2),  # state_index % 2 will always be who's playing
                self.child_visits[state_index])

    def to_play(self):
        """
        Return the player that is about to play
        """
        return len(self.history) % 2

    def __str__(self):
        board_state = self.make_image(len(self.history))

        out = ""
        for i in range(self._num_rows):
            out += f"{i}|"
            for j in range(self._num_cols):
                if board_state[0, i, j] == 1:
                    out += " ○ "
                elif board_state[1, i, j] == 1:
                    out += " ● "
                else:
                    out += "   "
            out += "|\n"

        out += "  "
        for j in range(self._num_cols):
            out += f" \u0305{j} "
        return out

# 3. One Network Two Heads
---
Two heads are smarter than one right? Lets implement the tow headed Neural Network. Recall that AlphaGo Zero uses Convolutional 
ResNet architecture that supplies features to two convolutional networks. One that outputs a probability distribution over all possible moves $(p)$ and another that 
outputs a single scalar value $(v)$ representing the value of the current state. 

The neural network is defined as:  
$$f_\theta (s) = (\mathbf{p,v})$$  

The game board is 6 spaces tall and 7 spaces wide. That means that we have 7 possible moves to make, hint....

In [0]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        """
        Convolution Block
        """
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True)
        )

        """
        ResNet Block
        """
        self.res_block = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=128),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=128)
        )

        """
        Value Head
        """
        self.value_convolv = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=3, kernel_size=1),
            nn.BatchNorm2d(num_features=3),
            nn.ReLU(inplace=True),
        )

        # TODO: The value head outputs a what?
        self.value_linear = nn.Sequential(
            nn.Linear(in_features=126, out_features=32),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=32, out_features=???),
            nn.Tanh()
        )


        """
        Policy Head
        """
        self.policy_convolv = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=32, kernel_size=1),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(inplace=True)
        )

        # TODO: the policy network outputs what?
        # how many moves can we make? hint in section header
        self.policy_linear = nn.Sequential(
            nn.Linear(6*7*32, ???),
            nn.LogSoftmax(dim=1)
        )

    def inference(self, image):
        """
        Use this for the evaluate() function in the next section, hint...
        The game class has some nifty functions that feed this palatable images...
        """
        image = torch.from_numpy(image)
        image = image.to(torch.float)
        image = image.unsqueeze(0)

        p, v = self.forward(image)

        return float(v.squeeze().detach()), p.squeeze().detach().cpu().numpy()


    def forward(self, x):
        """Perform forward."""

        # you can mess with the number of residual blocks here if youd like
        # the paper uses 20
        num_blocks = 10

        x = x.to(device)
        """
        ResNet
        """
        x = self.conv_block(x)
        for i in range(num_blocks):
            residual = x
            x = self.res_block(x)
            x += residual
            x = nn.functional.relu(x, inplace=True)

        """
        Value Head
        """
        v = self.value_convolv(x)
        v = v.view(-1, 3 * 6 * 7)
        v = self.value_linear(v)

        """
        Policy Head
        """
        p = self.policy_convolv(x)
        p = p.view(-1, 6 * 7 * 32)
        p = self.policy_linear(p)

        return p, v
    

# 4. The training pipeline
---
AlphaZero training is split into two independent parts: Network training and self-play data generation.
These two parts only communicate by transferring the latest network checkpoint
from the training to the self-play, and the finished games from the self-play
to the training.

In [0]:
def alphazero(config: AlphaZeroConfig, network: Net):

    # TODO: Here we'll have to do something! Remember the basic steps of the 
    # algorithm:
    #     1: Create training data using the current neural network, this evaluates and improves our policy
    #     2: Improve our policy by training our nural network
    #     3: Repete!
    #
    # We have provided helper functions that may be of use to you
    for i in range(config.cycles):
        print(f"self play {i} of {config.cycles}")
        network.eval()
        games = ??? # TODO 
        print(f"train network {i} of {config.cycles}")
        network.train()
        train_network(???, ???) # TODO

    return network

# Each self-play job is independent of all others; it takes the latest network
# snapshot, produces a game and makes it available to the training job by
# writing it to a shared replay buffer.
def run_selfplay(config: AlphaZeroConfig, network: Net):
    games = []
    for i in range(config.self_play_games):  
        if i % 10 == 0:
            print(f"game {i} of {config.self_play_games}")
        game = play_game(config, network)
        games.append(game)
    return games


# Each game is produced by starting at the initial board position, then
# repeatedly executing a Monte Carlo Tree Search to generate moves until the end
# of the game is reached.
def play_game(config: AlphaZeroConfig, network: Net):
    game = Game()
    while not game.terminal() and len(game.history) < config.max_moves:
        action, root = run_mcts(config, game, network)
        game.apply(action)
        game.store_search_statistics(root)
    return game


# Core Monte Carlo Tree Search algorithm.
# To decide on an action, we run N simulations, always starting at the root of
# the search tree and traversing the tree according to the UCB formula until we
# reach a leaf node.
def run_mcts(config: AlphaZeroConfig, game: Game, network: Net):
    root = Node(0)
    # Populate child nodes AKA the states that the actions available at this
    # states would take you too
    evaluate(root, game, network)
    add_exploration_noise(config, root)

    for i in range(config.num_simulations):
        node = root
        scratch_game = game.clone()
        search_path = [node]

        while node.expanded():
            # Here we take one step down our search tree towards a win or loss. Note
            # that we are resetting the node variable here to be the state that our
            # game picked given the action we took.
            #
            # On the first run all child nodes will not be expanded, so we'll only
            # take one step before backpropatagating back up the tree.
            action, node = select_child(config, node)
            scratch_game.apply(action)
            search_path.append(node)

        value = evaluate(node, scratch_game, network)
        backpropagate(search_path, value, scratch_game.to_play())
    return select_action(config, game, root), root


def select_action(config: AlphaZeroConfig, game: Game, root: Node):
    # This is where we would do a softmax sample for the first 30 moves then
    # turn down the temperature, our game is simple enough that we will just 
    # always pick the best computed move.
    visit_counts = [(child.visit_count, action)
                    for action, child in iter(root.children.items())]
    _, action = max(visit_counts)
    return action


# Select the child with the highest UCB score.
def select_child(config: AlphaZeroConfig, node: Node):
    """
    Return the child node, i.e. action to take, that UCB likes best
    """
    _, action, child = max((ucb_score(config, node, child), action, child)
                           for action, child in iter(node.children.items()))
    return action, child


# The score for a node is based on its value, plus an exploration bonus based on
# the prior.
def ucb_score(config: AlphaZeroConfig, parent: Node, child: Node):
    pb_c = math.log((parent.visit_count + config.pb_c_base + 1) /
                    config.pb_c_base) + config.pb_c_init
    pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

    prior_score = pb_c * child.prior
    value_score = child.value()
    return prior_score + value_score


# We use the neural network to obtain a value and policy prediction.
def evaluate(node: Node, game: Game, network: Net):
    # TODO:
    # Here we need to populate the input nodes's children list. Note that we
    # only want to populate nodes for legal next actions. The Node class takes
    # a prior on construction Node(prior_probability). Where will we get these
    # priors? A network is passed in, that might be useful! Note: that the
    # network returns a value and policy logits so something needs to be done
    # to convert to proper probabilities (maybe softmax, ok definitely
    # softmax). This function is also supposed to return the value, which you
    # might also be able to get from the neural network.
    #
    # It would probably be helpful to know that the policy returned from the NN
    # should be in the order of the columns of our connect 4 board. 
    # i.e. policy_logits[0] ∝ how much our network likes column 1. 
    
    value, policy_logits = ???  # TODO: take a look at back at the NN for a hint
                                # the game class may also have some useful functions for this

    # Expand the node.
    node.to_play = game.to_play() 
    policy = {a: math.exp(policy_logits[a]) for a in game.legal_actions()}
    policy_sum = sum(iter(policy.values()))
    for action, p in iter(policy.items()): 
        # this is just softmax, notice the math.exp 3 lines up
        node.children[action] = Node(p / policy_sum) 
    return ??? # TODO: what are we returning from this?


# At the end of a simulation, we propagate the evaluation all the way up the
# tree to the root.
def backpropagate(search_path: List[Node], value: float, to_play):
    for node in search_path:
        node.value_sum += value if node.to_play == to_play else (1 - value)
        node.visit_count += 1


# At the start of each search, we add dirichlet noise to the prior of the root
# to encourage the search to explore new actions.
def add_exploration_noise(config: AlphaZeroConfig, node: Node):
    """
    Modifies the priors stored in nodes children with dirichlet noise whatever
    that is
    """
    actions = node.children.keys()
    noise = numpy.random.gamma(config.root_dirichlet_alpha, 1, len(actions))
    frac = config.root_exploration_fraction
    for a, n in zip(actions, noise):
        node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac

def create_data_loader(config: AlphaZeroConfig, games: List[Game]):
    game_pos = [(g, i) for g in games for i in range(len(g.history))]

    image = np.array([g.make_image(i) for (g, i) in game_pos], dtype=np.float)
    image = torch.from_numpy(image)
    image = image.to(torch.float)

    policy_target = np.array([g.make_target(i)[1] for (g, i) in game_pos])
    policy_target = torch.from_numpy(policy_target)
    policy_target = policy_target.to(torch.float)

    value_target = np.array([g.make_target(i)[0] for (g, i) in game_pos])
    value_target = torch.from_numpy(value_target)
    value_target = value_target.to(torch.float)

    batch_data = TensorDataset(image, policy_target, value_target)
    return torch.utils.data.DataLoader(dataset=batch_data,
                                        batch_size=config.batch_size,
                                        shuffle=True)

# 5. To improve is to change, to be perfect is to change often
--- 
Bonus points if you can tell me who said this without googleing it. =)

Recall that the loss function is:
$$l = (z - \mathbf{v})^2 - \pi^T log(\mathbf{p}) + c||\theta||^2$$

This is the mean-squared error for the target value and predicted value and binary cross entropy for the target policy and predicted policy. 

The last term is the L2 regularization. We can take care of this in the optimizer and it is known as weight_decay. Our config class has a field for this parameter. Hint hint....

In [0]:
def train_network(config: AlphaZeroConfig, games: List[Game]):
    
    # TODO: add the L2 regularization here. Look at section heading for hint.
    # Weight decay takes care of our L2 regularization so it doesn't need to be in the loss function
    optimizer = torch.optim.SGD(
        network.parameters(),
        lr=config.learning_rate,
        momentum=config.momentum,
        ???
    )


    for i in range(config.training_steps): #(config.training_steps):
        data_loader = create_data_loader(config, games)
        update_weights(optimizer, network, data_loader, i)


def update_weights(optimizer, network, data_loader, batch_num):
    # Loop over each subset of data
    for image, policy_target, value_target in data_loader:
        # Zero out the optimizer's gradient buffer
        optimizer.zero_grad()
        
        # TODO: get the policy and the value from the network
        policy, value = ???
        
        # convert data to correct type
        policy = policy.exp()
        value = value.squeeze()

        # TODO: Compute the loss here
        # for the value_target and policy_target add .to(device) to make the tensors happy because 
        # we like happy tensors. The value and the policy do not need it, they are happy tensors already.
        # Also nn.functional has nifty functions for computing loss
        value_loss = ???
        policy_loss = ???

        loss = value_loss + policy_loss

        # Use backpropagation to compute the derivative of the loss with respect to the parameters
        loss.backward()

        # Use the derivative information to update the parameters
        optimizer.step()
    print("Batch: %d    Loss: %f" % (batch_num, loss))


# 6. Challenge Accepted!
---
Making an algorithm that can play against itself is cool and all but do you think your massive human brain can be it? To do so we need to make it possible to play against the singularity. This part has been done for you.

In [0]:
def get_human_action(i: int, game: Game):
    while True:
        print(f"Player {i} choose move please: ", end='')
        human_action = input()
        try:
            if int(human_action) not in game.legal_actions():
                print("illegal action")
            else:
                return int(human_action)
        except ValueError:
            print("illegal action")

def interactive_game(config: AlphaZeroConfig, network: Net, player_1_human=False):
    play_again = 'y'
    while play_again == 'y':
        game = Game()
        print(game)
        while not game.terminal():
            if player_1_human:
                for i in range(2):
                    game.apply(get_human_action(i, game))
                    print(game)
                    if game.terminal():
                        break
            else:
                game.apply(get_human_action(0, game))
                print(game)

                ai_action, _ = run_mcts(config, game, network)
                print(f"ai chooses {ai_action}")
                game.apply(ai_action)
                print(game)
        win_string = {-1: "lost", 1: "won", 0: "tied"}
        print(f"player 0 {win_string[game.terminal_value(0)]}")
        print(f"player 1 {win_string[game.terminal_value(1)]}")
        while True:
            print("Play again? y or n?", end='')
            play_again = input()
            try:
                if play_again != 'y' and play_again != 'n':
                    print("I didn't understand")
                else:
                    break
            except ValueError:
                print("illegal action")

# 7. Let the Singularity Begin!
---
Now try playing the game with the untrained network. If you mess with the hyper parameters in the config class make sure to rerun this cell.

In [12]:

print("Device: %s" % device)
network = Net().to(device)
config = AlphaZeroConfig()
interactive_game(config, network)


Device: cuda:0


# 8. I need more data to beat you human
---
Now train the network for a few cycles and observe its change in behavior

In [13]:
# this will train the network on self play games
alphazero(config, network)
interactive_game(config, network)

self play 0 of 5
game 0 of 30
game 10 of 30
game 20 of 30
train network 0 of 5
Batch: 0    Loss: 1.261694
Batch: 1    Loss: 1.116446
Batch: 2    Loss: 1.341696
Batch: 3    Loss: 0.909902
Batch: 4    Loss: 1.124683
Batch: 5    Loss: 1.176385
Batch: 6    Loss: 0.939910
Batch: 7    Loss: 1.149868
Batch: 8    Loss: 0.891490
Batch: 9    Loss: 1.200505
Batch: 10    Loss: 1.013344
Batch: 11    Loss: 0.928823
Batch: 12    Loss: 0.666806
Batch: 13    Loss: 0.627564
Batch: 14    Loss: 0.985499
Batch: 15    Loss: 0.615684
Batch: 16    Loss: 0.554353
Batch: 17    Loss: 0.730892
Batch: 18    Loss: 0.651778
Batch: 19    Loss: 0.558712
Batch: 20    Loss: 0.582786
Batch: 21    Loss: 0.609675
Batch: 22    Loss: 0.466753
Batch: 23    Loss: 0.523876
Batch: 24    Loss: 0.704813
Batch: 25    Loss: 0.547899
Batch: 26    Loss: 0.490743
Batch: 27    Loss: 0.487145
Batch: 28    Loss: 0.389103
Batch: 29    Loss: 0.394802
Batch: 30    Loss: 0.924192
Batch: 31    Loss: 0.693056
Batch: 32    Loss: 0.501567
Batch: 

# Question
---
Each time we start training the NN the error increases from the previous gradient descent steps. Is this an artifact? or why does this make sense?

# 9. Round Two!
---
Ok our algorithm just turned bits into strait gains. Lets try to play it again.

In [0]:
interactive_game(config, network)


# 10. Final Thoughts
---
Ok so its not a connect 4 champ in a few cycles but it does get better. Time permitting, run another few cycles to see how good you can get it.   