In [1]:
import numpy as np
from game import Connect4
from network import ConnectNet

In [2]:
import torch

In [3]:
class Node:
    def __init__(self, state, turn):
        self.id = str(state)
        self.state = state
        self.turn = turn
        self.edges = []
        
    def is_leaf(self):
        return len(self.edges) == 0
    
class Edge:
    def __init__(self, parent, child, prior, action):
        self.id = (parent.id, str(action))
        self.parent = parent
        self.child = child
        self.turn = parent.turn
        self.action = action
        
        self.N = 0
        self.W = 0
        self.Q = 0
        self.P = prior

In [346]:
class MCTS:
    def __init__(self, root, turn, game, net, args):
        self.root = Node(root, turn)
        self.game = game
        self.net = net
        self.args = args
        self.tree = {}
        self.tree[str(root)] = self.root
        
    
    def select_leaf(self):
        branch = []
        current_node = self.root
        
        while current_node.is_leaf() == False:
            best_qu = -np.inf
            best_edge = None

            Nb = 0
            for edge in current_node.edges:
                Nb += edge.N
            
            if Nb == 0:
                Nb += 1
            
            # Adding dirichlet noise to the root node for exploration
            if current_node == self.root:
                nu = np.random.dirichlet([self.args['alpha']] * len(current_node.edges))
                epsilon = self.args['epsilon']
                
            else:
                epsilon = 0
                nu = [0] * len(current_node.edges)

            # select the next node
            for i, edge in enumerate(current_node.edges):
                u = self.args['cpuct'] * ((1-epsilon) * edge.P + epsilon * nu[i]) * np.sqrt(Nb) / (1 + edge.N)
                q = edge.Q
            
                if q + u > best_qu:
                    best_qu = q + u
                    best_edge = edge
#                print(i,"q" , q, "u", u, "Nb", Nb)
                
            branch.append(best_edge)
            current_node = best_edge.child
        
        winner = self.game.check_winner(current_node.state)
        
        return current_node, winner, branch
    
    def expand_evaluate(self, leaf, winner):
        if winner == 1:
            return 1.
        elif winner == -1:
            return -1.
        elif winner == 2:
            return 1e-4
        else:
            probs, value = self.net(torch.tensor(leaf.state).float().to(self.args['device']))
            probs, value = probs.detach().squeeze().cpu().numpy(), value.detach().squeeze().cpu().numpy()
            valid_actions = np.array([1 if i in self.game.allowed_actions(leaf.state) else 0 for i in range(7)])
            # Masking invalid moves and normalize
            probs = probs * valid_actions
            probs = probs / np.sum(probs)
            
            # Expand the tree according to all possible actions
            for i, a in enumerate(self.game.allowed_actions(leaf.state)):
                next_state, next_turn = self.game.step(a, leaf.state, leaf.turn)
                if str(next_state) not in self.tree.keys():
                    node = Node(next_state, next_turn)
                    self.tree[str(next_state)] = node
                else:
                    node = self.tree[str(next_state)]
                
                new_edge = Edge(leaf, node, probs[i], a)
                leaf.edges.append(new_edge)
        return value

    def backup(self, leaf, value, branch):
        leaf_turn = leaf.turn
#        print("leaf turn", leaf_turn)
        for edge in branch:
            turn = edge.turn
            if turn == leaf_turn:
                sgn = 1
            else:
                sgn = -1
                
            edge.N += 1
            edge.W += sgn * value
            edge.Q = edge.W / edge.N
    
    def search(self):
        # Select the branch
        leaf, winner, branch = self.select_leaf()
        # Evaluate and expand
        value = self.expand_evaluate(leaf, winner)
        # Back up the value in the branch
        self.backup(leaf, value, branch)
    
    def act(self, tau):
        for s in range(self.args['n_sim']):
            self.search()
        edges = self.root.edges
        pi = [0] * 7
        values = [0] * 7
        for edge in edges:
            pi[edge.action] = edge.N ** (1/tau)
            values[edge.action] = edge.Q
        pi = np.array(pi)
        pi /= pi.sum()
        return pi, values

    def rollout_policy(self, node, turn):
        state = node.state
        while game.check_winner(state) == 0:
            random_action = np.random.choice(game.allowed_actions(state))
            state, turn = game.step(random_action, state, turn)
        return game.check_winner(state) 

    
    def set_root(self, node_id):
        self.root = self.tree[node_id]

In [347]:
jeu = Connect4()

In [348]:
net = ConnectNet()

In [349]:
net.load_state_dict(torch.load('networks_saved/net_init.pth',
                               map_location='cpu'))

In [354]:
args = {
    "alpha":0.03,
    "epsilon":0.2,
    "cpuct":1.,
    "device":"cpu",
    "n_sim":10,
}

In [355]:
mcts = MCTS(jeu.init_state, 1, jeu, net, args)

In [356]:
pi, values = mcts.act(1.)

In [357]:
pi

array([0.11111111, 0.        , 0.        , 0.22222222, 0.66666667,
       0.        , 0.        ])

In [358]:
for e in mcts.root.edges:
    print(e.Q)

0.7561668157577515
0
0
-0.25872988253831863
0.4237784097592036
0
0


In [165]:
a1,_ = jeu.step(0, jeu.init_state, 1)

In [167]:
a4,_ = jeu.step(3, jeu.init_state, 1)

In [166]:
net(torch.tensor(a1).float())

(tensor([[0.0626, 0.0357, 0.0063, 0.6230, 0.0010, 0.1753, 0.0960]],
        grad_fn=<ExpBackward>), tensor([[-0.7562]], grad_fn=<TanhBackward>))

In [168]:
net(torch.tensor(a4).float())

(tensor([[0.1201, 0.0305, 0.7931, 0.0255, 0.0023, 0.0256, 0.0028]],
        grad_fn=<ExpBackward>), tensor([[0.1401]], grad_fn=<TanhBackward>))

In [320]:
a43,_ = jeu.step(2, a4, -1)

In [322]:
net(torch.tensor(a43).float())

(tensor([[0.2767, 0.0425, 0.0249, 0.2779, 0.0484, 0.0199, 0.3097]],
        grad_fn=<ExpBackward>), tensor([[-0.3774]], grad_fn=<TanhBackward>))