In [1]:

import torch
import matplotlib.pyplot as plt
import logging

from env2048.vecttrainer import VectorizedTrainer, load_trainer_from_checkpoint
from az_resnet import AZResnet, AZResnetArchitectureParameters
from hyperparameters import LazyAZHyperparameters

logging.basicConfig(filename='training.log', filemode='a', level=logging.INFO, format='%(asctime)s %(message)s')
logging.info('Starting training')

torch.jit.enable_onednn_fusion(True)
torch.backends.cudnn.benchmark = True

In [2]:
CHECKPOINT_FILE = ''
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:


if CHECKPOINT_FILE:
    trainer = load_trainer_from_checkpoint(CHECKPOINT_FILE, device, load_replay_memory=True)
else:
    run_tag = '' # TODO: add run tag
    
    num_boards = 20

    model = AZResnet(AZResnetArchitectureParameters(
        input_size=torch.Size((1, 4, 4)),
        policy_size=4,
        res_channels=8,
        res_blocks=2, 
        value_head_res_channels=8,
        value_head_res_blocks=2,
        policy_head_res_channels=8,
        policy_head_res_blocks=2,
        kernel_size=3,
        policy_fc_size=32,
        value_fc_size=32
    )) # TODO: specify model architecture parameters 

    hypers = LazyAZHyperparameters() # TODO: specify hyperparameters

    optimizer = torch.optim.AdamW(model.parameters(), lr=hypers.learning_rate)

    trainer = VectorizedTrainer(num_boards, model, optimizer, hypers, device, run_tag=run_tag, progression_batch_size=20)

plt.close('all')

In [None]:
trainer.run_training_loop()