# Introduction to Reinforcement Learning

Here is a simple implementation of DeepMind's AlphaZero algorithm for Connect 4.

See [this blog post](https://towardsdatascience.com/from-scratch-implementation-of-alphazero-for-connect4-f73d4554002a) for more information.

The AlphaZero architecture is probably overkill for this problem (there are 19 residual blocks consisting multiple hidden layers). I recommend stealing these ideas and simplifying the architecture.

## $Q$-Learning

Before we look at the code, let's get a general sense of the algorithm. AlphaZero is utilizing a form of $Q$-learning.

See [this Wiki](https://en.wikipedia.org/wiki/Q-learning).

I like to think about $Q$-learning with a classic text-based adventure analogy. Here's a good [example](https://www.kongregate.com/games/rete/dont-shit-your-pants).

How does this work? There are different states a player might find themselves in. For each state, there is an optimal *move* (action) which gets you closer to a desired state (winning the game / minimizing your loss). We would like to know what the optimal function $Q$ is, where $Q(\textrm{State}) = \textrm{Action}$. This function $Q$ is known as the **Policy** and this is what we want to learn.

But how is AlphaZero doing this specifically? To answer that question fully, you probably want to familiarize yourself with [AlphaGo](https://deepmind.com/research/case-studies/alphago-the-story-so-far) first. But [here](https://deepmind.com/blog/article/alphazero-shedding-new-light-grand-games-chess-shogi-and-go) is a brief overview with links to the technical articles and other resources.

## AlphaGo - The algorithm that changed everything

If you have not seen the AlphaGo documentary, I recommend it: https://www.alphagomovie.com/

# AlphaZero for Connect 4

Let's get into it$\dots$

Again, the original implementation and description can be found [here](https://towardsdatascience.com/from-scratch-implementation-of-alphazero-for-connect4-f73d4554002a).

## Imports

In [None]:
# Data
import numpy as np
import pandas as pd

import os
import pickle

# Basic
import collections
import math
import copy
import datetime
from tqdm import tqdm

# Visual
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
%matplotlib inline

import matplotlib
matplotlib.use("Agg")

from matplotlib.table import Table

# NN
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
from torch.nn.utils import clip_grad_norm_

## Board Class

In [None]:
class board():
    def __init__(self):
        self.init_board = np.zeros([6,7]).astype(str)
        self.init_board[self.init_board == "0.0"] = " "
        self.player = 0
        self.current_board = self.init_board
        
    def drop_piece(self, column):
        if self.current_board[0, column] != " ":
            return "Invalid move"
        else:
            row = 0; pos = " "
            while (pos == " "):
                if row == 6:
                    row += 1
                    break
                pos = self.current_board[row, column]
                row += 1
            if self.player == 0:
                self.current_board[row-2, column] = "O"
                self.player = 1
            elif self.player == 1:
                self.current_board[row-2, column] = "X"
                self.player = 0
    
    def check_winner(self):
        if self.player == 1:
            for row in range(6):
                for col in range(7):
                    if self.current_board[row, col] != " ":
                        # rows
                        try:
                            if self.current_board[row, col] == "O" and self.current_board[row + 1, col] == "O" and \
                                self.current_board[row + 2, col] == "O" and self.current_board[row + 3, col] == "O":
                                #print("row")
                                return True
                        except IndexError:
                            next
                        # columns
                        try:
                            if self.current_board[row, col] == "O" and self.current_board[row, col + 1] == "O" and \
                                self.current_board[row, col + 2] == "O" and self.current_board[row, col + 3] == "O":
                                #print("col")
                                return True
                        except IndexError:
                            next
                        # \ diagonal
                        try:
                            if self.current_board[row, col] == "O" and self.current_board[row + 1, col + 1] == "O" and \
                                self.current_board[row + 2, col + 2] == "O" and self.current_board[row + 3, col + 3] == "O":
                                #print("\\")
                                return True
                        except IndexError:
                            next
                        # / diagonal
                        try:
                            if self.current_board[row, col] == "O" and self.current_board[row + 1, col - 1] == "O" and \
                                self.current_board[row + 2, col - 2] == "O" and self.current_board[row + 3, col - 3] == "O"\
                                and (col-3) >= 0:
                                #print("/")
                                return True
                        except IndexError:
                            next
        if self.player == 0:
            for row in range(6):
                for col in range(7):
                    if self.current_board[row, col] != " ":
                        # rows
                        try:
                            if self.current_board[row, col] == "X" and self.current_board[row + 1, col] == "X" and \
                                self.current_board[row + 2, col] == "X" and self.current_board[row + 3, col] == "X":
                                return True
                        except IndexError:
                            next
                        # columns
                        try:
                            if self.current_board[row, col] == "X" and self.current_board[row, col + 1] == "X" and \
                                self.current_board[row, col + 2] == "X" and self.current_board[row, col + 3] == "X":
                                return True
                        except IndexError:
                            next
                        # \ diagonal
                        try:
                            if self.current_board[row, col] == "X" and self.current_board[row + 1, col + 1] == "X" and \
                                self.current_board[row + 2, col + 2] == "X" and self.current_board[row + 3, col + 3] == "X":
                                return True
                        except IndexError:
                            next
                        # / diagonal
                        try:
                            if self.current_board[row, col] == "X" and self.current_board[row + 1, col - 1] == "X" and \
                                self.current_board[row + 2, col - 2] == "X" and self.current_board[row + 3, col - 3] == "X"\
                                and (col-3) >= 0:
                                return True
                        except IndexError:
                            next
    def actions(self): # returns all possible moves
        acts = []
        for col in range(7):
            if self.current_board[0, col] == " ":
                acts.append(col)
        return acts
            

## Visualization - Helper Function

In [None]:
def view_board(np_data, fmt='{:s}', bkg_colors=['pink', 'pink']):
    data = pd.DataFrame(np_data, columns=['0','1','2','3','4','5','6'])
    fig, ax = plt.subplots(figsize=[7,7])
    ax.set_axis_off()
    tb = Table(ax, bbox=[0,0,1,1])
    nrows, ncols = data.shape
    width, height = 1.0 / ncols, 1.0 / nrows

    for (i,j), val in np.ndenumerate(data):
        idx = [j % 2, (j + 1) % 2][i % 2]
        color = bkg_colors[idx]

        tb.add_cell(i, j, width, height, text=fmt.format(val), 
                    loc='center', facecolor=color)

    for i, label in enumerate(data.index):
        tb.add_cell(i, -1, width, height, text=label, loc='right', 
                    edgecolor='none', facecolor='none')

    for j, label in enumerate(data.columns):
        tb.add_cell(-1, j, width, height/2, text=label, loc='center', 
                           edgecolor='none', facecolor='none')
    tb.set_fontsize(24)
    ax.add_table(tb)
    return fig

### Example

In [None]:
game = board()
game.drop_piece(0)
game.drop_piece(1)
game.drop_piece(1)
state = game.current_board
state

In [None]:
view_board(state)
plt.show()

## AlphaZero Ingredients

1. $\textrm{Board State} \mapsto \textrm{Policy, Value}$
    * Policy: Probability Distribution of Possible Future Moves
    * Value: Prediction of Game Outcome given Current State
        * White Wins: +1, Draw: 0, Black Wins: -1
    * This is accomplished with a neural net.
2. Think $n$-moves ahead.
    * This is accomplished with a limited tree search algorithm.
3. Learning from self-play.
    * This is accomplished by survival of the fittest.

### Deep Convolutional Residual Neural Net

* 1 Convolutional Block:
    * 128 filters ($3\times 3$ kernel, stride 1)
    * Batch Norm + ReLU
* 19 Residual Blocks:
    * 128 filters ($3\times 3$ kernel, stride 1)
    * Batch Norm + ReLU
    * 128 filters ($3\times 3$ kernel, stride 1)
    * Batch Norm + **Residual Connection** + ReLU
* 1 Output BLock:
    * Policy: 
        * Convolution of 32 filters ($1\times 1$, stride 1)
        * Batch Norm + ReLU + Linear + SoftMax
    * Value:
        * Convolution of 3 filters ($1\times 1$, stride 1)
        * Batch Norm + ReLU + Linear + ReLU + Linear + Tanh

### Monte-Carlo Tree Search

The neural networks policy distribution helps guide the tree search.

A game can be thought of as a tree in which the root is the initial state of the game and each child-branch is a possible state of continuation of play from a parent-state. However, it is impractical (and in certain games, impossible) to simply brute-force search all game states for the best possible line of play. We need to optimize the Exploration-Exploitation Tradeoff. This is done in AlphaZero by using a MCTS.

### Self-Play Evaluation

After one iteration (of an epoch checkpoint) in which the neural net is trained using MCTS self-play data, this trained neural net is then pitted against its previous version, again using MCTS guided by the respective neural net. The neural network that performs better (eg. Wins the majority of games) would then be used for the next iteration. This ensures that the net is always improving.

-----

## Board Data Class

In [None]:
class board_data(Dataset):
    def __init__(self, dataset): # dataset = np.array of (s, p, v)
        self.X = dataset[:,0]
        self.y_p, self.y_v = dataset[:,1], dataset[:,2]
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self,idx):
        return np.int64(self.X[idx].transpose(2,0,1)), self.y_p[idx], self.y_v[idx]

## ConvBlock Class

In [None]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.action_size = 7
        self.conv1 = nn.Conv2d(3, 128, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(128)

    def forward(self, s):
        s = s.view(-1, 3, 6, 7)  # batch_size x channels x board_x x board_y
        s = F.relu(self.bn1(self.conv1(s)))
        return s


## ResBlock Class

In [None]:
class ResBlock(nn.Module):
    def __init__(self, inplanes=128, planes=128, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = F.relu(self.bn1(out))
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = F.relu(out)
        return out

## OutBlock Class

In [None]:
class OutBlock(nn.Module):
    def __init__(self):
        super(OutBlock, self).__init__()
        self.conv = nn.Conv2d(128, 3, kernel_size=1) # value head
        self.bn = nn.BatchNorm2d(3)
        self.fc1 = nn.Linear(3*6*7, 32)
        self.fc2 = nn.Linear(32, 1)
        
        self.conv1 = nn.Conv2d(128, 32, kernel_size=1) # policy head
        self.bn1 = nn.BatchNorm2d(32)
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.fc = nn.Linear(6*7*32, 7)
    
    def forward(self,s):
        v = F.relu(self.bn(self.conv(s))) # value head
        v = v.view(-1, 3*6*7)  # batch_size X channel X height X width
        v = F.relu(self.fc1(v))
        v = torch.tanh(self.fc2(v))
        
        p = F.relu(self.bn1(self.conv1(s))) # policy head
        p = p.view(-1, 6*7*32)
        p = self.fc(p)
        p = self.logsoftmax(p).exp()
        return p, v

## ConnectNet Class

In [None]:
class ConnectNet(nn.Module):
    def __init__(self):
        super(ConnectNet, self).__init__()
        self.conv = ConvBlock()
        for block in range(19):
            setattr(self, "res_%i" % block,ResBlock())
        self.outblock = OutBlock()
    
    def forward(self,s):
        s = self.conv(s)
        for block in range(19):
            s = getattr(self, "res_%i" % block)(s)
        s = self.outblock(s)
        return s

## AlphaLoss Class

In [None]:
class AlphaLoss(torch.nn.Module):
    def __init__(self):
        super(AlphaLoss, self).__init__()

    def forward(self, y_value, value, y_policy, policy):
        value_error = (value - y_value) ** 2
        policy_error = torch.sum((-policy* 
                                (1e-8 + y_policy.float()).float().log()), 1)
        total_error = (value_error.view(-1).float() + policy_error).mean()
        return total_error

## Save / Load - Helper Functions

In [None]:
def save_as_pickle(filename, data):
    completeName = os.path.join("./datasets/",\
                                filename)
    with open(completeName, 'wb') as output:
        pickle.dump(data, output)

def load_pickle(filename):
    completeName = os.path.join("./datasets/",\
                                filename)
    with open(completeName, 'rb') as pkl_file:
        data = pickle.load(pkl_file)
    return data

## UCTNode Class

In [None]:
class UCTNode():
    def __init__(self, game, move, parent=None):
        self.game = game # state s
        self.move = move # action index
        self.is_expanded = False
        self.parent = parent  
        self.children = {}
        self.child_priors = np.zeros([7], dtype=np.float32)
        self.child_total_value = np.zeros([7], dtype=np.float32)
        self.child_number_visits = np.zeros([7], dtype=np.float32)
        self.action_idxes = []
        
    @property
    def number_visits(self):
        return self.parent.child_number_visits[self.move]

    @number_visits.setter
    def number_visits(self, value):
        self.parent.child_number_visits[self.move] = value
    
    @property
    def total_value(self):
        return self.parent.child_total_value[self.move]
    
    @total_value.setter
    def total_value(self, value):
        self.parent.child_total_value[self.move] = value
    
    def child_Q(self):
        return self.child_total_value / (1 + self.child_number_visits)
    
    def child_U(self):
        return math.sqrt(self.number_visits) * (
            abs(self.child_priors) / (1 + self.child_number_visits))
    
    def best_child(self):
        if self.action_idxes != []:
            bestmove = self.child_Q() + self.child_U()
            bestmove = self.action_idxes[np.argmax(bestmove[self.action_idxes])]
        else:
            bestmove = np.argmax(self.child_Q() + self.child_U())
        return bestmove
    
    def select_leaf(self):
        current = self
        while current.is_expanded:
          best_move = current.best_child()
          current = current.maybe_add_child(best_move)
        return current
    
    def add_dirichlet_noise(self,action_idxs,child_priors):
        valid_child_priors = child_priors[action_idxs] # select only legal moves entries in child_priors array
        valid_child_priors = 0.75*valid_child_priors + 0.25*np.random.dirichlet(np.zeros([len(valid_child_priors)], \
                                                                                          dtype=np.float32)+192)
        child_priors[action_idxs] = valid_child_priors
        return child_priors
    
    def expand(self, child_priors):
        self.is_expanded = True
        action_idxs = self.game.actions(); c_p = child_priors
        if action_idxs == []:
            self.is_expanded = False
        self.action_idxes = action_idxs
        c_p[[i for i in range(len(child_priors)) if i not in action_idxs]] = 0.000000000 # mask all illegal actions
        if self.parent.parent == None: # add dirichlet noise to child_priors in root node
            c_p = self.add_dirichlet_noise(action_idxs,c_p)
        self.child_priors = c_p
    
    def decode_n_move_pieces(self,board,move):
        board.drop_piece(move)
        return board
            
    def maybe_add_child(self, move):
        if move not in self.children:
            copy_board = copy.deepcopy(self.game) # make copy of board
            copy_board = self.decode_n_move_pieces(copy_board,move)
            self.children[move] = UCTNode(
              copy_board, move, parent=self)
        return self.children[move]
    
    def backup(self, value_estimate: float):
        current = self
        while current.parent is not None:
            current.number_visits += 1
            if current.game.player == 1: # same as current.parent.game.player = 0
                current.total_value += (1*value_estimate) # value estimate +1 = O wins
            elif current.game.player == 0: # same as current.parent.game.player = 1
                current.total_value += (-1*value_estimate)
            current = current.parent

## MCTS_c4

In [None]:
class DummyNode(object):
    def __init__(self):
        self.parent = None
        self.child_total_value = collections.defaultdict(float)
        self.child_number_visits = collections.defaultdict(float)

def UCT_search(game_state, num_reads,net,temp):
    root = UCTNode(game_state, move=None, parent=DummyNode())
    for i in range(num_reads):
        leaf = root.select_leaf()
        encoded_s = encode_board(leaf.game); encoded_s = encoded_s.transpose(2,0,1)
        encoded_s = torch.from_numpy(encoded_s).float().cuda()
        child_priors, value_estimate = net(encoded_s)
        child_priors = child_priors.detach().cpu().numpy().reshape(-1); value_estimate = value_estimate.item()
        if leaf.game.check_winner() == True or leaf.game.actions() == []: # if somebody won or draw
            leaf.backup(value_estimate); continue
        leaf.expand(child_priors) # need to make sure valid moves
        leaf.backup(value_estimate)
    return root

def do_decode_n_move_pieces(board,move):
    board.drop_piece(move)
    return board

def get_policy(root, temp=1):
    #policy = np.zeros([7], dtype=np.float32)
    #for idx in np.where(root.child_number_visits!=0)[0]:
    #    policy[idx] = ((root.child_number_visits[idx])**(1/temp))/sum(root.child_number_visits**(1/temp))
    return ((root.child_number_visits)**(1/temp))/sum(root.child_number_visits**(1/temp))

def MCTS_self_play(connectnet, num_games, start_idx, cpu, args, iteration):
    
    if not os.path.isdir("./datasets/iter_%d" % iteration):
        if not os.path.isdir("datasets"):
            os.mkdir("datasets")
        os.mkdir("datasets/iter_%d" % iteration)
        
    for idxx in tqdm(range(start_idx, num_games + start_idx)):
        current_board = board()
        checkmate = False
        dataset = [] # to get state, policy, value for neural network training
        states = []
        value = 0
        move_count = 0
        while checkmate == False and current_board.actions() != []:
            if move_count < 11:
                t = args.temperature_MCTS
            else:
                t = 0.1
            states.append(copy.deepcopy(current_board.current_board))
            board_state = copy.deepcopy(encode_board(current_board))
            root = UCT_search(current_board,777,connectnet,t)
            policy = get_policy(root, t); print("[CPU: %d]: Game %d POLICY:\n " % (cpu, idxx), policy)
            current_board = do_decode_n_move_pieces(current_board,\
                                                    np.random.choice(np.array([0,1,2,3,4,5,6]), \
                                                                     p = policy)) # decode move and move piece(s)
            dataset.append([board_state,policy])
            print("[Iteration: %d CPU: %d]: Game %d CURRENT BOARD:\n" % (iteration, cpu, idxx), current_board.current_board,current_board.player); print(" ")
            if current_board.check_winner() == True: # if somebody won
                if current_board.player == 0: # black wins
                    value = -1
                elif current_board.player == 1: # white wins
                    value = 1
                checkmate = True
            move_count += 1
        dataset_p = []
        for idx,data in enumerate(dataset):
            s,p = data
            if idx == 0:
                dataset_p.append([s,p,0])
            else:
                dataset_p.append([s,p,value])
        del dataset
        save_as_pickle("iter_%d/" % iteration +\
                       "dataset_iter%d_cpu%i_%i_%s" % (iteration, cpu, idxx, datetime.datetime.today().strftime("%Y-%m-%d")), dataset_p)

def run_MCTS(args, start_idx=0, iteration=0):
    net_to_play="%s_iter%d.pth.tar" % (args.neural_net_name, iteration)
    net = ConnectNet()
    cuda = torch.cuda.is_available()
    if cuda:
        net.cuda()
    
    if args.MCTS_num_processes > 1:
        mp.set_start_method("spawn",force=True)
        net.share_memory()
        net.eval()
    
        current_net_filename = os.path.join("./model_data/",\
                                        net_to_play)
        if os.path.isfile(current_net_filename):
            checkpoint = torch.load(current_net_filename)
            net.load_state_dict(checkpoint['state_dict'])
        else:
            torch.save({'state_dict': net.state_dict()}, os.path.join("./model_data/",\
                        net_to_play))
        
        processes = []
        if args.MCTS_num_processes > mp.cpu_count():
            num_processes = mp.cpu_count()
        else:
            num_processes = args.MCTS_num_processes
        
        with torch.no_grad():
            for i in range(num_processes):
                p = mp.Process(target=MCTS_self_play, args=(net, args.num_games_per_MCTS_process, start_idx, i, args, iteration))
                p.start()
                processes.append(p)
            for p in processes:
                p.join()
    
    elif args.MCTS_num_processes == 1:
        net.eval()
        
        current_net_filename = os.path.join("./model_data/",\
                                        net_to_play)
        if os.path.isfile(current_net_filename):
            checkpoint = torch.load(current_net_filename)
            net.load_state_dict(checkpoint['state_dict'])
        else:
            torch.save({'state_dict': net.state_dict()}, os.path.join("./model_data/",\
                        net_to_play))
        
        with torch.no_grad():
            MCTS_self_play(net, args.num_games_per_MCTS_process, start_idx, 0, args, iteration)

## Encoder / Decoder

In [None]:
def encode_board(board):
    board_state = board.current_board
    encoded = np.zeros([6,7,3]).astype(int)
    encoder_dict = {"O":0, "X":1}
    for row in range(6):
        for col in range(7):
            if board_state[row,col] != " ":
                encoded[row, col, encoder_dict[board_state[row,col]]] = 1
    if board.player == 1:
        encoded[:,:,2] = 1 # player to move
    return encoded

def decode_board(encoded):
    decoded = np.zeros([6,7]).astype(str)
    decoded[decoded == "0.0"] = " "
    decoder_dict = {0:"O", 1:"X"}
    for row in range(6):
        for col in range(7):
            for k in range(2):
                if encoded[row, col, k] == 1:
                    decoded[row, col] = decoder_dict[k]
    cboard = board()
    cboard.current_board = decoded
    cboard.player = encoded[0,0,2]
    return cboard

## Evaluation

In [None]:
class arena():
    def __init__(self, current_cnet, best_cnet):
        self.current = current_cnet
        self.best = best_cnet
    
    def play_round(self):
        if np.random.uniform(0,1) <= 0.5:
            white = self.current; black = self.best; w = "current"; b = "best"
        else:
            white = self.best; black = self.current; w = "best"; b = "current"
        current_board = board()
        checkmate = False
        dataset = []
        value = 0; t = 0.1
        while checkmate == False and current_board.actions() != []:
            dataset.append(copy.deepcopy(encode_board(current_board)))
            print(""); print(current_board.current_board)
            if current_board.player == 0:
                root = UCT_search(current_board,777,white,t)
                policy = get_policy(root, t); print("Policy: ", policy, "white = %s" %(str(w)))
            elif current_board.player == 1:
                root = UCT_search(current_board,777,black,t)
                policy = get_policy(root, t); print("Policy: ", policy, "black = %s" %(str(b)))
            current_board = do_decode_n_move_pieces(current_board,\
                                                    np.random.choice(np.array([0,1,2,3,4,5,6]), \
                                                                     p = policy)) # decode move and move piece(s)
            if current_board.check_winner() == True: # someone wins
                if current_board.player == 0: # black wins
                    value = -1
                elif current_board.player == 1: # white wins
                    value = 1
                checkmate = True
        dataset.append(encode_board(current_board))
        if value == -1:
            dataset.append(f"{b} as black wins")
            return b, dataset
        elif value == 1:
            dataset.append(f"{w} as white wins")
            return w, dataset
        else:
            dataset.append("Nobody wins")
            return None, dataset
    
    def evaluate(self, num_games, cpu):
        current_wins = 0
        for i in range(num_games):
            with torch.no_grad():
                winner, dataset = self.play_round(); print("%s wins!" % winner)
            if winner == "current":
                current_wins += 1
            save_as_pickle("evaluate_net_dataset_cpu%i_%i_%s_%s" % (cpu,i,datetime.datetime.today().strftime("%Y-%m-%d"),\
                                                                     str(winner)),dataset)
        print("Current_net wins ratio: %.5f" % (current_wins/num_games))
        save_as_pickle("wins_cpu_%i" % (cpu),\
                                             {"best_win_ratio": current_wins/num_games, "num_games":num_games})
        
def fork_process(arena_obj, num_games, cpu): # make arena picklable
    arena_obj.evaluate(num_games, cpu)

def evaluate_nets(args, iteration_1, iteration_2) :
    current_net="%s_iter%d.pth.tar" % (args.neural_net_name, iteration_2); best_net="%s_iter%d.pth.tar" % (args.neural_net_name, iteration_1)
    current_net_filename = os.path.join("./model_data/",\
                                    current_net)
    best_net_filename = os.path.join("./model_data/",\
                                    best_net)
    
    current_cnet = ConnectNet()
    best_cnet = ConnectNet()
    cuda = torch.cuda.is_available()
    if cuda:
        current_cnet.cuda()
        best_cnet.cuda()
    
    if not os.path.isdir("./evaluator_data/"):
        os.mkdir("evaluator_data")
    
    if args.MCTS_num_processes > 1:
        mp.set_start_method("spawn",force=True)
        
        current_cnet.share_memory(); best_cnet.share_memory()
        current_cnet.eval(); best_cnet.eval()
        
        checkpoint = torch.load(current_net_filename)
        current_cnet.load_state_dict(checkpoint['state_dict'])
        checkpoint = torch.load(best_net_filename)
        best_cnet.load_state_dict(checkpoint['state_dict'])
         
        processes = []
        if args.MCTS_num_processes > mp.cpu_count():
            num_processes = mp.cpu_count()
        else:
            num_processes = args.MCTS_num_processes
        with torch.no_grad():
            for i in range(num_processes):
                p = mp.Process(target=fork_process,args=(arena(current_cnet,best_cnet), args.num_evaluator_games, i))
                p.start()
                processes.append(p)
            for p in processes:
                p.join()
               
        wins_ratio = 0.0
        for i in range(num_processes):
            stats = load_pickle("wins_cpu_%i" % (i))
            wins_ratio += stats['best_win_ratio']
        wins_ratio = wins_ratio/num_processes
        if wins_ratio >= 0.55:
            return iteration_2
        else:
            return iteration_1
            
    elif args.MCTS_num_processes == 1:
        current_cnet.eval(); best_cnet.eval()
        checkpoint = torch.load(current_net_filename)
        current_cnet.load_state_dict(checkpoint['state_dict'])
        checkpoint = torch.load(best_net_filename)
        best_cnet.load_state_dict(checkpoint['state_dict'])
        arena1 = arena(current_cnet=current_cnet, best_cnet=best_cnet)
        arena1.evaluate(num_games=args.num_evaluator_games, cpu=0)
        
        stats = load_pickle("wins_cpu_%i" % (0))
        if stats.best_win_ratio >= 0.55:
            return iteration_2
        else:
            return iteration_1

## Play Against the Computer

In [None]:
# def play_game(net):
#     # Asks human what he/she wanna play as
#     white = None; black = None
#     while (True):
#         play_as = input("What do you wanna play as? (\"O\"/\"X\")? Note: \"O\" starts first, \"X\" starts second\n")
#         if play_as == "O":
#             black = net; break
#         elif play_as == "X":
#             white = net; break
#         else:
#             print("I didn't get that.")
#     current_board = board()
#     checkmate = False
#     dataset = []
#     value = 0; t = 0.1; moves_count = 0
#     while checkmate == False and current_board.actions() != []:
#         if moves_count <= 5:
#             t = 1
#         else:
#             t = 0.1
#         moves_count += 1
#         dataset.append(copy.deepcopy(encode_board(current_board)))
#         print(current_board.current_board); print(" ")
#         if current_board.player == 0:
#             if white != None:
#                 print("AI is thinking........")
#                 root = UCT_search(current_board,777,white,t)
#                 policy = get_policy(root, t)
#             else:
#                 while(True):
#                     col = input("Which column do you wanna drop your piece? (Enter 1-7)\n")
#                     if int(col) in [1,2,3,4,5,6,7]:
#                         policy = np.zeros([7], dtype=np.float32); policy[int(col)-1] += 1
#                         break
#         elif current_board.player == 1:
#             if black != None:
#                 print("AI is thinking.............")
#                 root = UCT_search(current_board,777,black,t)
#                 policy = get_policy(root, t)
#             else:
#                 while(True):
#                     col = input("Which column do you wanna drop your piece? (Enter 1-7)\n")
#                     if int(col) in [1,2,3,4,5,6,7]:
#                         policy = np.zeros([7], dtype=np.float32); policy[int(col)-1] += 1
#                         break
#         current_board = do_decode_n_move_pieces(current_board,\
#                                                 np.random.choice(np.array([0,1,2,3,4,5,6]), \
#                                                                  p = policy)) # decode move and move piece(s)
#         if current_board.check_winner() == True: # someone wins
#             if current_board.player == 0: # black wins
#                 value = -1
#             elif current_board.player == 1: # white wins
#                 value = 1
#             checkmate = True
#     dataset.append(encode_board(current_board))
#     print(current_board.current_board); print(" ")
#     if value == -1:
#         if play_as == "O":
#             dataset.append(f"AI as black wins"); print("YOU LOSE!!!!!!!")
#         else:
#             dataset.append(f"Human as black wins"); print("YOU WIN!!!!!!!")
#         return "black", dataset
#     elif value == 1:
#         if play_as == "O":
#             dataset.append(f"Human as white wins"); print("YOU WIN!!!!!!!!!!!")
#         else:
#             dataset.append(f"AI as white wins"); print("YOU LOSE!!!!!!!")
#         return "white", dataset
#     else:
#         dataset.append("Nobody wins"); print("DRAW!!!!!")
#         return None, dataset

# if __name__ == "__main__":
#     best_net="c4_current_net_trained1_iter6.pth.tar"
#     best_net_filename = os.path.join("./model_data/",\
#                                     best_net)
#     best_cnet = ConnectNet()
#     cuda = torch.cuda.is_available()
#     if cuda:
#         best_cnet.cuda()
#     best_cnet.eval()
#     checkpoint = torch.load(best_net_filename)
#     best_cnet.load_state_dict(checkpoint['state_dict'])
#     play_again = True
#     while(play_again == True):
#         play_game(best_cnet)
#         while(True):
#             again = input("Do you wanna play again? (Y/N)\n")
#             if again.lower() in ["y", "n"]:
#                 if again.lower() == "n":
#                     play_again = False; break
#                 else:
#                     break
                

## Train

In [None]:
def load_state(net, optimizer, scheduler, args, iteration, new_optim_state=True):
    """ Loads saved model and optimizer states if exists """
    base_path = "./model_data/"
    checkpoint_path = os.path.join(base_path, "%s_iter%d.pth.tar" % (args.neural_net_name, iteration))
    start_epoch, checkpoint = 0, None
    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
    if checkpoint != None:
        if (len(checkpoint) == 1) or (new_optim_state == True):
            net.load_state_dict(checkpoint['state_dict'])
        else:
            start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
    return start_epoch


def load_results(iteration):
    """ Loads saved results if exists """
    losses_path = "./model_data/losses_per_epoch_iter%d.pkl" % iteration
    if os.path.isfile(losses_path):
        losses_per_epoch = load_pickle("losses_per_epoch_iter%d.pkl" % iteration)
    else:
        losses_per_epoch = []
    return losses_per_epoch


def train(net, dataset, optimizer, scheduler, start_epoch, cpu, args, iteration):
    torch.manual_seed(cpu)
    cuda = torch.cuda.is_available()
    net.train()
    criterion = AlphaLoss()
    
    train_set = board_data(dataset)
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=False)
    losses_per_epoch = load_results(iteration + 1)
    
    update_size = len(train_loader)//10
    print("Update step size: %d" % update_size)
    for epoch in range(start_epoch, args.num_epochs):
        total_loss = 0.0
        losses_per_batch = []
        for i,data in enumerate(train_loader,0):
            state, policy, value = data
            state, policy, value = state.float(), policy.float(), value.float()
            if cuda:
                state, policy, value = state.cuda(), policy.cuda(), value.cuda()
            policy_pred, value_pred = net(state) # policy_pred = torch.Size([batch, 4672]) value_pred = torch.Size([batch, 1])
            loss = criterion(value_pred[:,0], value, policy_pred, policy)
            loss = loss/args.gradient_acc_steps
            loss.backward()
            clip_grad_norm_(net.parameters(), args.max_norm)
            if (epoch % args.gradient_acc_steps) == 0:
                optimizer.step()
                optimizer.zero_grad()
                
            total_loss += loss.item()
            if i % update_size == (update_size - 1):    # print every update_size-d mini-batches of size = batch_size
                losses_per_batch.append(args.gradient_acc_steps*total_loss/update_size)
                print('[Iteration %d] Process ID: %d [Epoch: %d, %5d/ %d points] total loss per batch: %.3f' %
                      (iteration, os.getpid(), epoch + 1, (i + 1)*args.batch_size, len(train_set), losses_per_batch[-1]))
                print("Policy (actual, predicted):",policy[0].argmax().item(),policy_pred[0].argmax().item())
                print("Policy data:", policy[0]); print("Policy pred:", policy_pred[0])
                print("Value (actual, predicted):", value[0].item(), value_pred[0,0].item())
                #print("Conv grad: %.7f" % net.conv.conv1.weight.grad.mean().item())
                #print("Res18 grad %.7f:" % net.res_18.conv1.weight.grad.mean().item())
                print(" ")
                total_loss = 0.0
        
        scheduler.step()
        if len(losses_per_batch) >= 1:
            losses_per_epoch.append(sum(losses_per_batch)/len(losses_per_batch))
        if (epoch % 2) == 0:
            save_as_pickle("losses_per_epoch_iter%d.pkl" % (iteration + 1), losses_per_epoch)
            torch.save({
                    'epoch': epoch + 1,\
                    'state_dict': net.state_dict(),\
                    'optimizer' : optimizer.state_dict(),\
                    'scheduler' : scheduler.state_dict(),\
                }, os.path.join("./model_data/",\
                    "%s_iter%d.pth.tar" % (args.neural_net_name, (iteration + 1))))
        '''
        # Early stopping
        if len(losses_per_epoch) > 50:
            if abs(sum(losses_per_epoch[-4:-1])/3-sum(losses_per_epoch[-16:-13])/3) <= 0.00017:
                break
        '''
    fig = plt.figure()
    ax = fig.add_subplot(222)
    ax.scatter([e for e in range(start_epoch, (len(losses_per_epoch) + start_epoch))], losses_per_epoch)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss per batch")
    ax.set_title("Loss vs Epoch")
    plt.savefig(os.path.join("./model_data/", "Loss_vs_Epoch_iter%d_%s.png" % ((iteration + 1), datetime.datetime.today().strftime("%Y-%m-%d"))))
    plt.show()

    
def train_connectnet(args, iteration, new_optim_state):
    # gather data
    data_path="./datasets/iter_%d/" % iteration
    datasets = []
    for idx,file in enumerate(os.listdir(data_path)):
        filename = os.path.join(data_path,file)
        with open(filename, 'rb') as fo:
            datasets.extend(pickle.load(fo, encoding='bytes'))
    datasets = np.array(datasets)
    
    # train net
    net = ConnectNet()
    cuda = torch.cuda.is_available()
    if cuda:
        net.cuda()
    optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.8, 0.999))
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,100,150,200,250,300,400], gamma=0.77)
    start_epoch = load_state(net, optimizer, scheduler, args, iteration, new_optim_state)
    
    train(net, datasets, optimizer, scheduler, start_epoch, 0, args, iteration)


## Main Pipeline

In [None]:
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--iteration", type=int, default=0, help="Current iteration number to resume from")
parser.add_argument("--total_iterations", type=int, default=1000, help="Total number of iterations to run")
parser.add_argument("--MCTS_num_processes", type=int, default=5, help="Number of processes to run MCTS self-plays")
parser.add_argument("--num_games_per_MCTS_process", type=int, default=120, help="Number of games to simulate per MCTS self-play process")
parser.add_argument("--temperature_MCTS", type=float, default=1.1, help="Temperature for first 10 moves of each MCTS self-play")
parser.add_argument("--num_evaluator_games", type=int, default=100, help="No of games to play to evaluate neural nets")
parser.add_argument("--neural_net_name", type=str, default="cc4_current_net_", help="Name of neural net")
parser.add_argument("--batch_size", type=int, default=32, help="Training batch size")
parser.add_argument("--num_epochs", type=int, default=300, help="No of epochs")
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
parser.add_argument("--gradient_acc_steps", type=int, default=1, help="Number of steps of gradient accumulation")
parser.add_argument("--max_norm", type=float, default=1.0, help="Clipped gradient norm")
args = parser.parse_args('')

for i in range(args.iteration, args.total_iterations): 
    run_MCTS(args, start_idx=0, iteration=i)
    train_connectnet(args, iteration=i, new_optim_state=True)
    if i >= 1:
        winner = evaluate_nets(args, i, i + 1)
        counts = 0
        while (winner != (i + 1)):
            run_MCTS(args, start_idx=(counts + 1)*args.num_games_per_MCTS_process, iteration=i)
            counts += 1
            train_connectnet(args, iteration=i, new_optim_state=True)
            winner = evaluate_nets(args, i, i + 1)

# Resources / References

## Game Theory

* https://en.wikipedia.org/wiki/Game_theory
* https://en.wikipedia.org/wiki/Complete_information
* https://en.wikipedia.org/wiki/Perfect_information

## Residual Networks

* https://towardsdatascience.com/residual-network-implementing-resnet-a7da63c7b278

## Monte-Carlo Tree Search

* https://en.wikipedia.org/wiki/Monte_Carlo_tree_search

## Upper Confidence bounds applied to Trees

* https://www.chessprogramming.org/UCT

## Various RL Algorithms (in PyTorch)

* https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch

## Curated Resources on RL

* https://github.com/aikorea/awesome-rl
* https://github.com/tigerneil/awesome-deep-rl

## AlphaGo

* https://en.wikipedia.org/wiki/AlphaGo

## AlphaZero

* https://en.wikipedia.org/wiki/AlphaZero

## Inverse Reinforcement Learning (in PyTorch)

* https://github.com/reinforcement-learning-kr/lets-do-irl

## Flappy Bird RL Tutorial (in PyTorch)

* https://www.toptal.com/deep-learning/pytorch-reinforcement-learning-tutorial

## AlphaZero for Connect 4 (in PyTorch)

* https://towardsdatascience.com/from-scratch-implementation-of-alphazero-for-connect4-f73d4554002a

## Collection of Optimizers for PyTorch

* https://github.com/jettify/pytorch-optimizer

# Further Reading

## Solving Games with Imperfect Information

* https://www.aaai.org/ocs/index.php/AAAI/AAAI14/paper/viewFile/8407/8476
* https://en.wikipedia.org/wiki/Libratus
* https://www.youtube.com/watch?v=2dX0lwaQRX0
* https://arxiv.org/pdf/1811.00164.pdf

## Counter-Factual Regret Minimization (for Poker) in Python

* https://github.com/tansey/pycfr

## Inverse Reinforcement Learning

* https://ai.stanford.edu/~ang/papers/icml00-irl.pdf
* https://arxiv.org/pdf/1806.06877.pdf
* https://en.wikipedia.org/wiki/Apprenticeship_learning
* https://en.wikipedia.org/wiki/Reinforcement_learning#Inverse_reinforcement_learning
* https://jangirrishabh.github.io/2016/07/09/virtual-car-IRL/