In [1]:
import json
import torch
import random

from tqdm import tqdm

from game import Board, Game
from mcts_pure import MCTSPlayer as MCTS_Pure
from mcts_alphaZero import MCTSPlayer
from nn_architecture import PolicyValueNet
from human_play import Random

In [2]:
data = './data/battle_example.json'
n_rounds = 3

f = open(data, encoding='utf-8')
data = json.loads(f.read())

width = data["board"]["board_width"]
height = data["board"]["board_height"]
n_in_row = data["board"]["n_in_row"]
board = Board(width=width, height=height, n_in_row=n_in_row)
game = Game(board)

# ############### human VS AI ###################
player1_data = data["player1"]
player1_policy = PolicyValueNet(width, height, player1_data["nn_information"], model_file=player1_data["model_path"])
player1_policy_num_params = player1_policy.num_params()
player1 = MCTSPlayer(player1_policy.policy_value_fn, c_puct=5, n_playout=400)  # set larger n_playout for better performance
print(f"Number of Parameters(CNN): {player1_policy_num_params}")

player2_data = data["player2"]
player2_policy = PolicyValueNet(width, height, player2_data["nn_information"], model_file=player2_data["model_path"])
player2_policy_num_params = player2_policy.num_params()
player2 = MCTSPlayer(player2_policy.policy_value_fn, c_puct=5, n_playout=400)
print(f"Number of Parameters(GNN): {player2_policy_num_params}")

player1_wins = 0
player2_wins = 0
ties = 0
# set start_player=0 for human first
# for i in tqdm(range(n_rounds)):
#     winner = game.start_play(player1, player2, start_player=random.randint(0, 1), is_shown=0)
#     if winner == 1:
#         player1_wins += 1
#     elif winner == -1:
#         ties += 1
#     else: 
#         player2_wins += 1

# print(f"Number of Rounds: {n_rounds}\t CNN wins: {player1_wins}\t GNN wins: {player2_wins}\t ties: {ties}")

Number of Parameters(CNN): 105317
Number of Parameters(GNN): 30277


In [5]:
# Compete with Random Opponent
random_player = Random()

for epoch in range(50, 301, 50):
    cnn_player_data = data["player1"]
    cnn_player_policy = PolicyValueNet(width, height, cnn_player_data["nn_information"], model_file=f'./model/cnn/epoch_{epoch}.model')
    cnn_player = MCTSPlayer(cnn_player_policy.policy_value_fn, c_puct=5, n_playout=400) 

    win_ratio_list = []
    random_wins = 0
    ties = 0
    cnn_player_wins = 0
    for i in tqdm(range(10)):
        winner = game.start_play(random_player, cnn_player, start_player=random.randint(0, 1), is_shown=0)
        if winner == 1:
            random_wins += 1
        elif winner == -1:
            ties += 1
        else: 
            cnn_player_wins += 1
    win_ratio = 1.0*(cnn_player_wins + 0.5*ties) / 10
    print(f"Epoch {epoch}, Win Ratio: {win_ratio}")
    win_ratio_list.append(win_ratio)
win_ratio_list


100%|██████████| 10/10 [04:04<00:00, 24.46s/it]


Epoch 50, Win Ratio: 1.0


100%|██████████| 10/10 [02:46<00:00, 16.68s/it]


Epoch 100, Win Ratio: 0.9


100%|██████████| 10/10 [02:39<00:00, 15.92s/it]


Epoch 150, Win Ratio: 1.0


100%|██████████| 10/10 [02:49<00:00, 16.97s/it]


Epoch 200, Win Ratio: 1.0


100%|██████████| 10/10 [02:33<00:00, 15.40s/it]


Epoch 250, Win Ratio: 1.0


100%|██████████| 10/10 [02:40<00:00, 16.04s/it]

Epoch 300, Win Ratio: 1.0





[1.0]