In [1]:
from TrainerSeparateParallel import Trainer
from TetrisModel import TetrisModel
from Pretrainer import Pretrainer
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt
import pickle
import glob
import time

In [2]:
piece_dim = 8
key_dim = 12
depth = 32
gamma = 0.99
lam = 0.95
temperature = 1.0
num_players = 16
display_rows = 4

In [3]:
# pretrainer = Pretrainer(gamma=gamma)

In [4]:
# players_data = pretrainer._load_data()

In [5]:
# pretrainer._load_dset(players_data)

In [6]:
# max_len = pretrainer._max_len
max_len = 10

In [7]:
# gt_dset = pretrainer._cache_dset()

In [8]:
agent = TetrisModel(piece_dim=piece_dim,
                    key_dim=key_dim,
                    depth=depth,
                    num_heads=4,
                    num_layers=4,
                    max_length=max_len,
                    out_dim=key_dim)

In [9]:
agent_optimizer = keras.optimizers.Adam(1e-5, clipnorm=1.5)
agent.compile(optimizer=agent_optimizer)

In [10]:
logits, piece_scores, key_scores = agent((tf.random.uniform((32, 28, 10, 1)),
                                          tf.random.uniform((32, 7), minval=0, maxval=8, dtype=tf.int32),
                                          tf.random.uniform((32, max_len), minval=0, maxval=key_dim, dtype=tf.int32)), return_scores=True)
agent.summary(), tf.shape(logits), tf.shape(piece_scores), tf.shape(key_scores)

Model: "tetris_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 sequential (Sequential)     (32, 70, 32)              37312     
                                                                 
 seq_embedding (SeqEmbedding  multiple                 256       
 )                                                               
                                                                 
 seq_embedding_1 (SeqEmbeddi  multiple                 384       
 ng)                                                             
                                                                 
 piece_dec_0 (DecoderLayer)  multiple                  37984     
                                                                 
 piece_dec_1 (DecoderLayer)  multiple                  37984     
                                                                 
 piece_dec_2 (DecoderLayer)  multiple                 

(None,
 <tf.Tensor: shape=(3,), dtype=int32, numpy=array([32, 10, 12])>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 4, 32,  4,  7, 70])>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 4, 32,  4, 10,  7])>)

In [11]:
critic = TetrisModel(piece_dim=piece_dim,
                     key_dim=key_dim,
                     depth=depth,
                     num_heads=4,
                     num_layers=4,
                     max_length=max_len,
                     out_dim=1)

In [12]:
critic_optimizer = keras.optimizers.Adam(1e-5, clipnorm=1.5)
critic.compile(optimizer=critic_optimizer)

In [13]:
values, piece_scores, key_scores = critic((tf.random.uniform((32, 28, 10, 1)),
                                           tf.random.uniform((32, 7), minval=0, maxval=8, dtype=tf.int32),
                                           tf.random.uniform((32, max_len), minval=0, maxval=key_dim, dtype=tf.int32)), return_scores=True)
critic.summary(), tf.shape(values), tf.shape(piece_scores), tf.shape(key_scores)

Model: "tetris_model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 sequential_9 (Sequential)   (32, 70, 32)              37312     
                                                                 
 seq_embedding_2 (SeqEmbeddi  multiple                 256       
 ng)                                                             
                                                                 
 seq_embedding_3 (SeqEmbeddi  multiple                 384       
 ng)                                                             
                                                                 
 piece_dec_0 (DecoderLayer)  multiple                  37984     
                                                                 
 piece_dec_1 (DecoderLayer)  multiple                  37984     
                                                                 
 piece_dec_2 (DecoderLayer)  multiple               

(None,
 <tf.Tensor: shape=(3,), dtype=int32, numpy=array([32, 10,  1])>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 4, 32,  4,  7, 70])>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 4, 32,  4, 10,  7])>)

In [14]:
agent_checkpoint = tf.train.Checkpoint(model=agent, optim=agent.optimizer)
agent_checkpoint.restore('agent_checkpoint/finetuned/ckpt-9')

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x22f9f6dab50>

In [15]:
critic_checkpoint = tf.train.Checkpoint(model=critic, optim=critic.optimizer)
critic_checkpoint.restore('critic_checkpoint/finetuned/ckpt-9')

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x22c951c8b20>

In [16]:
ref_agent = TetrisModel(piece_dim=piece_dim,
                        key_dim=key_dim,
                        depth=depth,
                        num_heads=4,
                        num_layers=4,
                        max_length=max_len,
                        out_dim=key_dim)

In [17]:
logits, piece_scores, key_scores = ref_agent((tf.random.uniform((1, 28, 10, 1)),
                                              tf.random.uniform((1, 7), minval=0, maxval=8, dtype=tf.int32),
                                              tf.random.uniform((1, max_len), minval=0, maxval=key_dim, dtype=tf.int32)), return_scores=True)
tf.shape(logits), tf.shape(piece_scores), tf.shape(key_scores)

(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 1, 10, 12])>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 4,  1,  4,  7, 70])>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 4,  1,  4, 10,  7])>)

In [18]:
ref_checkpoint = tf.train.Checkpoint(model=ref_agent)
ref_checkpoint.restore('agent_checkpoint/finetuned/ckpt-9')

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x22c9513fa60>

In [19]:
# epochs = 10

In [20]:
# actor_losses, critic_losses, accs = pretrainer.train(agent, critic, gt_dset, epochs)

In [21]:
%matplotlib qt

In [22]:
# plt.plot(actor_losses)
# plt.plot(critic_losses)
# plt.plot(accs)

In [23]:
# agent_checkpoint.save('agent_checkpoint/pretrained/ckpt')
# critic_checkpoint.save('critic_checkpoint/pretrained/ckpt')

In [24]:
agent_checkpoint = tf.train.Checkpoint(model=agent, optim=agent.optimizer)
agent_checkpoint_manager = tf.train.CheckpointManager(agent_checkpoint, 'agent_checkpoint/finetuned', max_to_keep=5)

In [25]:
critic_checkpoint = tf.train.Checkpoint(model=critic, optim=critic.optimizer)
critic_checkpoint_manager = tf.train.CheckpointManager(critic_checkpoint, 'critic_checkpoint/finetuned', max_to_keep=5)

In [26]:
trainer = Trainer(agent=agent,
                  critic=critic,
                  ref_model=ref_agent,
                  max_len=max_len,
                  num_players=num_players,
                  display_rows=display_rows,
                  gamma=gamma,
                  lam=lam,
                  temperature=temperature,
                  max_episode_steps=1000)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmichaelsherrick[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [52]:
while True:
    if __name__ == '__main__':
        trainer.train(gens=100, train_steps=100, training_actor=True)
        agent_checkpoint_manager.save()
        critic_checkpoint_manager.save()

BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

In [37]:
agent_checkpoint_manager.save()
critic_checkpoint_manager.save()

'critic_checkpoint/finetuned\\ckpt-9'

In [None]:
trainer.save_demo('Demo.gif', max_steps=1000)

In [29]:
episode_data = trainer.player.run_episode(agent, critic, max_steps=1000, greedy=True, renderer=trainer.renderer)

In [31]:
episode_boards, episode_pieces, episode_inputs, episode_actions, episode_probs, episode_values, episode_rewards = episode_data

In [32]:
episode_advantages, episode_returns = trainer._compute_gae(episode_values, episode_rewards, trainer.gamma, trainer.lam)

In [33]:
fig, ax = plt.subplots()
ax.plot(episode_returns, label='Returns')
ax.plot(episode_values, label='Values')
ax.legend()
tf.reduce_sum(episode_rewards)

<tf.Tensor: shape=(), dtype=float32, numpy=154.17001>

In [36]:
fig, ax = plt.subplots()
ax.plot(episode_rewards, label='Rewards')
ax.plot(episode_advantages, label='Advantages')
ax.legend()

<matplotlib.legend.Legend at 0x22e8b5b1fd0>

In [31]:
trainer.wandb_run.finish()

VBox(children=(Label(value='0.357 MB of 0.357 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
critic_loss,▂▁▃▁▃▁▁▁▃▄▁▁▂▄▁▅▁▁▁▂▁▁▄▂▁▁▁▂▃▁▁█▅▂▅▃▁▁▁▁
entropy,▆▅▆▆▆▄▆▆▄▆▆▅▇▇▅▅▁▅▃▆▅▃▅▇██▇▇▆█▅██▅▄▆▇█▅▅
kl_div,▄▂▄▂▃▄▂▂▁▁▃▂▆▄▂▄▅▂▂▂▅▇▃▃▃▃▄█▅▃▃▆▆▅▅▄▃▃▆▂
ppo_loss,▃▆▄▄▅▆▆▅▁▂▆▄▇▄▄▂▆▃█▁▃▁▄▂▅▅▄▁▅▆▃█▄▄▄▅▃▅▂▆
reward,▇▇▅▄▇▇▄▃▇▇▆▇█▇▄▄▇▅▇▄▆▄▇▇▇▇▇▅▁▇██▃▃█▄▅▇▇▁
reward_per_piece,▇▇▅▇▆▆▆▇▆▆▅▇█▇▆▄▃▇▇▇█▄▇▆▅▇▆▇▇▇▇▆▇▇▇▅▁▇▇▆
unclipped_proportion,▄▅█▂▃▄▄▄▂▆▆▅▅▃▅▂▅▃▂▁▆▆▁▅▂▄▆▅▇▂▄▄▁▄▄▆▂▆▄▆

0,1
critic_loss,5.3345
entropy,-0.10619
kl_div,0.36793
ppo_loss,0.03429
reward,59.595
reward_per_piece,0.11685
unclipped_proportion,0.935
