# Testing the different models produced in this project

In [3]:
from joblib import Parallel, delayed
import gym
from src.MCTS import Monte_Carlo_Tree_Search
from gym_go.gogame import random_action
from copy import deepcopy

In [2]:
BOARD_SIZE = 5
from warnings import filterwarnings
filterwarnings("ignore")

cuda:0


### Testing win rates for MCTS vs Random

In [3]:
def mcts_black_random_white(mcts : Monte_Carlo_Tree_Search, go_env: gym.Env):
    go_env.reset()
    done = go_env.done
    turn_nr = 0
    while not done:
        node = mcts.get_move_from_env(go_env)
        _, _, done, _ = go_env.step(node.action)
        turn_nr += 1

        if done:
            continue

        action = random_action(go_env.state())
        _, _, done, _ = go_env.step(action)

        if turn_nr > 300:
            break
    
    return go_env

def random_black_mcts_white(mcts : Monte_Carlo_Tree_Search, go_env: gym.Env):
    go_env.reset()
    done = go_env.done
    turn_nr = 0
    while not done:
        action = random_action(go_env.state())
        _, _, done, _ = go_env.step(action)

        if done:
            continue

        node = mcts.get_move_from_env(go_env)
        _, _, done, _ = go_env.step(node.action)
        turn_nr += 1
        
        if turn_nr > 300:
            break
    
    return go_env

In [4]:

def play_black_game():
    mcts_test = Monte_Carlo_Tree_Search(BOARD_SIZE, None)
    env = mcts_black_random_white(mcts_test, deepcopy(mcts_test.env))
    if env.reward() > 0:
        return 1
    return 0

def play_white_game():
    mcts_test = Monte_Carlo_Tree_Search(BOARD_SIZE, None)
    env = random_black_mcts_white(mcts_test, deepcopy(mcts_test.env))
    if env.reward() < 0:
        return 1
    return 0

games = 100

mcts_black_wins = Parallel(n_jobs=4)(delayed(play_black_game)() for _ in range(games))
print("Win rate as black:", ((sum(mcts_black_wins) / games) * 100), "%")
mcts_white_wins = Parallel(n_jobs=4)(delayed(play_white_game)() for _ in range(games))
print("Win rate as white:", ((sum(mcts_white_wins) / games) * 100), "%")

Win rate as black: 0.0 %
Win rate as white: 90.0 %
