# Notes

Env based on: https://pettingzoo.farama.org/environments/classic/connect_four/

# Imports

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
from collections import defaultdict, Counter
import torch

import connect_four

# Show environment

In [3]:
def show():
    # env = pettingzoo.classic.connect_four_v3.env(render_mode="human")
    env = connect_four.ConnectFour()

    state, mask = env.reset()

    for step in range(50):
        I = np.where(mask == 1)[0]
        action = np.random.choice(I)

        state, mask, reward, done = env.step(action)

        print(step % 2, action, (reward, done, state.board.shape, mask))

        if done:
            env.render_ascii()
            break

show()

0 5 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
1 1 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
0 0 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
1 1 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
0 4 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
1 4 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
0 0 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
1 3 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
0 4 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
1 6 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
0 4 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
1 6 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
0 1 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
1 0 (0, False, (6, 7, 2), array([1, 1, 1, 1, 1, 1, 1], dtype=int32))
0 5 (0, False, (6, 7, 2), array([1

# Baseline agents

In [2]:
class RandomAgent:
    def get_action(self, state: connect_four.HashableState, action_mask):
        I = np.where(action_mask == 1)[0]
        action = np.random.choice(I)
        return action


class AlwaysLeftAgent:
    def get_action(self, state: connect_four.HashableState, action_mask):
        I = np.where(action_mask == 1)[0]
        return I[0]


class AlwaysRightAgent:
    def get_action(self, state: connect_four.HashableState, action_mask):
        I = np.where(action_mask == 1)[0]
        return I[-1]


def battle(env, agent0, agent1, n_games, n_max_steps):
    """
    Play `n_games` where each game is at most `n_max_steps` turns.
    Return counter of game results as array `[agent 1 wins, agent 2 wins, draws]`
    """

    results = [0, 0, 0]
    agents = [agent0, agent1]

    for game in range(n_games):
        winner = 2
        state, mask = env.reset()
        who = game % 2  # Switch sides every other game

        for step in range(n_max_steps):
            action = agents[who].get_action(state, mask)

            state, mask, reward, done = env.step(action)

            if done:
                if reward != 0:
                    winner = who
                break
            who = 1 - who

        results[winner] += 1

    return results


def show_sample_agents_battle():
    env = connect_four.ConnectFour()

    for name0, agent0, name1, agent1 in [
        ("Random", RandomAgent(), "Random", RandomAgent()),
        ("Random", RandomAgent(), "AlwaysLeft", AlwaysLeftAgent()),
        ("Random", RandomAgent(), "AlwaysRight", AlwaysRightAgent()),
        ("AlwaysLeft", AlwaysLeftAgent(), "AlwaysRight", AlwaysRightAgent()),
    ]:
        battle_results = battle(env, agent0, agent1, n_games=500, n_max_steps=50)
        print(f"{name0} vs {name1}: wins {battle_results[0]}, loses {battle_results[1]}, draws {battle_results[2]}")

show_sample_agents_battle()


Random vs Random: wins 244, loses 255, draws 1
Random vs AlwaysLeft: wins 108, loses 392, draws 0
Random vs AlwaysRight: wins 99, loses 401, draws 0
AlwaysLeft vs AlwaysRight: wins 250, loses 250, draws 0


# AlphaGo Zero plays Connect Four

In [12]:
class AlphaGoZeroModel(torch.nn.Module):
    """
    Implement AlphaGo Zero model with two heads: for action probabilities and state value
    """
    def __init__(self) -> None:
        super().__init__()

        self.model_common = torch.nn.Sequential(
            torch.nn.Flatten(),   # (6, 7, 2) -> 84
            torch.nn.Linear(84, 32),
            torch.nn.ReLU(),
        )

        self.model_action_probs = torch.nn.Sequential(
            torch.nn.Linear(32, 7),
            torch.nn.Softmax(dim=1),
        )

        self.model_state_value = torch.nn.Sequential(
            torch.nn.Linear(32, 1),
        )


    def forward(self, board):
        B = board.shape[0]
        assert board.shape == (B, 6, 7, 2)

        common = self.model_common(board)
        action_probs = self.model_action_probs(common)
        state_value = self.model_state_value(common)

        # action_probs.shape == (B, 7)
        # state_value.shape == (B, 1)
        return action_probs, state_value


    def get_p_v_single_state(self, state: connect_four.HashableState, need_p: bool, need_v: bool):
        board, player_idx = state.board, state.player_idx
        assert board.shape == (6, 7, 2)

        if player_idx == 1:
            # Make current player to have coins at board[:, :, 0], and opponent at board[:, :, 1]
            board = np.flip(board, axis=2).copy()

        board = torch.from_numpy(board.reshape((1, 6, 7, 2))).float()
        common = self.model_common(board)

        action_probs, state_value = None, None

        if need_p:
            action_probs = self.model_action_probs(common)
            action_probs = action_probs.detach().cpu().numpy()[0, :]

        if need_v:
            state_value = self.model_state_value(common)
            state_value = state_value.detach().cpu().numpy()[0, 0]

        return action_probs, state_value


    def get_rollout_action(self, state: connect_four.HashableState, actions_mask):
        assert actions_mask.shape == (7,)

        action_probs, _ = self.get_p_v_single_state(state, need_p=True, need_v=False)
        action_probs[actions_mask == 0] = -np.inf

        return np.argmax(action_probs)

    def go_train(self, tree, optimizer, batch_size):
        self.train()

        all_states = list(tree.V_state_values.keys())
        np.random.shuffle(all_states)
        for bi in range(batch_size, len(all_states) + 1, batch_size):
            states = all_states[bi - batch_size : bi]

            board = [state.board for state in states]
            board = torch.as_tensor(board).float()
            # print(f"[go_train] board {board.shape} {board.dtype}")   # (B, 7, 8, 2)

            action_probs, state_value = self.forward(board)
            # print(f"[go_train] action_probs {action_probs.shape} {action_probs.dtype}")  # (B, 7)
            # print(f"[go_train] state_value {state_value.shape} {state_value.dtype}")  # (B, 1)

            actual_probs = np.array([ tree.N_cnt_visits[state] for state in states ])
            # print(f"[go_train] actual_probs {actual_probs.shape} {actual_probs.dtype}")  # (B, 7)
            actual_probs /= np.sum(actual_probs, axis=1).reshape((batch_size, 1))

            actual_values = np.array([ np.max(tree.Q_values[state]) for state in states ])
            # print(f"[go_train] actual_values {actual_values.shape} {actual_values.dtype}")  # (B,)

            actual_probs = torch.as_tensor(actual_probs)
            actual_values = torch.as_tensor(actual_values).view((batch_size, 1))

            states_loss = torch.sum(torch.abs((actual_values - state_value)**2))

            action_loss = torch.sum(actual_probs * torch.log(action_probs))

            loss = states_loss + action_loss

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

class AlphaGoZeroAgent:
    """
    Trained agent which can play Connect Four game using `AlphaGoZeroModel` model
    """
    def __init__(self, model: AlphaGoZeroModel) -> None:
        self.model = model

    @torch.no_grad()
    def get_action(self, state: connect_four.HashableState, actions_mask):
        self.model.eval()
        return self.model.get_rollout_action(state, actions_mask)


class MCTS_Tree:
    """
    Tree structure which keeps various statistics needed for MCTS
    """
    def __init__(self) -> None:
        # Structure of all dicts: { state -> [ action -> value ] }
        self.N_cnt_visits = dict()
        self.W_sum_vals = dict()
        self.Q_values = dict()
        self.P_action_probs = dict()

        self.V_state_values = dict()  # { state -> predicted state value v(s) }

    def get_action(self, env, state: connect_four.HashableState, actions_mask):
        N_s_a = self.N_cnt_visits[state]
        Q_s_a = self.Q_values[state]
        P_s_a = self.P_action_probs[state]
        U_s_a = 4 * P_s_a * np.sum(N_s_a) / (1 + N_s_a)
        action = np.argmax(Q_s_a + U_s_a)
        return action

    def add_node(self, model: AlphaGoZeroModel, state, actions_mask):
        action_probs, state_value = model.get_p_v_single_state(state, need_p=True, need_v=True)
        action_probs[actions_mask == 0] = 0
        action_probs /= np.sum(action_probs)

        actions_len = len(actions_mask)
        self.N_cnt_visits[state] = np.zeros(actions_len)
        self.W_sum_vals[state] = np.zeros(actions_len)
        self.Q_values[state] = np.zeros(actions_len)
        self.Q_values[state][actions_mask == 0] = -np.inf
        self.P_action_probs[state] = action_probs
        self.V_state_values[state] = state_value

    def backpropagate(self, sa_array, reward):
        # `sa_array`: selected path in tree [state0, action0, state1, action1, ..., stateK],

        assert len(sa_array) % 2 == 1 and len(sa_array) >= 3
        for si in range(0, len(sa_array) - 1, 2):
            state, action = sa_array[si], sa_array[si + 1]

            # reward = 1 iff player 0 wins
            # if `state.player_idx` == 0 -> then it gets reward of 1
            # if `state.player_idx` == 1 -> then it gets reward of -1
            # if reward is -1, then all values are opposite

            self.N_cnt_visits[state][action] += 1
            self.W_sum_vals[state][action] += reward * (1 - 2 * state.player_idx)
            self.Q_values[state][action] = self.W_sum_vals[state][action] / self.N_cnt_visits[state][action]


In [17]:
class AlphaGoZero_Impl:
    def __init__(self,
        env,
        games_per_iteration,
        simulations_cnt,
        learning_rate,
        weight_decay,
        batch_size
    ):
        """
        Parameters:
        games_per_iteration: how many games best player plays against itself in each iteration
        simulations_cnt: how many MCTS simulations run in each iteration
        learning_rate: learning rate for model optimizer
        weight_decay: weight decay for model training loss
        batch_size: batch size for model training
        """
        self.root_env = env
        self.games_per_iteration = games_per_iteration
        self.simulations_cnt = simulations_cnt
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.batch_size = batch_size

    @torch.no_grad()
    def simulate(self, env, model: AlphaGoZeroModel, state, actions_mask):
        model.eval()

        while True:
            action = model.get_rollout_action(state, actions_mask)
            state, actions_mask, reward, done = env.step(action)
            if done:
                break
        return reward

    def select(self, env, tree: MCTS_Tree):
        state, actions_mask = env.last()
        sa_array = [state]    # selected path [state0, action0, state1, action1, ..., stateK],
        done = False
        reward = 0

        while sa_array[-1] in tree.V_state_values:
            action = tree.get_action(env, state, actions_mask)
            state, actions_mask, reward, done = env.step(action)

            sa_array.append(action)
            sa_array.append(state)
            if done:
                break

        # Guarantees:
        #    `sa_array[-1] not in tree.states_set`
        #    `actions_mask, reward, done` tied to `sa_array[-1]`
        #    `env` moves to state `sa_array[-1]`
        return sa_array, actions_mask, reward, done

    def expand(self, env, model: AlphaGoZeroModel, sa_array, actions_mask):
        action = model.get_rollout_action(sa_array[-1], actions_mask)
        state, actions_mask, reward, done = env.step(action)

        sa_array.append(action)
        sa_array.append(state)
        return reward, done

    def mcts_step(self, root_env, model: AlphaGoZeroModel, tree: MCTS_Tree):
        for si in range(self.simulations_cnt):
            env = root_env.copy()

            sa_array, actions_mask, reward, done = self.select(env, tree)

            tree.add_node(model, sa_array[-1], actions_mask)

            if not done:
                reward, done = self.expand(env, model, sa_array, actions_mask)

                if not done:
                    reward = self.simulate(env, model, sa_array[-1], actions_mask)

            tree.backpropagate(sa_array, reward)

        state, actions_mask = root_env.last()
        N_s_a = tree.N_cnt_visits[state]
        pi_probs = N_s_a / np.sum(N_s_a)
        pi_probs[actions_mask == 0] = 0
        pi_probs /= np.sum(pi_probs)
        return np.random.choice(len(actions_mask), p=pi_probs)

    def improve_policy_using_self_play(self, model: AlphaGoZeroModel, optimizer):
        root_env = self.root_env
        tree = MCTS_Tree()

        for game_i in range(self.games_per_iteration):
            root_env.reset()
            while True:
                action = self.mcts_step(root_env, model, tree)

                state, mask, reward, done = root_env.step(action)
                if done:
                    break

        model.go_train(tree, optimizer, self.batch_size)

    def go_train(self):
        model = AlphaGoZeroModel()
        optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        self.improve_policy_using_self_play(model, optimizer)
        return model


def train_alphagozero_connectfour():
    env = connect_four.ConnectFour()
    impl = AlphaGoZero_Impl(
        env=env,
        games_per_iteration=10,
        simulations_cnt=100,
        learning_rate=1e-3,
        weight_decay=1e-5,
        batch_size=8,
    )
    model = impl.go_train()
    return env, model

trained = train_alphagozero_connectfour()

  actual_probs /= np.sum(actual_probs, axis=1).reshape((batch_size, 1))


In [18]:
def show_trained(trained):
    env, model = trained
    print(battle(env, RandomAgent(), AlphaGoZeroAgent(model), n_games=1000, n_max_steps=50))

show_trained(trained)

[201, 799, 0]
