In [1]:
"""Trains the AlphaZero agent on a single machine for the game of Go."""
import os

# This forces OpenMP to use 1 single thread, which is needed to
# prevent contention between multiple process.
os.environ['OMP_NUM_THREADS'] = '1'
# Tell numpy to only use one core.
os.environ['MKL_NUM_THREADS'] = '1'

import sys
from absl import flags

import numpy as np
import torch
from torch.optim.lr_scheduler import MultiStepLR

FLAGS = flags.FLAGS

flags.DEFINE_integer('board_size', 9, 'Board size for Go.')
flags.DEFINE_float('komi', 7.5, 'Komi rule for Go.')

flags.DEFINE_integer(
    'num_stack',
    8,
    'Stack N previous states, the state is an image of N x 2 + 1 binary planes.',
)

flags.DEFINE_integer('num_filters', 236, 'Number of filters for the conv2d layers in the neural network.')
flags.DEFINE_integer('max_depth', 10, ' maximum depth for quantum search')
flags.DEFINE_integer('branching_width', 3, ' branching_width for quantum search')
flags.DEFINE_integer('beam_width', 1, ' beam_width for quantum search')
flags.DEFINE_integer(
    'num_fc_units',
    128,
    'Number of hidden units in the linear layer of the neural network.',
)
flags.DEFINE_integer('num_search', 1, ' number of search modules for quantum search')


flags.DEFINE_integer(
    'batch_size',
    1024,
    'To avoid overfitting, we want to make sure the agent only sees ~10% of samples in the replay over one checkpoint.'
    'That is, batch_size * ckpt_interval <= replay_capacity * 0.1',
)

flags.DEFINE_bool(
    'argument_data',
    True,
    'Apply random rotation and mirroring to the training data, default on.',
)


flags.DEFINE_float('init_lr', 0.01, 'Initial learning rate.')
flags.DEFINE_float('lr_decay', 0.1, 'Learning rate decay rate.')
flags.DEFINE_multi_integer(
    'lr_milestones',
    [10000, 20000, 40000],
    'The number of training steps at which the learning rate will be decayed.',
)
flags.DEFINE_float('l2_regularization', 1e-4, 'The L2 regularization parameter applied to weights.')
flags.DEFINE_float('sgd_momentum', 0.9, '')

flags.DEFINE_integer(
    'max_training_steps',
    int(5e4),
    'Number of training steps (measured in network parameter update, one batch is one training step).',
)

flags.DEFINE_integer('ckpt_interval', 500, 'The frequency (in training step) to create new checkpoint.')
flags.DEFINE_integer('log_interval', 20, 'The frequency (in training step) to log training statistics.')

flags.DEFINE_string('ckpt_dir', '', 'Checkpoint directory (to be generated dynamically)')
flags.DEFINE_string('logs_dir', '', 'Logs directory (to be generated dynamically)')
flags.DEFINE_string(
    'dataset_dir',
    'go_dataset.pth',
    'Go dataset',
)

flags.DEFINE_string('log_level', 'INFO', '')
flags.DEFINE_integer('seed', 1, 'Seed the runtime.')


# Initialize flags
FLAGS(sys.argv, known_only = True)

os.environ['BOARD_SIZE'] = str(FLAGS.board_size)

In [3]:
def generate_folder_name(depth, search, branching, filters, beam):
    return f"d_{depth}s_{search}br_{branching}f_{filters}be_{beam}"

In [4]:
folder_name = generate_folder_name(
        FLAGS.max_depth, FLAGS.num_search, FLAGS.branching_width, FLAGS.num_filters, FLAGS.beam_width
    )

# Update ckpt_dir and logs_dir with the generated folder name
FLAGS.ckpt_dir = f'./checkpoints/go/9x9/quantum/{folder_name}'
FLAGS.logs_dir = f'./logs/go/9x9/quantum/{folder_name}'

In [None]:
from alpha_zero.envs.go import GoEnv
from alpha_zero.core.pipeline import (
    supervised_learner_loop,
    set_seed,
    maybe_create_dir,
)
from alpha_zero.core.quantum_net import QuantumAlphaZeroNet
from alpha_zero.utils.util import extract_args_from_flags_dict, create_logger

In [3]:
def env_builder():
        return GoEnv(komi=FLAGS.komi, num_stack=FLAGS.num_stack)
eval_env = env_builder()

input_shape = eval_env.observation_space.shape
num_actions = eval_env.action_space.n
def network_builder():
        return QuantumAlphaZeroNet(
            input_shape,
            num_actions,
            FLAGS.num_filters,
            FLAGS.max_depth,
            FLAGS.branching_width,
            FLAGS.beam_width,
            FLAGS.num_fc_units,
            FLAGS.num_search

        )
network = network_builder()
network = torch.compile(network)
optimizer = torch.optim.SGD(
    network.parameters(),
    lr=FLAGS.init_lr,
    momentum=FLAGS.sgd_momentum,
    weight_decay=FLAGS.l2_regularization,
)
lr_scheduler = MultiStepLR(optimizer, milestones=FLAGS.lr_milestones, gamma=FLAGS.lr_decay)

In [None]:
total_params = sum(p.numel() for p in network.parameters())
print(f" Total number of parameters: {total_params}")

In [None]:
input_shape

In [None]:
torch.set_float32_matmul_precision('high')

set_seed(FLAGS.seed)

maybe_create_dir(FLAGS.ckpt_dir)
maybe_create_dir(FLAGS.logs_dir)
# maybe_create_dir(FLAGS.save_sgf_dir)

logger = create_logger(FLAGS.log_level)

logger.info(extract_args_from_flags_dict(FLAGS.flag_values_dict()))

if torch.cuda.is_available():
    learner_device = torch.device('cuda')
supervised_learner_loop(
    seed = FLAGS.seed,
    network = network,
    data_dir = FLAGS.dataset_dir,
    device = learner_device,
    optimizer = optimizer,
    lr_scheduler=lr_scheduler,
    logger = logger,
    argument_data = FLAGS.argument_data,
    batch_size = FLAGS.batch_size,
    ckpt_interval = FLAGS.ckpt_interval,
    log_interval = FLAGS.log_interval,
    max_training_steps = FLAGS.max_training_steps,
    patience = 1000,
    ckpt_dir = FLAGS.ckpt_dir,
    logs_dir = FLAGS.logs_dir,
   )
