In [1]:
import torch, gym, time
import numpy as np
import matplotlib.pyplot as plt
from joblib import Parallel, delayed

from src.CNN import CNN

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from warnings import filterwarnings
filterwarnings("ignore")

In [3]:
def play_game(black_player=None, white_player=None, max_moves=300):
    go_env = gym.make('gym_go:go-v0', size=5, komi=0, reward_method='heuristic')
    go_env.reset()

    if black_player and white_player:
        go_env.step(go_env.uniform_random_action())
        go_env.step(go_env.uniform_random_action())

    for _ in range(max_moves):
        # Player 1's turn
        if go_env.done: break
        if black_player: 
            moves = black_player.forward(go_env.state()).detach().cpu().numpy() * go_env.valid_moves()
            go_env.step(moves.argmax())
        else: go_env.step(go_env.uniform_random_action())

        # Player 2's turn
        if go_env.done: break
        if white_player: 
            moves = white_player.forward(go_env.state()).detach().cpu().numpy() * go_env.valid_moves()
            go_env.step(moves.argmax())
        else: go_env.step(go_env.uniform_random_action())
    return go_env.reward()

In [8]:
def get_win_percent(black_player=None, white_player=None, n_games=800, max_moves=300):
    black, white, draws = 0, 0, 0
    results = Parallel(n_jobs=6)(delayed(play_game)(black_player, white_player, max_moves) for _ in range(n_games))
    for i in range(n_games):
        res = results[i]
        if res > 0: black += 1
        elif res < 0: white += 1
        else: draws += 1
    return (black/n_games)*100, (white/n_games)*100, (draws/n_games)*100

In [5]:
models = []
for i in range(5):
    model = CNN()
    model.to(device)
    model.load_state_dict(torch.load(f"src/models/1000-games/{i}-times.pth"))
    models.append(model)

In [9]:
# results = Parallel(n_jobs=3)(delayed(get_win_percent)(b, w) for b in models for w in models)
brandom = Parallel(n_jobs=3)(delayed(get_win_percent)(b) for b in models)
wrandom = Parallel(n_jobs=3)(delayed(get_win_percent)(None, w) for w in models)


In [10]:
print(brandom)
print(wrandom)

[(83.75, 16.25, 0.0), (87.0, 13.0, 0.0), (87.0, 13.0, 0.0), (85.5, 14.499999999999998, 0.0), (87.0, 13.0, 0.0)]
[(14.75, 85.25, 0.0), (13.750000000000002, 86.25, 0.0), (17.25, 82.75, 0.0), (18.0, 82.0, 0.0), (14.000000000000002, 86.0, 0.0)]
