In [1]:
import os
os.environ["MKL_NUM_THREADS"] = "1" 
os.environ["NUMEXPR_NUM_THREADS"] = "1" 
os.environ["OMP_NUM_THREADS"] = "1" 


import warnings
warnings.filterwarnings("ignore", category=UserWarning)

from hyperparameters import AZ_HYPERPARAMETERS
import torch
torch.set_num_threads(1)

import matplotlib.pyplot as plt

import logging
logging.basicConfig(filename='training.log', filemode='a', level=logging.INFO, format='%(asctime)s %(message)s')
logging.info('Starting training')

from az_resnet import AZResnetArchitectureParameters, AZResnet
from trainer import AlphaZeroTrainer, init_trainer_from_checkpoint
from env2048.trainer import _2048Trainer

# depending on the cloud provider you're using you may need to uncomment these lines
# import resource
# resource.setrlimit(
#     resource.RLIMIT_NOFILE,
#     (200000, 200000))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
CHECKPOINT_PATH = "" # place filename of checkpoint here, otherwise leave empty
LOAD_REPLAY_MEMORY = True
PLOT_EVERY = 25
NUM_COLLECTION_PROCS = 8
NUM_TRAIN_PROCS = 1

In [3]:
if CHECKPOINT_PATH:
    trainer = init_trainer_from_checkpoint(CHECKPOINT_PATH, load_replay_memory=LOAD_REPLAY_MEMORY)
    logging.info(f'Loaded model from {CHECKPOINT_PATH}')
else:
    # name your run here
    run_tag = 'resnet2048'
    hypers = AZ_HYPERPARAMETERS()
    # Set any hyperparameters you want to change here
    
    # hypers.mcts_iters_train = 100
    # ...

    # init model, replay memory, optimizer, and metrics history
    model_arch_params = AZResnetArchitectureParameters(
        input_size=torch.Size([4,4,4]),
        policy_size=4,
        res_channels=16,
        res_blocks=8,
        value_head_res_channels=16,
        value_head_res_blocks=0,
        policy_head_res_channels=16,
        policy_head_res_blocks=0,
        kernel_size=3,
        policy_fc_size=32,
        value_fc_size=32,
    )

    model = AZResnet(model_arch_params)
    optimizer = torch.optim.AdamW(model.parameters(), lr=hypers.learning_rate, weight_decay=hypers.weight_decay, amsgrad=True)
    trainer = _2048Trainer(model, optimizer, hypers, run_tag=run_tag)
    logging.info(f'Created new trainer module with tag {run_tag}')
plt.close('all')

In [4]:
trainer.run_training_loop()

KeyboardInterrupt: 