In [1]:

import torch
import matplotlib.pyplot as plt
import logging

logging.basicConfig(filename='training_2048.log', filemode='a', level=logging.INFO, format='%(asctime)s %(message)s')
logging.info('Starting training')

from envs._2048.trainer import _2048Trainer
from core.hyperparameters import LZHyperparameters
from core.lz_resnet import LZArchitectureParameters, LZResnet


In [2]:
# --------- SETUP ---------
from envs._2048.trainer import init_2048_trainer_from_checkpoint


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

CHECKPOINT_FILE = ''
NUM_PARALLEL_ENVS = 5 # you can go much higher on a GPU, depending on the model size / state size

# performance
torch.backends.cudnn.benchmark = True

if CHECKPOINT_FILE:
    trainer: _2048Trainer = init_2048_trainer_from_checkpoint(NUM_PARALLEL_ENVS, CHECKPOINT_FILE, device)
else:
    run_tag = '' # TODO: add run tag
    
    model_architecture = LZArchitectureParameters(
        input_size=torch.Size((1, 4, 4)),
        policy_size=4,
        res_channels=16,
        res_blocks=8, 
        value_head_res_channels=16,
        value_head_res_blocks=4,
        policy_head_res_channels=16,
        policy_head_res_blocks=4,
        kernel_size=3,
        policy_fc_size=32,
        value_fc_size=32
    ) # TODO: specify model architecture parameters 

    hypers =LZHyperparameters(
        # TODO: I strongly reccommend changing default hyperparamters
        learning_rate = 1e-4,
        num_iters_train = 5,
        iter_depth_train = 2,
        num_iters_eval = 5,
        iter_depth_test = 3,
        replay_memory_size = 10000,
        replay_memory_min_size = 1,
        minibatch_size = 4096,
        minibatches_per_update = 2,
        episodes_per_epoch=100000,
        epsilon_decay_per_epoch=0.1,
        eval_episodes_per_epoch=0
    )

    model = LZResnet(model_architecture).to(device)

    trainer = _2048Trainer(
        NUM_PARALLEL_ENVS,
        model = model,
        optimizer = torch.optim.AdamW(model.parameters(), lr=hypers.learning_rate),
        hypers = hypers,
        device = device,
    )

plt.close('all')

  mask0 = torch.tensor([[[[-1e5, 1]]]], dtype=dtype, device=device, requires_grad=False)
  mask1 = torch.tensor([[[[1], [-1e5]]]], dtype=dtype, device=device, requires_grad=False)
  mask2 = torch.tensor([[[[1, -1e5]]]], dtype=dtype, device=device, requires_grad=False)
  mask3 = torch.tensor([[[[-1e5], [1]]]], dtype=dtype, device=device, requires_grad=False)


In [None]:
trainer.training_loop()