In [1]:
import torch, gym, time
import numpy as np
import matplotlib.pyplot as plt
from joblib import Parallel, delayed

from src.CNN import CNN

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from warnings import filterwarnings
filterwarnings("ignore")

In [3]:
def play_game(black_player=None, white_player=None, max_moves=300):
    go_env = gym.make('gym_go:go-v0', size=5, komi=0, reward_method='heuristic')
    go_env.reset()

    if black_player and white_player:
        go_env.step(go_env.uniform_random_action())
        go_env.step(go_env.uniform_random_action())

    for _ in range(max_moves):
        # Player 1's turn
        if go_env.done: break
        if black_player: 
            moves = black_player.forward(go_env.state()).detach().cpu().numpy() * go_env.valid_moves()
            go_env.step(moves.argmax())
        else: go_env.step(go_env.uniform_random_action())

        # Player 2's turn
        if go_env.done: break
        if white_player: 
            moves = white_player.forward(go_env.state()).detach().cpu().numpy() * go_env.valid_moves()
            go_env.step(moves.argmax())
        else: go_env.step(go_env.uniform_random_action())
    return go_env.reward()

In [4]:
def get_win_percent(black_player=None, white_player=None, n_games=1000, max_moves=300):
    black, white, draws = 0, 0, 0
    results = Parallel(n_jobs=6)(delayed(play_game)(black_player, white_player, max_moves) for _ in range(n_games))
    for i in range(n_games):
        res = results[i]
        if res > 0: black += 1
        elif res < 0: white += 1
        else: draws += 1
    return (black/n_games)*100, (white/n_games)*100, (draws/n_games)*100

In [7]:
models = []
for i in range(5):
    model = CNN()
    model.to(device)
    model.load_state_dict(torch.load(f"src/models/1000-games/{i}-times.pth"))
    models.append(model)

In [8]:
results = np.zeros((5, 5), dtype=tuple)

for black_index, black_player in enumerate(models):
    print(f"Model {black_index}:")
    for white_index, white_player in enumerate(models):
        start = time.time()
        (black, white, draw) = get_win_percent(black_player, white_player, 167)
        stop = time.time()
        results[black_index][white_index] = black, white, draw
        print(f"  {white_index}: b={black:2.2f}%, w={white:2.2f}%, d={draw:2.2f}%, took {stop-start:4.2f} seconds")

Model 0:
  0: b=46.71%, w=51.50%, d=1.80%, took 7.09 seconds
  1: b=48.50%, w=49.70%, d=1.80%, took 6.27 seconds
  2: b=40.72%, w=59.28%, d=0.00%, took 7.28 seconds
  3: b=41.32%, w=55.69%, d=2.99%, took 7.43 seconds
  4: b=37.13%, w=59.88%, d=2.99%, took 9.83 seconds
Model 1:
  0: b=43.71%, w=56.29%, d=0.00%, took 6.35 seconds
  1: b=32.93%, w=63.47%, d=3.59%, took 7.65 seconds
  2: b=35.93%, w=61.08%, d=2.99%, took 8.17 seconds
  3: b=38.92%, w=59.88%, d=1.20%, took 6.09 seconds
  4: b=40.72%, w=58.08%, d=1.20%, took 8.26 seconds
Model 2:
  0: b=34.73%, w=65.27%, d=0.00%, took 7.93 seconds
  1: b=38.32%, w=56.89%, d=4.79%, took 7.94 seconds
  2: b=40.72%, w=58.68%, d=0.60%, took 6.55 seconds
  3: b=38.92%, w=61.08%, d=0.00%, took 7.14 seconds
  4: b=40.12%, w=57.49%, d=2.40%, took 7.82 seconds
Model 3:
  0: b=36.53%, w=59.88%, d=3.59%, took 7.74 seconds
  1: b=37.72%, w=59.28%, d=2.99%, took 7.00 seconds
  2: b=37.72%, w=60.48%, d=1.80%, took 8.79 seconds
  3: b=32.93%, w=64.67%, d=2

In [9]:
print(results[0][3])

(41.31736526946108, 55.688622754491014, 2.9940119760479043)
