In [2]:
import numpy as np
from game import Connect4
from network import ConnectNet

In [100]:
class Node():
    # Data structure to keep track of our search
    def __init__(self, state, parent=None, parent_action=None):
        self.visits = 0 # Initialize at 0 because we will set +1 after a back up
        self.state = state
        self.children = []
        self.children_move = []
        self.valid_actions = np.where(state[0,:] == 0)[0]
        self.parent = parent
        self.parent_action = parent_action
        
        # Stats mcts
        
        priors = None
        q_value = [0 for _ in range(7)]
        total_q = [0 for _ in range(7)]
        N_a = [0 for _ in range(7)]
        

    def add_child(self ,child_state, action):
        child = Node(child_state, self, action)
        self.children.append(child)
        self.children_move.append(action)

    def update(self, reward, action):
        self.total_q[action] += reward
        self.N_a[action] += 1
        self.q_value[action] = self.total_q[a] / self.N_a[action]

    def fully_explored(self):
        if len(self.children) == len(self.valid_actions):
            return True
        return False

In [99]:
def mcts(n_iter, root, turn, factor):
    for _ in range(n_iter):
        leaf, turn = selection_policy(root ,turn, factor)
        reward = rollout_policy(leaf, turn)
        backup(leaf, reward, turn)

    out = best_child(root, 0)
    return out

def selection_policy(node, turn , factor, stochastic=True):
    """
    We select the best branch until reaching a leaf node (non explored one or terminal one).
    """
    while game.check_winner(node.state) == 0:
        if (node.fully_explored() == False):
            return expand(node, turn), -turn
        else:
            node = select_child(node, factor, stochastic=stochastic)
            turn *= -1
    return node, turn

def expand(node, turn):
    not_tried_moves = [a for a in node.valid_actions if a not in node.children_move]
    move = np.random.choice(not_tried_moves)
    new_state, turn = game.step(move, node.state, turn)

    node.addChild(new_state, move)
    return node.children[-1]

def select_child(node, exploration_constant, stochastic):
    best_score = -np.inf
    for c in node.children:
        exploit = c.reward / c.visits
        explore = np.sqrt(np.log(2.0*node.visits)/float(c.visits))
        score = exploit + exploration_constant*explore
        if score > best_score:
            bestChild = c
            bestscore = score 
    return bestChild

def rollout_policy(node, turn):
    state = node.state
    while game.check_winner(state) == 0:
        random_action = np.random.choice(game.allowed_actions(state))
        state, turn = game.step(random_action, state, turn)
    return game.check_winner(state) 

def backup(node, reward, turn):
    while node != None:
        node.visits += 1 
        node.reward -= turn*reward
        node = node.parent
        turn *= -1
    return


In [13]:
game = Connect4()

In [14]:
game.init_state

array([[0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.]])

In [104]:
c1, _ = game.step(3, game.init_state, 1)

In [105]:
c1

array([[0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0.]])

In [101]:
n = Node(game.init_state)

In [106]:
n.add_child(c1, 3)

In [110]:
n.children[0].parent

<__main__.Node at 0x7ff5a05239b0>

In [98]:
c = mcts(maxIter=10, root=n, turn=1, factor=1)

Children visited
[]
Children visited
[1]
Children visited
[1, 1]
Children visited
[1, 1, 1]
Children visited
[1, 1, 1, 1]
Children visited
[1, 1, 1, 1, 1]
Children visited
[1, 1, 1, 1, 1, 1]
Children visited
[1, 1, 1, 1, 1, 1, 1]
Children visited
[1, 1, 2, 1, 1, 1, 1]
Children visited
[1, 1, 2, 2, 1, 1, 1]


In [90]:
n.fully_explored()

False

In [91]:
len(n.children)

2

In [92]:
for ch in n.children:
    print(ch.reward)

1.0
1.0


In [93]:
bestChild(n, 1).state

array([[0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0.]])