In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import numpy as np
import math
from random import shuffle
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

In [None]:
batch_size=64
numIters=128                                # Total number of training iterations
num_simulations=15                          # Total number of MCTS simulations to run when deciding on a move to play
numEps=100                                  # Number of full games (episodes) to run during each iteration
numItersForTrainExamplesHistory=20
num_epochs=2                                # Number of epochs of training per iteration
learning_rate=5e-4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
class Connect4Model(nn.Module):
    def __init__(self, board_size, action_size):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                ConvBlock(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
                ConvBlock(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
                ConvBlock(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
                ResidualBlock(channels=128, num_repeats=4),
                ConvBlock(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
                ResidualBlock(channels=64, num_repeats=4),
                ConvBlock(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
                ConvBlock(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1),
                nn.Flatten()
            ]
        )

        # Two heads on our network
        self.action_head = nn.Linear(in_features=board_size, out_features=action_size)
        self.value_head = nn.Linear(in_features=board_size, out_features=1)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        action_logits = self.action_head(x)
        value_logit = self.value_head(x)

        return F.softmax(action_logits, dim=1), torch.tanh(value_logit)

    def predict(self, board):
        self.eval()
        with torch.no_grad():
            pi, v = self.forward(board)

        return pi.data.cpu().numpy()[0], v.data.cpu().numpy()[0]

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
        
class ResidualBlock(nn.Module):
    def __init__(self, channels, num_repeats=1):
        super().__init__()
        self.num_repeats = num_repeats
        self.layers = nn.ModuleList()
        for _ in range(self.num_repeats):
            self.layers += [
                nn.Sequential(
                    ConvBlock(channels, channels // 2, kernel_size=1),
                    ConvBlock(channels // 2, channels, kernel_size=3, padding=1)
                )
            ]

    def forward(self, x):
        for layer in self.layers:
            x = x + layer(x)
        return x

In [None]:
def ucb_score(parent, child):
    """
    The score for an action that would transition between the parent and child.
    """
    prior_score = child.prior * math.sqrt(parent.visit_count) / (child.visit_count + 1)
    if child.visit_count > 0:
        # The value of the child is from the perspective of the opposing player
        value_score = -child.value()
    else:
        value_score = 0

    return value_score + prior_score

class Node:
    def __init__(self, prior, to_play):
        self.visit_count = 0
        self.to_play = to_play
        self.prior = prior
        self.value_sum = 0
        self.children = {}
        self.state = None

    def expanded(self):
        return len(self.children) > 0

    def value(self):
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

    def select_action(self, temperature):
        """
        Select action according to the visit count distribution and the temperature.
        """
        visit_counts = np.array([child.visit_count for child in self.children.values()])
        actions = [action for action in self.children.keys()]
        if temperature == 0:
            action = actions[np.argmax(visit_counts)]
        elif temperature == float("inf"):
            action = np.random.choice(actions)
        else:
            # See paper appendix Data Generation
            visit_count_distribution = visit_counts ** (1 / temperature)
            visit_count_distribution = visit_count_distribution / sum(visit_count_distribution)
            action = np.random.choice(actions, p=visit_count_distribution)
        return action

    def select_child(self):
        """
        Select the child with the highest UCB score.
        """
        best_score = -np.inf
        best_action = -1
        best_child = None

        for action, child in self.children.items():
            score = ucb_score(self, child)
            if score > best_score:
                best_score = score
                best_action = action
                best_child = child

        return best_action, best_child

    def expand(self, state, to_play, action_probs):
        """
        We expand a node and keep track of the prior policy probability given by neural network
        """
        self.to_play = to_play
        self.state = state
        for a, prob in enumerate(action_probs):
            if prob != 0:
                self.children[a] = Node(prior=prob, to_play=self.to_play * -1)

    def __repr__(self):
        """
        Debugger pretty print node info
        """
        prior = "{0:.2f}".format(self.prior)
        return "{} Prior: {} Count: {} Value: {}".format(self.state.__str__(), prior, self.visit_count, self.value())

class MCTS:
    def __init__(self, game, model):
        self.game = game
        self.model = model

    def run(self, model, state, to_play):
        root = Node(0, to_play)

        # EXPAND root
        action_probs, value = model.predict(torch.FloatTensor(state.reshape(1, 1, 6, 7).astype(np.float32)).to(device))
        valid_moves = self.game.get_valid_moves(state)
        action_probs = action_probs * valid_moves  # mask invalid moves
        action_probs /= np.sum(action_probs)
        root.expand(state, to_play, action_probs)

        for _ in range(num_simulations):
            node = root
            search_path = [node]

            # SELECT
            while node.expanded():
                action, node = node.select_child()
                search_path.append(node)

            parent = search_path[-2]
            state = parent.state
            # Now we're at a leaf node and we would like to expand
            # Players always play from their own perspective
            next_state, _ = self.game.get_next_state(state, player=1, action=action)
            # Get the board from the perspective of the other player
            next_state = self.game.get_canonical_board(next_state, player=-1)

            # The value of the new state from the perspective of the other player
            value = self.game.get_reward_for_player(next_state, player=1)
            if value is None:
                # If the game has not ended:
                # EXPAND
                action_probs, value = model.predict(torch.FloatTensor(next_state.reshape(1, 1, 6, 7).astype(np.float32)).to(device))
                valid_moves = self.game.get_valid_moves(next_state)
                action_probs = action_probs * valid_moves  # mask invalid moves
                action_probs /= np.sum(action_probs)
                node.expand(next_state, parent.to_play * -1, action_probs)

            self.backpropagate(search_path, value, parent.to_play * -1)

        return root

    def backpropagate(self, search_path, value, to_play):
        """
        At the end of a simulation, we propagate the evaluation all the way up the tree
        to the root.
        """
        for node in reversed(search_path):
            node.value_sum += value if node.to_play == to_play else -value
            node.visit_count += 1

In [None]:
class Connect4Game:
    def __init__(self):
        self.columns = 7
        self.rows = 6

    def get_init_board(self):
        b = np.zeros((self.rows, self.columns,), dtype=int)
        return b

    def get_board_size(self):
        return self.rows*self.columns

    def get_action_size(self):
        return self.columns

    def get_next_state(self, board, player, action):
        b = np.copy(board)
        column_b = b[:, action]
        non_zero = np.where(column_b != 0)[0]
        if non_zero.size == 0:
            i = self.rows - 1
        else:
            i = non_zero[0] - 1
        b[i, action] = player
        # Return the new game, but
        # change the perspective of the game with negative
        return (b, -player)

    def has_legal_moves(self, board):
        for index in range(self.columns):
            if board[0, index] == 0:
                return True
        return False

    def get_valid_moves(self, board):
        # All moves are invalid by default
        valid_moves = [0] * self.get_action_size()
        for index in range(self.columns):
            if board[0, index] == 0:
                valid_moves[index] = 1
        return valid_moves

    def is_win(self, board, player):
        for i in range(6):
            for j in range(4):
                if board[i, j] == board[i, j + 1] == board[i, j + 2] == board[i, j + 3] == player:
                    return True
        for i in range(3):
            for j in range(7):
                if board[i, j] == board[i + 1, j] == board[i + 2, j] == board[i + 3, j] == player:
                    return True
        for i in range(3):
            for j in range(4):
                if board[i, j] == board[i + 1, j + 1] == board[i + 2, j + 2] == board[i + 3, j + 3] == player:
                    return True
        for i in range(3, 6):
            for j in range(4):
                if board[i, j] == board[i - 1, j + 1] == board[i - 2, j + 2] == board[i - 3, j + 3] == player:
                    return True
        return False

    def get_reward_for_player(self, board, player):
        if self.is_win(board, player):
            return 1
        if self.is_win(board, -player):
            return -1
        if self.has_legal_moves(board):
            return None
        return 0

    def get_canonical_board(self, board, player):
        return player * board

In [None]:
game = Connect4Game()
board_size = game.get_board_size()
action_size = game.get_action_size()

model = Connect4Model(board_size, action_size).to(device)

In [None]:
def exceute_episode():
    train_examples = []
    current_player = 1
    state = game.get_init_board()

    while True:
        canonical_board = game.get_canonical_board(state, current_player)

        mcts = MCTS(game, model)
        root = mcts.run(model, canonical_board, to_play=1)

        action_probs = [0 for _ in range(game.get_action_size())]
        for k, v in root.children.items():
            action_probs[k] = v.visit_count

        action_probs = action_probs / np.sum(action_probs)
        train_examples.append((canonical_board, current_player, action_probs))

        action = root.select_action(temperature=0)
        state, current_player = game.get_next_state(state, current_player, action)
        reward = game.get_reward_for_player(state, current_player)

        if reward is not None:
            ret = []
            for hist_state, hist_current_player, hist_action_probs in train_examples:
                # [Board, currentPlayer, actionProbabilities, Reward]
                ret.append((hist_state, hist_action_probs, reward * ((-1) ** (hist_current_player != current_player))))
            return ret

def learn():
    for i in range(1, numIters + 1):
        print("{}/{}".format(i, numIters))

        train_examples = []
        for eps in range(numEps):
            iteration_train_examples = exceute_episode()
            train_examples.extend(iteration_train_examples)

        shuffle(train_examples)
        train(train_examples)

def train(examples):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    pi_losses = []
    v_losses = []

    for epoch in range(num_epochs):
        model.train()

        batch_idx = 0
        while batch_idx < int(len(examples) / batch_size):
            sample_ids = np.random.randint(len(examples), size=batch_size)
            boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
            boards = torch.FloatTensor(np.array(boards).astype(np.float64))
            target_pis = torch.FloatTensor(np.array(pis))
            target_vs = torch.FloatTensor(np.array(vs).astype(np.float64))

            # predict
            boards = boards.contiguous().to(device)
            target_pis = target_pis.contiguous().to(device)
            target_vs = target_vs.contiguous().to(device)

            # compute output
            out_pi, out_v = model(boards.reshape(-1, 1, 6, 7))
            l_pi = loss_pi(target_pis, out_pi)
            l_v = loss_v(target_vs, out_v)
            total_loss = l_pi + l_v

            pi_losses.append(float(l_pi))
            v_losses.append(float(l_v))

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            batch_idx += 1

        print()
        print("Policy Loss", np.mean(pi_losses))
        print("Value Loss", np.mean(v_losses))
        print("Examples:")
        print(out_pi[0].detach())
        print(target_pis[0])

def loss_pi(targets, outputs):
    loss = -(targets * torch.log(outputs)).sum(dim=1)
    return loss.mean()

def loss_v(targets, outputs):
    loss = torch.sum((targets-outputs.view(-1))**2)/targets.size()[0]
    return loss

In [None]:
learn()

In [None]:
torch.save(model.state_dict(), "drive/MyDrive/AlphaZero.pth")