In [1]:
import torch
import logging
import matplotlib.pyplot as plt

torch.backends.cudnn.benchmark = True

from core.evaluation.mcts_hypers import MCTSHypers
from core.resnet import TurboZeroResnet, TurboZeroArchParams
from core.training.training_hypers import TurboZeroHypers
from envs.othello.evaluator import OthelloMCTS
from envs.othello.trainer import OthelloTrainer, load_checkpoint
from core.utils.custom_activations import Tanh0to1

logging.basicConfig(filename='training_othello.log', filemode='a', level=logging.INFO, format='%(asctime)s %(message)s')
logging.info('Starting training')

In [3]:
NUM_PARALLEL_ENVS = 8192
CHECKPOINT_FILE = ''
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DEBUG = False # disables JIT compilation

In [4]:
if CHECKPOINT_FILE:
    trainer = load_checkpoint(NUM_PARALLEL_ENVS, CHECKPOINT_FILE, DEVICE, debug=DEBUG)
else:
    arch_params=TurboZeroArchParams(
        input_size=torch.Size((2, 8, 8)),
        policy_size=65,
        res_channels=4,
        res_blocks=1,
        value_head_res_channels=4,
        value_head_res_blocks=1,
        policy_head_res_channels=4,
        policy_head_res_blocks=1,
        kernel_size=3,
        value_output_activation=Tanh0to1()
    )

    hypers=TurboZeroHypers(
        learning_rate = 3e-4,
        replay_memory_size = 30000,
        replay_memory_min_size = 10000,
        minibatch_size = 2048,
        minibatches_per_update = 1,
        train_episodes_per_epoch = 10000,
        test_episodes_per_epoch = 1000,
        temperature_train = 1.0,
        temperature_test = 0.1
    )

    eval_hypers_train = MCTSHypers(num_iters = 250, puct_coeff=1.41, dirichlet_alpha=0.5, dirichlet_epsilon=0.25)
    eval_hypers_test = MCTSHypers(num_iters = 250, puct_coeff=1.41, dirichlet_alpha=0.1, dirichlet_epsilon=0.1)

    run_tag='test_othello_1'

    model = TurboZeroResnet(arch_params).to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=hypers.learning_rate)

    trainer = OthelloTrainer(
        OthelloMCTS(NUM_PARALLEL_ENVS, DEVICE, 8, eval_hypers_train, debug=DEBUG),
        OthelloMCTS(hypers.test_episodes_per_epoch, DEVICE, 8, eval_hypers_test, debug=DEBUG),
        NUM_PARALLEL_ENVS,
        DEVICE,
        torch.device('cpu'),
        model,
        optimizer,
        hypers,
        run_tag=run_tag
    )

plt.close('all')

In [None]:
trainer.training_loop()