In [1]:
import torch
import logging
import matplotlib.pyplot as plt

logging.basicConfig(filename='training_othello.log', filemode='a', level=logging.INFO, format='%(asctime)s %(message)s')
logging.info('Starting training')

In [2]:
from envs.othello.trainer import OthelloTrainer
from core.hyperparameters import LZHyperparameters
from core.lz_resnet import LZArchitectureParameters, LZResnet

NUM_PARALLEL_ENVS = 10
DEVICE = torch.device('cpu')
torch.backends.cudnn.benchmark = True

arch_params=LZArchitectureParameters(
    input_size=torch.Size((2, 8, 8)),
    policy_size=64,
    res_channels=4,
    res_blocks=1,
    value_head_res_channels=4,
    value_head_res_blocks=1,
    policy_head_res_channels=4,
    policy_head_res_blocks=1,
    kernel_size=3,
    value_output_activation=torch.nn.Tanh()
)

hypers=LZHyperparameters(
    learning_rate=1e-4,
    num_iters_train=20,
    iter_depth_train=3,
    num_iters_eval=20,
    iter_depth_test=3,
    replay_memory_size=1000,
    replay_memory_min_size=1,
    policy_factor=1,
    minibatch_size=128,
    minibatches_per_update=1,
    episodes_per_epoch=5,
    eval_episodes_per_epoch=5,
    epsilon_decay_per_epoch=0.0000,
    epsilon_start=0.0,
    epsilon_end=0.0,
    mcts_c_puct=1.0
)

run_tag='test_othello_1'

model = LZResnet(arch_params).to(DEVICE)

trainer = OthelloTrainer(
    num_parallel_envs = NUM_PARALLEL_ENVS,
    model = model,
    optimizer = torch.optim.AdamW(model.parameters(), lr=hypers.learning_rate),
    hypers = hypers,
    device = DEVICE,
    debug=True
)

plt.close('all')

In [None]:
trainer.training_loop()