In [67]:
from copy import deepcopy
from math import sqrt, log
import gym, random
import numpy as np

In [72]:
BOARD_SIZE, KOMI, UCB_CONSTANT = 5, 0, 2

In [78]:
test_env = gym.make('gym_go:go-v0', size=BOARD_SIZE, komi=KOMI, reward_method='heuristic')
test_env.reset()
test_env.step(2)
# test_env.render("terminal")
# print(np.argwhere(test_env.valid_moves()).flatten())
print(test_env.reward())

25.0


1. **Selection**: Start from root R and select successive child nodes until a leaf node L is reached. The root is the current game state and a leaf is any node that has a potential child from which no simulation (playout) has yet been initiated.
2. **Expansion**: Unless L ends the game decisively (e.g. win/loss/draw) for either player, create one (or more) child nodes and choose node C from one of them. Child nodes are any valid moves from the game position defined by L.
3. **Simulation**: Complete one random playout from node C. This step is sometimes also called playout or rollout. A playout may be as simple as choosing uniform random moves until the game is decided (for example in chess, the game is won, lost, or drawn).
4. **Backpropagation**: Use the result of the playout to update information in the nodes on the path from C to R.

In [79]:
class Node():

    def __init__(self, env, parent = None):
        self.env = env
        self.parent = parent
        self.children = {}
        self.trials = 0
        self.value = 0

    def is_leaf_node(self):
        if self.env.done: return True
        return len(self.children) == 0

    def get_max_ucb_child(self, total_trials):
        best_child = self.children[0]
        best_child_ucb = best_child.ucb(total_trials)
        for child in self.children:
            child_ucb = child.ucb(total_trials)
            if child_ucb > best_child_ucb:
                best_child = child
                best_child_ucb = child.ucb
        return best_child

    def ucb(self, total_trials):
        return self.value + (UCB_CONSTANT * sqrt(log(total_trials) / self.trials))

    def __backpropagate(self, value):
        if self.parent == None: return
        self.value += value
        self.trials += 1
        self.__backpropagate(self.parent, value)

    def rollout(self):
        if self.env.done: return self.env.reward()
        rollout_env = deepcopy(self.env)
        while not rollout_env.done:
            rollout_env.step(rollout_env.uniform_random_action())
        self.__backpropagate(self, rollout_env.reward())
        return rollout_env.reward()
        

In [86]:
class MCTS():

    def __init__(self, size = BOARD_SIZE, komi = KOMI):
        env = gym.make('gym_go:go-v0', size=size, komi=komi, reward_method='heuristic')
        env.reset()
        self.root = Node(env)
        self.iterations = 0

    def run(self):
        current_node = self.root
        while True:
            
            if current_node.is_leaf_node() and current_node != self.root:
                current_node = current_node.get_max_ucb_child()
            elif current_node.trials == 0:
                current_node.rollout()
            else:
                for action in np.argwhere(current_node.env.valid_moves()).flatten():
                    child_env = deepcopy(current_node.env)
                    child_env.step(action)
                    current_node.children.add(Node(child_env, current_node))
                current_node = random.choice(current_node.children)
                current_node.rollout()

    


In [87]:
mcts = MCTS()
mcts.run()

  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")


Iteration  0
Iteration  1
Iteration  2
Iteration  3
Iteration  4
Iteration  5
Iteration  6
Iteration  7
Iteration  8
Iteration  9
Iteration 10
Iteration 11
Iteration 12
Iteration 13
Iteration 14
Iteration 15
Iteration 16
Iteration 17
Iteration 18
Iteration 19
Iteration 20
Iteration 21
Iteration 22
Iteration 23
Iteration 24
Iteration 25
Iteration 26
Iteration 27
Iteration 28
Iteration 29
Iteration 30
Iteration 31
Iteration 32
Iteration 33
Iteration 34
Iteration 35
Iteration 36
Iteration 37
Iteration 38
Iteration 39
Iteration 40
Iteration 41
Iteration 42
Iteration 43
Iteration 44
Iteration 45
Iteration 46
Iteration 47
Iteration 48
Iteration 49
Iteration 50
Iteration 51
Iteration 52
Iteration 53
Iteration 54
Iteration 55
Iteration 56
Iteration 57
Iteration 58
Iteration 59
Iteration 60
Iteration 61
Iteration 62
Iteration 63
Iteration 64
Iteration 65
Iteration 66
Iteration 67
Iteration 68
Iteration 69
Iteration 70
Iteration 71
Iteration 72
Iteration 73
Iteration 74
Iteration 75
Iteration 76

  logger.warn(
