In [1]:
import random

In [2]:
import model
import selfplay
import game
import config

In [7]:
from visualise import vis

In [3]:
def get_probs_with_log(self, g: game.GameState):
    "runs N simulations and returns visit probabilities at root"
    root = selfplay.Node(g)
    for i in range(config.num_simulate):
        self.simulate(root)
    # self.print_tree(root)
    return [
        root.children[i].visit/(root.visit-1) if i in root.children else 0.0
        for i in range(config.num_actions)
    ]

def my_print_tree(self, n: selfplay.Node, depth=0, all=True):
    if n.visit:
        a = ">>"
    else:
        a = "--"
    if depth < 2 and (all or n.visit):
        print(depth*" " + f"{a} position: {n.state}, prob: {n.prob:0.2}, value: {n.value:0.2},"
                            f" visit: {n.visit}, value_sum: {n.value_sum:0.2}")
        for ch in n.children.values():
            self.print_tree(ch, depth+1)

selfplay.MCTS.print_tree = my_print_tree
selfplay.MCTS.get_probs = get_probs_with_log

In [4]:
net = model.Model()
mcts = selfplay.MCTS(net)

In [10]:
g = game.GameState()
history = []
move_history = []

while not g.terminated():
    probs = mcts.get_probs(g)
    print(f"g={g.to_image()} {probs=} {sum(probs)=}")
    print(vis(str(g.to_image())))
    history.append([g, probs, 0])
    # choose action acc to probs
    action = random.choices(list(range(config.num_actions)), probs)[0]
    print(f"{action=}")
    move_history.append(action)
    g = g.next_state(action)

print(f"g={g.to_image()} {g.terminated()=}")
print(vis(str(g.to_image())))


g=[[0 0 0 0 0 0 0 0 0]] probs=[0.05263157894736842, 0.10526315789473684, 0.2631578947368421, 0.05263157894736842, 0.05263157894736842, 0.15789473684210525, 0.05263157894736842, 0.05263157894736842, 0.21052631578947367] sum(probs)=0.9999999999999998
. . .
. . .
. . .

action=8
g=[[0 0 0 0 0 0 0 0 1]] probs=[0.0, 0.05263157894736842, 0.10526315789473684, 0.10526315789473684, 0.0, 0.7368421052631579, 0.0, 0.0, 0.0] sum(probs)=1.0
. . .
. . .
. . X

action=5
g=[[ 0  0  0  0  0 -1  0  0  1]] probs=[0.0, 0.0, 0.6842105263157895, 0.15789473684210525, 0.0, 0.0, 0.10526315789473684, 0.05263157894736842, 0.0] sum(probs)=1.0
. . .
. . O
. . X

action=2
g=[[ 0  0  1  0  0 -1  0  0  1]] probs=[0.0, 0.0, 0.0, 0.3157894736842105, 0.0, 0.0, 0.5789473684210527, 0.10526315789473684, 0.0] sum(probs)=1.0
. . X
. . O
. . X

action=6
g=[[ 0  0  1  0  0 -1 -1  0  1]] probs=[0.10526315789473684, 0.10526315789473684, 0.0, 0.5789473684210527, 0.10526315789473684, 0.0, 0.0, 0.10526315789473684, 0.0] sum(probs)=1

In [6]:
# log game outcome
winner = g.winner()
# print(move_history, winner)d
# fill in the value function
if winner != 0:
    for state in history:
        g = state[0]
        if g.player() == winner:
            state[2] = -1
        else:
            state[2] = 1