In [None]:

import os

# This forces OpenMP to use 1 single thread, which is needed to
# prevent contention between multiple process.
os.environ['OMP_NUM_THREADS'] = '1'
# Tell numpy to only use one core.
os.environ['MKL_NUM_THREADS'] = '1'

import sys
from absl import flags

import numpy as np
import torch


FLAGS = flags.FLAGS

flags.DEFINE_integer('board_size', 5, 'Board size for Go.')
flags.DEFINE_float('komi', 7.5, 'Komi rule for Go.')
flags.DEFINE_integer(
    'num_stack',
    8,
    'Stack N previous states, the state is an image of N x 2 + 1 binary planes.',
)
flags.DEFINE_integer('num_res_blocks', 5, 'Number of residual blocks in the neural network.')
flags.DEFINE_integer('num_filters_resnet', 64, 'Number of filters for the conv2d layers in the neural network.')
flags.DEFINE_integer(
    'num_fc_units',
    128,
    'Number of hidden units in the linear layer of the neural network.',
)

flags.DEFINE_integer(
    'num_simulations',
    200,
    'Number of simulations per MCTS search, this applies to both self-play and evaluation processes.',
)

flags.DEFINE_integer(
    'num_parallel',
    6,
    'Number of leaves to collect before using the neural network to evaluate the positions during MCTS search,'
    '1 means no parallel search.',
)
flags.DEFINE_float(
    'c_puct_base',
    19652,
    'Exploration constants balancing priors vs. search values. Original paper use 19652',
)
flags.DEFINE_float(
    'c_puct_init',
    1.25,
    'Exploration constants balancing priors vs. search values. Original paper use 1.25',
)

flags.DEFINE_float(
    'default_rating',
    1500,
    'Default elo rating, change to the rating (for black) from last checkpoint when resume training.',
)
flags.DEFINE_string(
    'logs_dir',
    './logs/go/5x5/alphago_series',
    'Path to save statistics for self-play, training, and evaluation.',
)
flags.DEFINE_string('log_level', 'INFO', '')
flags.DEFINE_integer('seed', 1, 'Seed the runtime.')
# Initialize flags
FLAGS(sys.argv, known_only = True)

os.environ['BOARD_SIZE'] = str(FLAGS.board_size)

In [None]:
from alpha_zero.envs.go import GoEnv
from alpha_zero.core.pipeline import (
    set_seed,
    maybe_create_dir,
)
from alpha_zero.core.multi_game import run_tournament
from alpha_zero.core.quantum_net import QuantumAlphaZeroNet
from alpha_zero.core.network import AlphaZeroNet
from alpha_zero.utils.util import extract_args_from_flags_dict, create_logger

In [None]:
agent_configs = [


    {
        'name': 'search_50k',
        'num_filters': 16,
        'max_depth': 1,
        'branching_width': 3,
        'beam_width': 1,
        'num_fc_units':128,
        'num_search':5,
        'load_chkpt' : './checkpoints/go/5x5/search/f_16_sgf_500/chkpts_lr_.0001/training_steps_50176.ckpt'
    },

    {
        'name': 'Search_90k',
        'num_filters': 16,
        'max_depth': 1,
        'branching_width': 3,
        'beam_width': 1,
        'num_fc_units':128,
        'num_search':5,
        'load_chkpt' : './checkpoints/go/5x5/search/f_16_sgf_500/chkpts_lr_.0001/training_steps_90112.ckpt'
    },
    {
        'name': 'search_100k',
        'num_filters': 16,
        'max_depth': 1,
        'branching_width': 3,
        'beam_width': 1,
        'num_fc_units':128,
        'num_search':5,
        'load_chkpt' : './checkpoints/go/5x5/search/f_16_sgf_500/chkpts_lr_.0001/training_steps_100352.ckpt'
    },
     {
        'name': 'Search_138k',
        'num_filters': 16,
        'max_depth': 1,
        'branching_width': 3,
        'beam_width': 1,
        'num_fc_units':128,
        'num_search':5,
        'load_chkpt' : './checkpoints/go/5x5/search/f_16_sgf_500/chkpts_lr_.0001/training_steps_138240.ckpt'
    },


]


In [None]:
agent_configs_resnet = [


     {
        'name': 'resnet_88k',
        'num_res_blocks': 5,
        'num_filters_resnet': 64,
        'num_fc_units':128,
        'load_chkpt' : './checkpoints/go/5x5/resnets/training_steps_88000.ckpt'
    },

    ]

In [None]:
def env_builder():
        return GoEnv(komi=FLAGS.komi, num_stack=FLAGS.num_stack)
eval_env = env_builder()

input_shape = eval_env.observation_space.shape
num_actions = eval_env.action_space.n

# Initialize agents
Agents = []
Agents_resnet = []
for config in agent_configs:
    agent = QuantumAlphaZeroNet(
        input_shape,
        num_actions,
        config['num_filters'],
        config['max_depth'],
        config['branching_width'],
        config['beam_width'],
        config['num_fc_units'],
        config['num_search'],
    )
    Agents.append(agent)

for config in agent_configs_resnet:
    agent = AlphaZeroNet(
            input_shape,
            num_actions,
            config['num_res_blocks'],
            config['num_filters_resnet'],
            FLAGS.num_fc_units,
        )
    Agents_resnet.append(agent)



In [None]:
# Initialize agents with metadata

agents_search = {}
agents_resnet = {}

# Add QuantumAlphaZeroNet agents
for config, agent in zip(agent_configs, Agents):
    agents_search[config["name"]] = {
        "network": agent,
        "elo_rating": 1500,  # Initial Elo rating
        "checkpoint": config["load_chkpt"],
        "wins": 0,
        "lost":0
    }

for config, agent in zip(agent_configs_resnet, Agents_resnet):
    agents_resnet[config["name"]] = {
        "network": agent,
        "elo_rating": 1500,  # Initial Elo rating
        "checkpoint": config["load_chkpt"],
        "wins": 0,
        "lost":0
    }


In [None]:
agents = agents_resnet | agents_search

In [None]:
set_seed(FLAGS.seed)

logger = create_logger(FLAGS.log_level)

logger.info(extract_args_from_flags_dict(FLAGS.flag_values_dict()))


In [None]:
if torch.cuda.is_available():
    learner_device = torch.device('cuda')

run_tournament(
    seed = FLAGS.seed,
    agents = agents,
    env = eval_env,
    device = learner_device,
    num_games = 1000*len(agents),
    num_simulations = FLAGS.num_simulations,
    num_parallel = FLAGS.num_parallel,
    c_puct_base = FLAGS.c_puct_base,
    c_puct_init = FLAGS.c_puct_init,
    default_rating = FLAGS.default_rating,
    log_level = FLAGS.log_level,
    logs_dir = FLAGS.logs_dir,

)