In [1]:



from train import MCTS_HYPERPARAMETERS, load_from_checkpoint, collect_episode, train, rotate_training_examples

from model_3d import MonteCarlo3d
from utils import input_to_tensor_3d
import torch.multiprocessing as mp
from train import save_checkpoint, load_from_checkpoint

In [3]:
MODEL_TO_LOAD = "" # place filename of checkpoint here, otherwise leave empty
LOAD_REPLAY_MEMORY = True

# HYPERPARAMETERS (will be ignored if loading from a checkpoint)
hyperparameters = MCTS_HYPERPARAMETERS() # use kwargs to specify non-default values


In [3]:
# if MODEL_TO_LOAD:
#     env, mcts, episode, model, optimizer, hyperparameters, metric_history, replay_memory, run_tag = load_from_checkpoint(MODEL_TO_LOAD, load_replay_memory=LOAD_REPLAY_MEMORY)
# else:
#     env = _2048Env()
#     model = MonteCarlo3d()
#     mcts = MCTS_Evaluator(model, env, input_to_tensor_3d, training=True)
#     replay_memory = ReplayMemory(hyperparameters.replay_memory_size)
#     optimizer = torch.optim.AdamW(model.parameters(), lr=hyperparameters.lr, weight_decay=hyperparameters.weight_decay)
#     metric_history = MetricsHistory()
#     run_tag = ''     
#     episode = 0



In [2]:
_, model, optimizer, hyperparameters, metrics_history, replay_memory, run_tag = load_from_checkpoint('3DMCTS_ep2200.pt', MonteCarlo3d, load_replay_memory=True)
hyperparameters.num_episodes = 10000

In [None]:
def enque_and_train(results):
    training_examples, reward, moves, high_square= results
    training_examples = rotate_training_examples(training_examples)
    replay_memory.extend(training_examples)
    
    if replay_memory.size() >= hyperparameters.minibatch_size:
        value_loss, prob_loss, total_loss = train(replay_memory.sample(hyperparameters.minibatch_size), model, optimizer, tensor_conversion_fn=input_to_tensor_3d)
        new_best = metrics_history.add_history({
                'reward': reward,
                'game_moves': moves,
                'prob_loss': prob_loss,
                'value_loss': value_loss,
                'total_loss': total_loss,
                'high_square': high_square
            })
        metrics_history.plot_history(window_size=100)
        if new_best:
            print('*** NEW BEST REWARD ***')
        print(f'[EPISODE {metrics_history.episodes}] Total Loss: {total_loss}, Prob Loss {prob_loss}, Value Loss {value_loss}, Reward {reward}, Moves: {moves}, Highest Square: {high_square}')
        if metrics_history.episodes % hyperparameters.checkpoint_every == 0:
            print('Saving model checkpoint...')
            save_checkpoint(metrics_history.episodes, model, optimizer, hyperparameters, metrics_history, replay_memory, run_tag='3DMCTS', save_replay_memory=True)
            print('Saved model checkpoint!')
    else:
        print(f'Replay memory size not large enough, {replay_memory.size()} < {hyperparameters.minibatch_size}')
    
    

with mp.Pool(mp.cpu_count() - 1) as p:
    results = []
    for n in range(metrics_history.episodes, hyperparameters.num_episodes):
        results.append(p.apply_async(collect_episode, (model, hyperparameters, input_to_tensor_3d, ), callback=enque_and_train, error_callback=print))
    for r in results:
        r.wait()

In [None]:
# from train import test_network

# test_network(model, hyperparameters, input_to_tensor_3d, debug_print=True)