# Notes

Env based on: https://pettingzoo.farama.org/environments/classic/connect_four/

# Imports

In [364]:
%load_ext autoreload
%autoreload 2

import numpy as np
from collections import defaultdict, Counter, deque
import torch
import copy
from torch.utils.tensorboard import SummaryWriter

import connect_four

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [365]:
class MySummaryWriter(SummaryWriter):
    def __init__(self):
        super().__init__()
        self.points_cnt = Counter()

    def append_scalar(self, name, value):
        step = self.points_cnt[name]
        self.points_cnt[name] += 1
        self.add_scalar(name, value, step)


TENSORBOARD = MySummaryWriter()

# Show environment

In [3]:
def show():
    env = connect_four.ConnectFour()

    state, mask = env.reset()

    for step in range(50):
        I = np.where(mask == 1)[0]
        action = np.random.choice(I)

        state, mask, reward, done = env.step(action)

        print(step % 2, action, (reward, done, state.board.shape, mask))

        if done:
            env.render_ascii()
            break

show()

0 5 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
1 1 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
0 0 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
1 1 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
0 4 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
1 4 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
0 0 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
1 3 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
0 4 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
1 6 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
0 4 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
1 6 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
0 1 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
1 0 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
0 5 (0, False, (6, 7, 2), array([1

# Baseline agents

In [261]:
class RandomAgent:
    def get_action(self, state: connect_four.HashableState, action_mask):
        I = np.where(action_mask == 1)[0]
        action = np.random.choice(I)
        return action


class AlwaysLeftAgent:
    def get_action(self, state: connect_four.HashableState, action_mask):
        I = np.where(action_mask == 1)[0]
        return I[0]


class AlwaysRightAgent:
    def get_action(self, state: connect_four.HashableState, action_mask):
        I = np.where(action_mask == 1)[0]
        return I[-1]


def battle(env, agent0, agent1, n_games, n_max_steps):
    """
    Play `n_games` where each game is at most `n_max_steps` turns.
    Return counter of game results as array `[
        agent0  first turn and agent0 wins,
        agent0 second turn and agent0 wins,
        agent1  first turn and agent1 wins,
        agent1 second turn and agent1 wins,
        draws
    ]`
    """

    results = [0, 0, 0, 0, 0]
    agents = [agent0, agent1]

    for game in range(n_games):
        state, mask = env.reset()

        agent_id_first_turn = game % 2  # Switch sides every other game

        for step in range(n_max_steps):
            action = agents[(agent_id_first_turn + step) % 2].get_action(state, mask)

            state, mask, reward, done = env.step(action)

            if done:
                if reward == 0:   # draw
                    results[4] += 1
                else:
                    if agent_id_first_turn == 0:   # agent0 first turn, agent1 second turn
                        if reward == 1:
                            results[0] += 1
                        else:
                            results[3] += 1

                    else:  # agent0 second turn, agent1 first turn
                        if reward == 1:
                            results[2] += 1
                        else:
                            results[1] += 1

                break

    return results


def show_sample_agents_battle():
    env = connect_four.ConnectFour()

    for name0, agent0, name1, agent1 in [
        ("RandomA", RandomAgent(), "RandomB", RandomAgent()),
        ("Random", RandomAgent(), "AlwaysLeft", AlwaysLeftAgent()),
        ("Random", RandomAgent(), "AlwaysRight", AlwaysRightAgent()),
        ("AlwaysLeft", AlwaysLeftAgent(), "AlwaysRight", AlwaysRightAgent()),
    ]:
        battle_results = battle(env, agent0, agent1, n_games=1000, n_max_steps=50)
        print(f"{name0} vs {name1}:")
        print(f"   {name0}  first turn wins: {battle_results[0]}")
        print(f"   {name0} second turn wins: {battle_results[1]}")
        print(f"   {name1}  first turn wins: {battle_results[2]}")
        print(f"   {name1} second turn wins: {battle_results[3]}")
        print(f"                      draws: {battle_results[4]}")

show_sample_agents_battle()


RandomA vs RandomB:
   RandomA  first turn wins: 278
   RandomA second turn wins: 199
   RandomB  first turn wins: 298
   RandomB second turn wins: 219
                      draws: 6
Random vs AlwaysLeft:
   Random  first turn wins: 131
   Random second turn wins: 69
   AlwaysLeft  first turn wins: 431
   AlwaysLeft second turn wins: 369
                      draws: 0
Random vs AlwaysRight:
   Random  first turn wins: 115
   Random second turn wins: 85
   AlwaysRight  first turn wins: 415
   AlwaysRight second turn wins: 385
                      draws: 0
AlwaysLeft vs AlwaysRight:
   AlwaysLeft  first turn wins: 500
   AlwaysLeft second turn wins: 0
   AlwaysRight  first turn wins: 500
   AlwaysRight second turn wins: 0
                      draws: 0


# AlphaGo Zero plays Connect Four

## Experience replay

In [285]:
class ExperienceReplayEpisode:
    def __init__(self, first_state):
        self.states = [first_state]
        self.actions = []
        self.terminal_rewards = [0, 0]

    def on_action(self, action: int, reward: int, done: bool, next_state: connect_four.HashableState):
        if done:
            player_idx = self.states[-1].player_idx
            self.terminal_rewards[player_idx] = reward
            self.terminal_rewards[1 - player_idx] = -reward
        self.actions.append(action)
        self.states.append(next_state)


class ExperienceReplay:
    def __init__(self, max_episodes):
        self.episodes = deque()
        self.max_episodes = max_episodes

    def on_reset(self, state: connect_four.HashableState):
        self.episodes.append(ExperienceReplayEpisode(state))
        while len(self.episodes) > self.max_episodes:
            self.episodes.popleft()

    def on_action(self, action: int, reward: int, done: bool, next_state: connect_four.HashableState):
        self.episodes[-1].on_action(action, reward, done, next_state)

    def clear(self):
        self.episodes.clear()

    def yield_training_tuples(self):
        for episode in self.episodes:
            for i in range(len(episode.actions)):
                state = episode.states[i]
                action = episode.actions[i]
                reward = episode.terminal_rewards[state.player_idx]

                yield state, action, reward

## NN Model

In [286]:
class AlphaGoZeroModel(torch.nn.Module):
    """
    Implement AlphaGo Zero model with two heads: for action probabilities and state value
    """
    def __init__(self) -> None:
        super().__init__()

        self.model_common = torch.nn.Sequential(
            torch.nn.Flatten(),   # (6, 7, 2) -> 84
            torch.nn.Linear(84, 32),
            torch.nn.ReLU(),
        )

        self.model_action_logits = torch.nn.Sequential(
            torch.nn.Linear(32, 7),
        )

        self.model_state_value = torch.nn.Sequential(
            torch.nn.Linear(32, 1),
        )


    def _rotate_board(self, state: connect_four.HashableState):
        """
        Make current player to have coins at `board[:, :, 0]`, and opponent at `board[:, :, 1]`
        """
        board = state.board
        if state.player_idx == 1:
            board = np.flip(board, axis=2).copy()
        return board


    def forward(self, board):
        B = board.shape[0]
        assert board.shape == (B, 6, 7, 2)

        common = self.model_common(board)
        action_logits = self.model_action_logits(common)
        state_value = self.model_state_value(common)

        # action_logits.shape == (B, 7)
        # state_value.shape == (B, 1)
        return action_logits, state_value


    def get_p_v_single_state(self, state: connect_four.HashableState, need_p: bool, need_v: bool):
        board = self._rotate_board(state)
        board = torch.from_numpy(board.reshape((1, 6, 7, 2))).float()
        common = self.model_common(board)

        action_logits, state_value = None, None

        if need_p:
            action_logits = self.model_action_logits(common)
            action_probs = torch.nn.functional.softmax(action_logits, dim=1)
            action_probs = action_probs.detach().cpu().numpy()[0, :]

        if need_v:
            state_value = self.model_state_value(common)
            state_value = state_value.detach().cpu().numpy()[0, 0]

        return action_probs, state_value

    def get_rollout_action(self, state: connect_four.HashableState, actions_mask):
        assert actions_mask.shape == (7,)

        action_probs, _ = self.get_p_v_single_state(state, need_p=True, need_v=False)
        action_probs[actions_mask == 0] = -np.inf

        return np.argmax(action_probs)

    def go_train(self, tree, experience_replay, optimizer, batch_size):
        self.train()

        data = list(experience_replay.yield_training_tuples())

        TENSORBOARD.append_scalar('data len', len(data))

        np.random.shuffle(data)
        for bi in range(batch_size, len(data) + 1, batch_size):
            boards = []
            actual_probs = []
            actual_values = []

            for state, action, reward in data[bi - batch_size : bi]:
                boards.append(self._rotate_board(state))

                cnt_visits = tree.N_cnt_visits[state]
                actual_probs.append(cnt_visits / np.sum(cnt_visits))

                actual_values.append(reward)

            boards = torch.as_tensor(boards).float()
            # print(f"[go_train] boards {boards.shape} {boards.dtype}")   # (B, 7, 8, 2)

            actual_probs = torch.as_tensor(actual_probs).float()
            # print(f"[go_train] actual_probs {actual_probs.shape} {actual_probs.dtype}")  # (B, 7)

            actual_values = torch.as_tensor(actual_values).view((batch_size, 1)).float()
            # print(f"[go_train] actual_values {actual_values.shape} {actual_values.dtype}")  # (B, 1)

            pred_action_logits, pred_state_value = self.forward(boards)
            # print(f"[go_train] pred_action_logits {pred_action_logits.shape} {pred_action_logits.dtype}")  # (B, 7)
            # print(f"[go_train] pred_state_value {pred_state_value.shape} {pred_state_value.dtype}")  # (B, 1)

            states_loss = torch.nn.functional.mse_loss(actual_values, pred_state_value)
            action_loss = torch.nn.functional.cross_entropy(pred_action_logits, actual_probs)

            TENSORBOARD.append_scalar('states_loss', states_loss.item())
            TENSORBOARD.append_scalar('action_loss', action_loss.item())

            loss = states_loss + action_loss

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

class AlphaGoZeroAgent:
    """
    Trained agent which can play Connect Four game using `AlphaGoZeroModel` model
    """
    def __init__(self, model: AlphaGoZeroModel) -> None:
        self.model = model

    @torch.no_grad()
    def get_action(self, state: connect_four.HashableState, actions_mask):
        self.model.eval()
        return self.model.get_rollout_action(state, actions_mask)

## MCTS Tree

In [287]:
class MCTS_Tree:
    """
    Tree structure which keeps various statistics needed for MCTS
    """
    def __init__(self) -> None:
        # Structure of all dicts: { state -> [ action -> value ] }
        self.N_cnt_visits = dict()
        self.W_sum_vals = dict()
        self.Q_values = dict()
        self.P_action_probs = dict()

        self.V_state_values = dict()  # { state -> predicted state value v(s) }

    def get_action(self, state: connect_four.HashableState):
        N_s_a = self.N_cnt_visits[state]
        Q_s_a = self.Q_values[state]
        P_s_a = self.P_action_probs[state]
        U_s_a = 4 * P_s_a * np.sqrt(np.sum(N_s_a)) / (1 + N_s_a)
        action = np.argmax(Q_s_a + U_s_a)
        return action

    def add_node(self, model: AlphaGoZeroModel, state, actions_mask):
        action_probs, state_value = model.get_p_v_single_state(state, need_p=True, need_v=True)
        action_probs[actions_mask == 0] = 0
        action_probs /= np.sum(action_probs)

        actions_len = len(actions_mask)
        self.N_cnt_visits[state] = np.zeros(actions_len)
        self.W_sum_vals[state] = np.zeros(actions_len)
        self.Q_values[state] = np.zeros(actions_len)
        self.Q_values[state][actions_mask == 0] = -np.inf
        self.P_action_probs[state] = action_probs
        self.V_state_values[state] = state_value

    def backpropagate(self, sa_array, reward):
        # `sa_array`: selected path in tree [state0, action0, state1, action1, ..., stateK],

        assert len(sa_array) % 2 == 1 and len(sa_array) >= 3
        for si in range(0, len(sa_array) - 1, 2):
            state, action = sa_array[si], sa_array[si + 1]

            # reward = 1 iff player 0 wins
            # if `state.player_idx` == 0 -> then it gets reward of 1
            # if `state.player_idx` == 1 -> then it gets reward of -1
            # if reward is -1, then all values are opposite

            self.N_cnt_visits[state][action] += 1
            self.W_sum_vals[state][action] += reward * (1 - 2 * state.player_idx)
            self.Q_values[state][action] = self.W_sum_vals[state][action] / self.N_cnt_visits[state][action]

    def clear(self):
        self.N_cnt_visits.clear()
        self.W_sum_vals.clear()
        self.Q_values.clear()
        self.P_action_probs.clear()
        self.V_state_values.clear()

## Agent Evaluator

In [288]:
class AgentEvaluator:
    """
    Evaluate models and hold the best model so far
    """

    def __init__(self, n_evaluate_games: int, n_max_steps: int, best_model=None) -> None:
        self.n_evaluate_games = n_evaluate_games
        self.n_max_steps = n_max_steps
        self.best_model = best_model if best_model is not None else AlphaGoZeroModel()
        self.best_agent = AlphaGoZeroAgent(self.best_model)

    def clone_best_model(self):
        new_model = AlphaGoZeroModel()
        new_model.load_state_dict(copy.deepcopy(self.best_model.state_dict()))
        return new_model

    @torch.no_grad()
    def battle_new_candidate(self, env, candidate_model: AlphaGoZeroModel):
        """
        Take candidate model and play it against current best
        Return true if the candidate model is better and become a new best model
        """
        self.best_model.eval()
        candidate_model.eval()

        candidate_agent = AlphaGoZeroAgent(candidate_model)
        results = battle(env, candidate_agent, self.best_agent, n_games=self.n_evaluate_games, n_max_steps=self.n_max_steps)

        wins = results[0] + results[1]
        losses = results[2] + results[3]

        if wins >= 0.55 * self.n_evaluate_games:
            print(f"Candidate is a new champion: wins {wins}, losses {losses}, ties {results[4]}. Results={results}")
            self.best_model = candidate_model
            self.best_agent = candidate_agent
            return True

        print(f"Candidate defeated: wins {wins}, losses {losses}, ties {results[4]}. Results={results}")
        return False

## Implementation

In [395]:
class AlphaGoZero_Impl:
    def __init__(self,
        env,
        agent_evaluator: AgentEvaluator,
        mem_max_episodes: int,
        n_iterations: int,
        games_per_iteration: int,
        simulations_cnt: int,
        learning_rate: float,
        weight_decay: float,
        batch_size: int,
        root_dirichlet_alpha: float,
        root_exploration_fraction: float,
    ):
        """
        Parameters:
        env: connect four compatible environment
        agent_evaluator: to evaluate models and keep track of the current best model
        mem_max_episodes: how many episodes keep in experience replay
        n_iterations: how many iterations to do (one iteration consists of one improvement step and one agent evaluation step)
        games_per_iteration: how many games best player plays against itself in each iteration
        simulations_cnt: how many MCTS simulations run in each iteration
        learning_rate: learning rate for model optimizer
        weight_decay: weight decay for model training loss
        batch_size: batch size for model training
        root_dirichlet_alpha: noise power at each root state for selection
        root_exploration_fraction: noise fraction added at each root state for selection
        """
        self.root_env = env
        self.agent_evaluator = agent_evaluator
        self.n_iterations = n_iterations
        self.games_per_iteration = games_per_iteration
        self.simulations_cnt = simulations_cnt
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.batch_size = batch_size
        self.root_dirichlet_alpha = root_dirichlet_alpha
        self.root_exploration_fraction = root_exploration_fraction
        self.tree = MCTS_Tree()
        self.experience_replay = ExperienceReplay(max_episodes=mem_max_episodes)

    def simulate(self, env, model: AlphaGoZeroModel, state, actions_mask):
        while True:
            action = model.get_rollout_action(state, actions_mask)
            state, actions_mask, reward, done = env.step(action)
            if done:
                break
        return reward

    def select(self, env):
        tree = self.tree
        state, actions_mask = env.last()
        sa_array = [state]    # selected path [state0, action0, state1, action1, ..., stateK],
        done = False
        reward = 0

        if state in tree.V_state_values:
            noise = np.random.gamma(self.root_dirichlet_alpha, 1, len(actions_mask))
            frac = self.root_exploration_fraction
            tree.P_action_probs[state] = tree.P_action_probs[state] * (1 - frac) + noise * frac

        while sa_array[-1] in tree.V_state_values:
            action = tree.get_action(state)
            state, actions_mask, reward, done = env.step(action)

            sa_array.append(action)
            sa_array.append(state)
            if done:
                break

        # Guarantees:
        #    `sa_array[-1] not in tree.states_set`
        #    `actions_mask, reward, done` tied to `sa_array[-1]`
        #    `env` moves to state `sa_array[-1]`
        return sa_array, actions_mask, reward, done

    def expand(self, env, model: AlphaGoZeroModel, sa_array, actions_mask):
        action = model.get_rollout_action(sa_array[-1], actions_mask)
        state, actions_mask, reward, done = env.step(action)

        sa_array.append(action)
        sa_array.append(state)
        return reward, done

    def mcts_step(self, root_env, model: AlphaGoZeroModel):
        tree = self.tree

        for si in range(self.simulations_cnt):
            env = root_env.copy()

            sa_array, actions_mask, reward, done = self.select(env)

            if not done:
                tree.add_node(model, sa_array[-1], actions_mask)

                reward, done = self.expand(env, model, sa_array, actions_mask)

                if not done:
                    reward = self.simulate(env, model, sa_array[-1], actions_mask)

            tree.backpropagate(sa_array, reward)

        state, actions_mask = root_env.last()
        N_s_a = np.sqrt(tree.N_cnt_visits[state])
        pi_probs = N_s_a / np.sum(N_s_a)
        pi_probs[actions_mask == 0] = 0
        pi_probs /= np.sum(pi_probs)
        return np.random.choice(len(actions_mask), p=pi_probs)

    def collect_tree_statistics_using_self_play(self, model: AlphaGoZeroModel):
        experience_replay = self.experience_replay

        root_env = self.root_env

        for game_i in range(self.games_per_iteration):
            state, mask = root_env.reset()
            experience_replay.on_reset(state)
            while True:
                action = self.mcts_step(root_env, model)
                state, mask, reward, done = root_env.step(action)
                experience_replay.on_action(action, reward, done, state)
                if done:
                    break

    def go_train(self):
        agent_evaluator = self.agent_evaluator
        experience_replay = self.experience_replay
        tree = self.tree

        candidate_model = None
        optimizer = None

        for iter_i in range(self.n_iterations):
            best_model = agent_evaluator.best_model
            best_model.eval()
            with torch.no_grad():
                self.collect_tree_statistics_using_self_play(agent_evaluator.best_model)

            if candidate_model is None:
                candidate_model = agent_evaluator.clone_best_model()
                optimizer = torch.optim.AdamW(candidate_model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)

            candidate_model.train()
            candidate_model.go_train(tree, experience_replay, optimizer, self.batch_size)
            experience_replay.clear()

            candidate_model.eval()
            with torch.no_grad():
                if agent_evaluator.battle_new_candidate(self.root_env, candidate_model):
                    # Yay, new champion! Reset all things to start improving new best model
                    candidate_model = None
                    optimizer = None
                    tree.clear()

def train_alphagozero_connectfour(agent_evaluator=None):
    env = connect_four.ConnectFour()

    if agent_evaluator is None:
        agent_evaluator = AgentEvaluator(
            n_evaluate_games=400,
            n_max_steps=50,
        )

    impl = AlphaGoZero_Impl(
        agent_evaluator=agent_evaluator,
        env=env,
        mem_max_episodes=100000,
        n_iterations=20,
        games_per_iteration=100,
        simulations_cnt=20,
        learning_rate=1e-2,
        weight_decay=1e-5,
        batch_size=16,
        root_dirichlet_alpha=0.3,   # like in chess
        root_exploration_fraction=0.25,
    )
    impl.go_train()

    return {
        'env': env,
        'agent_evaluator': agent_evaluator,
    }

# Go train

In [396]:
# trained = train_alphagozero_connectfour()
trained = train_alphagozero_connectfour(trained['agent_evaluator'])

Candidate defeated: wins 0, losses 400, ties 0. Results=[0, 0, 200, 200, 0]
Candidate is a new champion: wins 400, losses 0, ties 0. Results=[200, 200, 0, 0, 0]
Candidate defeated: wins 0, losses 400, ties 0. Results=[0, 0, 200, 200, 0]
Candidate is a new champion: wins 400, losses 0, ties 0. Results=[200, 200, 0, 0, 0]
Candidate is a new champion: wins 400, losses 0, ties 0. Results=[200, 200, 0, 0, 0]
Candidate defeated: wins 200, losses 200, ties 0. Results=[0, 200, 0, 200, 0]
Candidate defeated: wins 200, losses 200, ties 0. Results=[0, 200, 0, 200, 0]
Candidate defeated: wins 200, losses 200, ties 0. Results=[200, 0, 200, 0, 0]
Candidate is a new champion: wins 400, losses 0, ties 0. Results=[200, 200, 0, 0, 0]
Candidate is a new champion: wins 400, losses 0, ties 0. Results=[200, 200, 0, 0, 0]
Candidate defeated: wins 200, losses 200, ties 0. Results=[0, 200, 0, 200, 0]
Candidate defeated: wins 200, losses 200, ties 0. Results=[200, 0, 200, 0, 0]
Candidate defeated: wins 0, losse

In [397]:
def show_trained(trained):
    env = trained['env']
    agent_evaluator = trained['agent_evaluator']
    agent = agent_evaluator.best_agent

    agent.model.eval()
    with torch.no_grad():
        print(battle(env, RandomAgent(), agent, n_games=1000, n_max_steps=50))

show_trained(trained)

[69, 39, 461, 431, 0]


# Play

In [398]:
def play_vs_ai(env, opponent, action):
    state, action_mask, reward, done = env.step(action)

    print(f"Your move {action}: actions mask {list(action_mask)}, reward {reward}, done {done}.")
    env.render_ascii()

    if not done:
        action = opponent.get_action(state, action_mask)
        state, action_mask, reward, done = env.step(action)
        print()
        print(f"Opponent move {action}: actions mask {list(action_mask)}, reward {reward}, done {done}")
        env.render_ascii()

test_opponent = AlphaGoZeroAgent(trained['agent_evaluator'].best_model)

test_env = connect_four.ConnectFour()
test_env.reset()
test_env.step(test_opponent.get_action(test_env.last()[0], test_env.last()[1]))
test_env.render_ascii()

.......
.......
.......
.......
.......
...0...


In [405]:
play_vs_ai(test_env, test_opponent, 4)

Your move 4: actions mask [1, 1, 1, 1, 1, 1, 1], reward -1, done True.
.......
...0...
..001..
..011..
..101..
.0101..


In [94]:
# TODO:
# Add symmetrical boards
# Cpuct: supplementary materials p5
# LR schedule: supplementary materials p7
# Paper: tree is new in each mcts run