In [1]:
import torch
import logging
import matplotlib.pyplot as plt

torch.backends.cudnn.benchmark = True

from core.evaluation.mcts_hypers import MCTSHypers
from core.resnet import TurboZeroResnet, TurboZeroArchParams
from core.training.training_hypers import TurboZeroHypers
from envs.othello.evaluator import OthelloMCTS
from envs.othello.trainer import OthelloTrainer, load_checkpoint

logging.basicConfig(filename='training_othello.log', filemode='a', level=logging.INFO, format='%(asctime)s %(message)s')
logging.info('Starting training')

In [2]:
NUM_PARALLEL_ENVS = 100
CHECKPOINT_FILE = ''
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DEBUG = True # disables JIT compilation

In [3]:
if CHECKPOINT_FILE:
    trainer = load_checkpoint(NUM_PARALLEL_ENVS, CHECKPOINT_FILE, DEVICE, debug=DEBUG)
else:
    arch_params=TurboZeroArchParams(
        input_size=torch.Size((2, 8, 8)),
        policy_size=64,
        res_channels=4,
        res_blocks=1,
        value_head_res_channels=4,
        value_head_res_blocks=1,
        policy_head_res_channels=4,
        policy_head_res_blocks=1,
        kernel_size=3,
        value_output_activation=torch.nn.Tanh()
    )

    hypers=TurboZeroHypers()

    eval_hypers = MCTSHypers(num_iters = 100)

    run_tag='test_othello_1'

    model = TurboZeroResnet(arch_params).to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=hypers.learning_rate)

    trainer = OthelloTrainer(
        OthelloMCTS(NUM_PARALLEL_ENVS, DEVICE, 8, eval_hypers, debug=DEBUG),
        OthelloMCTS(hypers.test_episodes_per_epoch, DEVICE, 8, eval_hypers, debug=DEBUG),
        NUM_PARALLEL_ENVS,
        DEVICE,
        torch.device('cpu'),
        model,
        optimizer,
        hypers,
        run_tag=run_tag
    )

plt.close('all')

In [None]:
trainer.training_loop()