# **Project AlphaFour**
In this project I implemented the alphaGo algorithm depicted in this paper: https://augmentingcognition.com/assets/Silver2017a.pdf for Connect4, I also implemented a fast and efficient Connect4 environment that has integers as rows instead of arrays(though the bottleneck is not the environment but the neural net, but it was a fun challenge to optimize the environment).  

# setup

imports

In [1]:
import torch
import random
from torch import nn, optim, stack, tensor, flatten
from torch.nn import functional as F
from matplotlib import pyplot as plt
from torchsummary import summary
import numpy as np
import os

configuration

In [2]:
config = {
    "device" : 'cpu',
    "mem_size" : 1000,
    "mcts_iterations" : 240,
    'num_games_per_epoch' : 15,
    'num_mcts_iterations_increase' : 20,
    'batch_size' : 32,
    'initial_tau' : 1.2,
    'tau_multiplier' : 1.1,
    'lr' : 3e-4,
    'max_tau' : 2.0,
    'num_training_steps_after_game' : 10,
    'saving_dir' : 'models',
    'num_blocks': 15,
    'hidden_size': 200
}

# Environment

The environment class saves each row as an integer where each square in the row takes two bits of the row (in total there are 14 bits occupied), if a quare is empty it would have "00" in it, if player one placed a piece in the square then it would have "01" in it and if player two played the square it would have "10". We can use the binary format to check for win efficiently etc. . The environment class also is responsible to provide the state translation to tensor.

In [3]:
class EffConnect4:
    def reset():
        board = [0 for _ in range(6)]
        current_player = 1

        return (board, current_player)

    def get_row(state, row_number):
      board, _ = state
      row = board[row_number]
      row = [row % 4, (row //4) % 4, (row //16) % 4, (row //64) % 4, (row //256) % 4, (row //1024) % 4, (row //4096) % 4 ]
      return row

    def check_win_cols(board):
        lines_0_1, lines_1_2, lines_2_3, lines_3_4, lines_4_5 = \
         (board[0] & board[1]), (board[1] & board[2]), (board[2] & board[3])\
        ,(board[3] & board[4]), (board[4] & board[5])
        res = (lines_0_1 & lines_2_3) | (lines_1_2 & lines_3_4) | (lines_2_3 & lines_4_5)
        return bool(res & 10922) * 2 + bool(res & 5461) * 1

    def check_win_ascending_diags(board):
        lines_0_1, lines_1_2, lines_2_3, lines_3_4, lines_4_5 = \
         (board[0] << 10 & board[1] << 8) , (board[1] << 8 & board[2] << 6), (board[2] << 6 & board[3] << 4)\
        ,(board[3] << 4 & board[4] << 2), (board[4] << 2 & board[5])
        res = (lines_0_1 & lines_2_3) | (lines_1_2 & lines_3_4) | (lines_2_3 & lines_4_5)
        return bool(res & 11184810) * 2 + bool(res & 5592405) * 1

    def check_win_descending_diags(board):
        lines_0_1, lines_1_2, lines_2_3, lines_3_4, lines_4_5 = \
         (board[0] & board[1] << 2) , (board[1] << 2 & board[2] << 4), (board[2] << 4 & board[3] << 6)\
        ,(board[3] << 6 & board[4] << 8), (board[4] << 8 & board[5] << 10)
        res = (lines_0_1 & lines_2_3) | (lines_1_2 & lines_3_4) | (lines_2_3 & lines_4_5)
        return bool(res & 11184810) * 2 + bool(res & 5592405) * 1

    def check_a_row(board, i):
        return  2 * bool((board[i] & 10880 == 10880) | (board[i] & 2720 == 2720) | (board[i] & 680 == 680) | (board[i] & 170 == 170)) or\
                1 * bool((board[i] & 5440 == 5440) | (board[i] & 1360 == 1360) | (board[i] & 340 == 340) | (board[i] & 85 == 85))

    def check_win_rows(board):
        return EffConnect4.check_a_row(board, 0) | EffConnect4.check_a_row(board, 1) | EffConnect4.check_a_row(board, 2)\
              | EffConnect4.check_a_row(board, 3) | EffConnect4.check_a_row(board, 4) | EffConnect4.check_a_row(board, 5)

    def check_win(state):
        board, current_player = state
        return EffConnect4.check_win_cols(board) | EffConnect4.check_win_rows(board) \
              | EffConnect4.check_win_descending_diags(board) | EffConnect4.check_win_ascending_diags(board)

    def get_action_mask(board):
        row = ~(board[-1] | (board[-1] >> 1)) & 5461
        row = [row % 4, (row //4) % 4, (row //16) % 4, (row //64) % 4, (row //256) % 4, (row //1024) % 4, (row //4096) % 4 ]
        return row

    def get_legal_actions(state):
        board, _ = state
        mask = EffConnect4.get_action_mask(board)
        return [i for i in range(7) if mask[i]]

    def play_action(state, action):
        board, current_player = state
        board = list(board)
        x = (current_player << 2 * action)
        y = (3 << 2 * action)
        for i in range(6):
            if (board[i] & y):
                continue
            board[i] += x
            current_player = (3 - current_player)
            return (board, current_player), EffConnect4.get_action_mask(board), EffConnect4.is_terminal((board, current_player))
        raise "no action was played"

    def print_board(state):
        board, current_player = state
        print(f'Player {current_player} to play:')
        print("board:")
        print("-" * 20)
        for i in range(6):
          print(' , '.join(map(str, EffConnect4.get_row(state, 6 - i - 1))))
        print("-" * 20)

    ''' translate the board and current player to one dimensional tensor '''
    def states_to_tensor(states, device):
        state_tensors = []
        for state in states:
            board, current_player = state
            state_array = [current_player]
            for i in range(len(board)):
                state_array.extend(EffConnect4.get_row(state, i))
            state_tensor = (tensor(state_array, dtype=torch.float32, device=device) - 1)/2
            state_tensors.append(state_tensor)
        return torch.stack(state_tensors)

    ''' translate the mask of action to a tensor with -inf for unavilable actions '''
    def action_masks_to_tensor(masks, device):
        return torch.stack([torch.where(tensor(mask) == 1, torch.tensor(0.0), torch.tensor(float('-inf')))
                            for mask in masks])

    def is_full(state):
        first_row = state[0][-1]
        return ((first_row | (first_row >> 1)) & 5461) == 5461

    def is_terminal(state):
        return (EffConnect4.is_full(state) or (EffConnect4.check_win(state) != 0))

    def get_value(state):
        # 3 - current_player gives the other player
        return (EffConnect4.check_win(state) == state[1]) - (EffConnect4.check_win(state) == (3 - state[1]))


In [4]:
import time
t1 = time.time()
env = EffConnect4
state = env.reset()
while not env.is_terminal(state):
  env.print_board(state)
  print(env.get_legal_actions(state))
  action = random.choice(env.get_legal_actions(state))
  state, _, _ = env.play_action(state, action)
  print(f"The action {action} was just played")

env.print_board(state)
print(f"value = {env.get_value(state)}")
print(time.time() - t1)

  # print(f'is winning player {(i % 2) + 1}: {env.is_win((i % 2) + 1)}')

Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
--------------------
[0, 1, 2, 3, 4, 5, 6]
The action 0 was just played
Player 2 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
1 , 0 , 0 , 0 , 0 , 0 , 0
--------------------
[0, 1, 2, 3, 4, 5, 6]
The action 4 was just played
Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
1 , 0 , 0 , 0 , 2 , 0 , 0
--------------------
[0, 1, 2, 3, 4, 5, 6]
The action 4 was just played
Player 2 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 1 , 0 , 0
1 

# Model

In [5]:
class LinearResidualBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()

        self.linear1 = nn.Linear(in_features, out_features)
        self.linear2 = nn.Linear(out_features, out_features)

        self.norm1 = nn.LayerNorm(out_features)
        self.norm2 = nn.LayerNorm(out_features)

        self.activation = nn.ReLU()

    def forward(self, x):
        identity = x
        out = self.linear1(x)
        out = self.norm1(out)
        out = self.activation(out)
        out = self.linear2(out)
        out = self.norm2(out)
        out = self.activation(out)
        return out + identity

In [6]:
class AlphaFourModel(nn.Module):
    def __init__(self, num_blocks, hidden_dim):
        super().__init__()
        self.first_block = nn.Sequential(
            nn.Linear(6*7 + 1, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU()
        )
        self.blocks = nn.ModuleList([LinearResidualBlock(hidden_dim, hidden_dim) for _ in range(num_blocks)])
        self.policy_layers = nn.Sequential(
            nn.Linear(hidden_dim, 500),
            nn.LayerNorm(500),
            nn.ReLU(),
            nn.Linear(500, 7)
        )
        self.value_layers = nn.Sequential(
            nn.Linear(hidden_dim, 500),
            nn.LayerNorm(500),
            nn.ReLU(),
            nn.Linear(500, 1)
        )

    def forward(self, states, masks):
        x = self.first_block(states)

        for block in self.blocks:
            x = block(x)

        distributions = torch.softmax(self.policy_layers(x) + masks, dim = -1)
        val = F.tanh(self.value_layers(x))
        return val, distributions

    def save(self, optimizer, file_name, version):
        if not os.path.isdir(config['saving_dir']):
            os.makedirs(config['saving_dir'])
        print(f"saving model into: {file_name} version: {version}")
        torch.save(self.state_dict(), os.path.join(config['saving_dir'], file_name + '_model_' + f'_version_{version}'))
        torch.save(optimizer.state_dict(),  os.path.join(config['saving_dir'], file_name + '_optimizer_' + f'_version_{version}'))

    def load(self, optimizer, file_name, version):
        print(f"loading model and optimizer from: {file_name} version: {version}")
        self.load_state_dict(torch.load(os.path.join(config['saving_dir'], file_name + '_model_' + f'_version_{version}')))
        optimizer.load_state_dict(torch.load(os.path.join(config['saving_dir'], file_name + '_optimizer_' + f'_version_{version}')))

In [7]:
model = AlphaFourModel(config['num_blocks'], config['hidden_size']).to(config['device'])

state = env.reset()
for i in range(4):
  action = random.choice(env.get_legal_actions(state))
  state, _, _ = env.play_action(state, action)

env.print_board(state)

print(env.get_action_mask(state))

print(model(env.states_to_tensor([state], config['device']), env.action_masks_to_tensor([env.get_action_mask(state)], config['device'])))

Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
1 , 2 , 0 , 2 , 1 , 0 , 0
--------------------
[0, 1, 1, 1, 1, 1, 1]
(tensor([[-0.0423]], grad_fn=<TanhBackward0>), tensor([[0.0000, 0.1368, 0.1312, 0.1445, 0.1191, 0.1077, 0.3608]],
       grad_fn=<SoftmaxBackward0>))


# MCTS

 here I created a monte carlo tree search that get a model, an environment, and adapter that process the data from the environment and pass it to the model.
 It will need to get a state and then run num_simulations of simulations using the model and then it'll return the policy as calculated by the mcts algorithm of the alphazero-paper.https://augmentingcognition.com/assets/Silver2017a.pdf

MCTS search as a function gets an MCTS node that can be an empty root node but can also be a partially filled node, it than runs num_steps extension steps and the proccess of an extension step goes as follow:
start at root as current node:
1.if current node is terminal:
    backprop its value back
2.for each valid next action:
    if action was already played ask the son node for its ucb
    else regard the son node value as zero and

In [8]:
def run_n_simulations(root, num_simulations, env, model):
    if env.is_terminal(root.state):
        return
    for i in range(num_simulations):
        run_simulation(root, env, model)

In [9]:
def run_simulation(root, env, model):
    current_node = root
    visited_edges = []
    while True:
        sum_actions_played = current_node.visit_count()
        if not current_node.edges:
            print(current_node.__dict__)
        edge = max(current_node.edges, key=lambda e: e.uct(sum_actions_played))
        visited_edges.insert(0, edge)
        # when we get the value of the node after the edge we need to multiply it by -1
        # because the edge is played by the other player and we want to measure
        # Q(last_state(viewed from the former player), action)
        if edge.visit_count == 0:
            new_node = edge.open_edge(current_node.state, env, model)
            backprop(val = -1 * new_node.value, edges = visited_edges)
            break
        current_node = edge.node_after
        if env.is_terminal(current_node.state):
            backprop(val = -1 * current_node.value, edges = visited_edges)
            break

In [10]:
def backprop(val, edges):
    for edge in edges:
        edge.visit_count += 1
        edge.Q += val
        val *= -1

In [11]:
class MCTSEdge():
    def __init__(self, action, probability):
        self.action = action
        self.probability = probability
        self.visit_count = 0
        self.Q = 0
        self.node_after = None

    def open_edge(self, last_state, env, model):
        if self.node_after:
            raise f"edge was already open:\nstate: {state},\naction: {self.action}\nstate_after: {self.state_after}"
        new_state, new_state_action_mask, is_terminal = env.play_action(last_state, self.action)
        self.node_after = MCTSNode(new_state, new_state_action_mask, is_terminal, env, model)

        return self.node_after

    def uct(self, sum_action_played_from_father_state ,c_puct = 1):
        U = c_puct * self.probability * (((sum_action_played_from_father_state + 1) ** 0.5) / (1 + self.visit_count))
        Q = self.Q / (self.visit_count + 1)
        return Q + U


In [12]:
class MCTSNode():
    def __init__(self, state, action_mask, is_terminal, env, model, device = config["device"]):
        self.state = state
        self.action_mask = action_mask
        self.edges = []
        self.value = 0

        if is_terminal:
            self.init_terminal(env)
        else:
            self.init_non_terminal(state, action_mask, env, model, device)


    def init_non_terminal(self, new_state, new_state_action_mask, env , model, device):
        new_state_tensor = env.states_to_tensor([new_state], device=device)
        new_state_action_mask_tensor = env.action_masks_to_tensor([new_state_action_mask], device=device)

        value, probs = model(new_state_tensor, new_state_action_mask_tensor)

        self.value = value.cpu().item()

        for i in range(len(new_state_action_mask)):
            if new_state_action_mask[i]:
                self.edges.append(MCTSEdge(i, probs[0][i].cpu().item()))
        if not self.edges:
            raise f"this state do not possess any actions, and should be terminal or possess actions\n{new_state_tensor.__dict__}"

    def init_terminal(self, env):
        self.value = env.get_value(self.state)

    def visit_count(self):
        return sum([edge.visit_count for edge in self.edges])

    def extract_policy(self, tau):
        action_count = [0] * len(self.action_mask) # we create a list in the size of action_space
        for edge in self.edges:
            action_count[edge.action] = edge.visit_count ** tau

        policy_tensor = tensor(action_count)
        policy_tensor = policy_tensor / sum(policy_tensor)
        return policy_tensor

let's try to make a node

In [13]:
import time
t1 = time.time()
env = EffConnect4
model = AlphaFourModel(config['num_blocks'], config['hidden_size']).to(config['device'])
root_state = env.reset()

root_node = MCTSNode(root_state, env.get_action_mask(state[0]), False, env, model)
print(root_node.edges[0].uct(0))
print([e.probability for e in root_node.edges])

run_n_simulations(root_node, 20, env, model)

for edge in root_node.edges:
    print(edge.__dict__)

print(root_node.extract_policy(2))

0.13138513267040253
[0.13138513267040253, 0.06733889877796173, 0.13779036700725555, 0.19699271023273468, 0.12018755823373795, 0.1973654329776764, 0.14893987774848938]
{'action': 0, 'probability': 0.13138513267040253, 'visit_count': 1, 'Q': -0.5141957402229309, 'node_after': <__main__.MCTSNode object at 0x7a927453c760>}
{'action': 1, 'probability': 0.06733889877796173, 'visit_count': 1, 'Q': -0.5036095380783081, 'node_after': <__main__.MCTSNode object at 0x7a927455cca0>}
{'action': 2, 'probability': 0.13779036700725555, 'visit_count': 1, 'Q': -0.4948214292526245, 'node_after': <__main__.MCTSNode object at 0x7a927453f040>}
{'action': 3, 'probability': 0.19699271023273468, 'visit_count': 3, 'Q': 0.5038596093654633, 'node_after': <__main__.MCTSNode object at 0x7a92745470a0>}
{'action': 4, 'probability': 0.12018755823373795, 'visit_count': 1, 'Q': -0.503229558467865, 'node_after': <__main__.MCTSNode object at 0x7a927453ca90>}
{'action': 5, 'probability': 0.1973654329776764, 'visit_count': 1

# Replay Buffer and terminal states Replay Buffer

In [14]:
class ReplayBuffer:
    def __init__(self, mem_size, state_shape, n_actions):
        self.mem_counter = 0
        self.max_size = mem_size
        self.states = torch.zeros((mem_size, *state_shape), dtype = torch.float32)
        self.policies = torch.zeros((mem_size, n_actions), dtype = torch.float32)
        self.values = torch.zeros((mem_size, 1), dtype = torch.float32)
        self.masks = torch.zeros((mem_size, n_actions), dtype = torch.float32)

    def sample(self, batch_size):
        max_index = min(self.mem_counter, self.max_size)
        indexes = np.array(random.sample(range(max_index), min(batch_size, max_index)), dtype = np.intc)

        return self.states[indexes], self.policies[indexes], self.values[indexes], self.masks[indexes]

    def enter(self, states, policies, values, masks):
        for i in range(len(states)):
            self.states[self.mem_counter % self.max_size] = states[i]
            self.policies[self.mem_counter % self.max_size] = policies[i]
            self.values[self.mem_counter % self.max_size] = values[i]
            self.masks[self.mem_counter % self.max_size] = masks[i]
            self.mem_counter += 1

# Game playing function

In [15]:
def delete_edges_except(mcts_node, action):
    mcts_node.edges = [e for e in mcts_node.edges if e.action == action]

In [16]:
def write_states_masks_policy_value_to_replay_buffer(state_policy_value, replay_buffer, env, device=config['device']):
    states, policies, values = list(zip(*state_policy_value))
    policies_tensor = torch.stack(policies)
    values_tensor = torch.stack(values)
    states_tensor = env.states_to_tensor(states, device)
    masks_tensor = env.action_masks_to_tensor([env.get_action_mask(state) for state in states], device)

    replay_buffer.enter(states_tensor, policies_tensor, values_tensor, masks_tensor)

In [17]:
def play_game(num_simulations_per_turn, tau, env, model, replay_buffer, show_game = False):
    state_policy_value = []

    root_state = env.reset()
    current_node = MCTSNode(root_state, env.get_action_mask(state[0]), False, env, model)

    while not env.is_terminal(current_node.state):
        if show_game:
            print("first the state was:")
            env.print_board(current_node.state)
            print(env.get_action_mask(current_node.state[0]))
        run_n_simulations(current_node, num_simulations_per_turn - current_node.visit_count(), env, model)
        policy_dist = current_node.extract_policy(tau)
        action = torch.multinomial(policy_dist, num_samples=1).item()

        if show_game: print(f"action ditribution is {policy_dist}")
        if show_game: print(f"then action {action} was taken")

        # this is in order that the python garbage collector will collect the other edges
        delete_edges_except(current_node, action)

        state_policy_value.insert(0, [current_node.state, policy_dist, None])

        edge = next((edge for edge in current_node.edges if edge.action == action), -1)
        current_node = edge.node_after

    val = -1 * env.get_value(current_node.state)
    for i in range(len(state_policy_value)):
        state_policy_value[i][2] = tensor(val)
        val *= -1

    write_states_masks_policy_value_to_replay_buffer(state_policy_value, replay_buffer, env)

    return current_node



In [18]:
env = EffConnect4
model = AlphaFourModel(config['num_blocks'], config['hidden_size']).to(config['device'])
rb = ReplayBuffer(1000, state_shape=(43,), n_actions=7)

end_node = play_game(20, 2, env, model, rb, True)

env.print_board(end_node.state)

first the state was:
Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
--------------------
[1, 1, 1, 1, 1, 1, 1]
action ditribution is tensor([0.0056, 0.0056, 0.0056, 0.0056, 0.9494, 0.0225, 0.0056])
then action 4 was taken
first the state was:
Player 2 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 1 , 0 , 0
--------------------
[1, 1, 1, 1, 1, 1, 1]
action ditribution is tensor([0.0056, 0.0056, 0.0056, 0.0056, 0.9494, 0.0225, 0.0056])
then action 5 was taken
first the state was:
Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 1 , 2 , 0
--------------------
[1

# Training loop

In [19]:
def learn_from_batch(model, optimizer, batch):
    optimizer.zero_grad()

    states, policies, values, masks = batch
    values_pred, probs_pred = model(states, masks)

    policy_loss = torch.nn.functional.kl_div(probs_pred, policies)

    value_loss = torch.nn.functional.mse_loss(values_pred, values)

    l2_reg = torch.tensor(0.)
    for param in model.parameters():
        l2_reg += torch.norm(param, 2)

    loss = 0.1 * l2_reg - policy_loss + value_loss

    loss.backward()
    optimizer.step()

In [20]:
def training_loop(num_epochs, env, rb, model, optimizer):
    batch_size = config['batch_size']
    num_games_per_epoch = config['num_games_per_epoch']
    num_simulations_per_turn = config['mcts_iterations']
    num_training_steps_after_game = config['num_training_steps_after_game']
    tau = config['initial_tau']
    for i in range(num_epochs):
        for j in range(num_games_per_epoch):
            if j == 0:
                print(f"Game No.{i}")
                print(3*(20*"-" + "\n"))
                play_game(num_simulations_per_turn, tau, env, model, rb, True)
            play_game(num_simulations_per_turn, tau, env, model, rb)

            for _ in range(num_training_steps_after_game):
                batch = rb.sample(batch_size)
                learn_from_batch(model, optimizer, batch)

        tau = min(tau * config['tau_multiplier'], config['max_tau'])
        num_simulations_per_turn += config['num_mcts_iterations_increase']

        model.save(optimizer=optimizer, file_name="alpha_four", version=i)


In [21]:
env = EffConnect4
rb = ReplayBuffer(1000, state_shape=(43,), n_actions=7)

model = AlphaFourModel(config['num_blocks'], config['hidden_size']).to(config['device'])
optimizer = optim.Adam(model.parameters(), lr = config['lr'])

In [22]:
training_loop(12, env, rb, model, optimizer)

Game No.0
--------------------
--------------------
--------------------

first the state was:
Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
--------------------
[1, 1, 1, 1, 1, 1, 1]
action ditribution is tensor([0.2068, 0.0969, 0.1016, 0.1755, 0.1207, 0.2015, 0.0969])
then action 0 was taken
first the state was:
Player 2 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
1 , 0 , 0 , 0 , 0 , 0 , 0
--------------------
[1, 1, 1, 1, 1, 1, 1]
action ditribution is tensor([0.2177, 0.1018, 0.0971, 0.1656, 0.1257, 0.1809, 0.1112])
then action 1 was taken
first the state was:
Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0




first the state was:
Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
--------------------
[1, 1, 1, 1, 1, 1, 1]
action ditribution is tensor([0.0142, 0.8501, 0.0547, 0.0177, 0.0273, 0.0258, 0.0102])
then action 1 was taken
first the state was:
Player 2 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 1 , 0 , 0 , 0 , 0 , 0
--------------------
[1, 1, 1, 1, 1, 1, 1]
action ditribution is tensor([3.9642e-04, 6.5531e-04, 9.9672e-01, 6.5531e-04, 5.1777e-04, 3.9642e-04,
        6.5531e-04])
then action 2 was taken
first the state was:
Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 1 , 2 , 

# Playing against it

In [23]:
def play_against_comp(num_simulations_per_turn, tau, env, model):
    count = 0
    root_state = env.reset()
    current_node = MCTSNode(root_state, env.get_action_mask(state[0]), False, env, model)

    to_start = int(input("do you want to start? y/n") == 'n')

    while not env.is_terminal(current_node.state):
        env.print_board(current_node.state)

        if count % 2 == to_start:
            action = int(input("pls enter your move:")) - 1
            delete_edges_except(current_node, action)
            edge = next((edge for edge in current_node.edges if edge.action == action), -1)
            if edge.node_after:
                current_node = edge.node_after
            else:
                current_node = edge.open_edge(current_node.state, env, model)

        else:
            run_n_simulations(current_node, num_simulations_per_turn - current_node.visit_count(), env, model)
            policy_dist = current_node.extract_policy(tau)
            action = torch.multinomial(policy_dist, num_samples=1).item()
            delete_edges_except(current_node, action)
            edge = next((edge for edge in current_node.edges if edge.action == action), -1)
            current_node = edge.node_after

        print(f"then action {action} was taken")

        count += 1

    if count % 2 == to_start:
        print("you lost :(")
    else:
        print("you won :)")

    return current_node

In [26]:
env = EffConnect4
model = AlphaFourModel(config['num_blocks'], config['hidden_size']).to(config['device'])
optimizer = optim.Adam(model.parameters(), lr = config['lr'])

model.load(optimizer, "alpha_four", version=8)

end_node = play_against_comp(1600, 2, env, model)
env.print_board(end_node.state)

  self.load_state_dict(torch.load(os.path.join(config['saving_dir'], file_name + '_model_' + f'_version_{version}')))
  optimizer.load_state_dict(torch.load(os.path.join(config['saving_dir'], file_name + '_optimizer_' + f'_version_{version}')))


loading model and optimizer from: alpha_four version: 8
do you want to start? y/nn
Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
--------------------
then action 4 was taken
Player 2 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 1 , 0 , 0
--------------------
pls enter your move:4
then action 3 was taken
Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 2 , 1 , 0 , 0
--------------------
then action 4 was taken
Player 2 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 

# Games I played against it

In [31]:
end_node = play_against_comp(1600, 2, env, model)
env.print_board(end_node.state)

do you want to start? y/nn
Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
--------------------
then action 4 was taken
Player 2 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 1 , 0 , 0
--------------------
pls enter your move:4
then action 3 was taken
Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 2 , 1 , 0 , 0
--------------------
then action 4 was taken
Player 2 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 1 , 0 , 0
0 , 0 , 0 , 2 , 1 , 0 , 0
--------

In [29]:
end_node = play_against_comp(1600, 2, env, model)
env.print_board(end_node.state)

do you want to start? y/ny
Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
--------------------
pls enter your move:4
then action 3 was taken
Player 2 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 1 , 0 , 0 , 0
--------------------
then action 6 was taken
Player 1 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 1 , 0 , 0 , 2
--------------------
pls enter your move:4
then action 3 was taken
Player 2 to play:
board:
--------------------
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 0 , 0 , 0 , 0
0 , 0 , 0 , 1 , 0 , 0 , 0
0 , 0 , 0 , 