In [1]:
import PyTetris
import tensorflow.keras as keras
from tools import *
from time import time
import numpy as np

In [2]:
batch_size = 100

behavior_policy = policy_e_greedy
target_policy   = policy_greedy

display = True
new_model = True
model_name = "Q2"

In [3]:
if display:
    window = PyTetris.Window()
    window.set_ghost(0)
    window.set_gravity(0)

In [4]:
def build_model():
    x_input = keras.Input(shape=(10, 20, 26))
    x = x_input
    
    x = keras.layers.Conv2D(16, (3, 3), padding = 'valid')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(alpha=0.2)(x)
    
    x = keras.layers.Conv2D(16, (3, 3), padding = 'valid')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(alpha=0.2)(x)
    
    x = keras.layers.Conv2D(16, (3, 3), padding = 'valid')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(alpha=0.2)(x)
    
    x = keras.layers.Flatten()(x)
    
    x = keras.layers.Dense(128)(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(alpha=0.3)(x)
    
    x = keras.layers.Dense(1)(x)

    x_output = x / 128
    return keras.Model(x_input, x_output)
if new_model:
    Q = build_model()
    Q.compile(loss = "MSE", optimizer = keras.optimizers.Adam(learning_rate=0.02))
else:
    Q = keras.models.load_model(model_name + ".h5")

In [5]:
S = PyTetris.State(10, 20)

X = []
Y = []

epoch = 1
while True:
    for batch in range(batch_size):
        s0 = S.copy()
        s1, r1, Q1 = behavior_policy(S, Q)
        s2, r2, Q2 = target_policy(s1, Q)
        
        X.append(state_to_layer(s0, s1))
        Y.append(r1 + gamma * Q2)
        
        S = s1.copy()
        
        if display:
            window.set_state(s0)
            window.tick()
    
    loss = Q.train_on_batch(np.asarray(X), np.asarray(Y))
    print("[Epoch %d]\tloss:\t%.4f" % (epoch, loss))
    Q.save(model_name + ".h5")
    X, Y = [], []
    epoch += 1
    

[Epoch 1]	loss:	1884.4742
[Epoch 2]	loss:	2107.4548
[Epoch 3]	loss:	2522.6223
[Epoch 4]	loss:	6672.0068
[Epoch 5]	loss:	11311.8408
[Epoch 6]	loss:	17036.1055
[Epoch 7]	loss:	25897.7832
[Epoch 8]	loss:	41022.6094
[Epoch 9]	loss:	78063.8672
[Epoch 10]	loss:	132360.4688
[Epoch 11]	loss:	137967.9375
[Epoch 12]	loss:	178871.0469
[Epoch 13]	loss:	392680.4062
[Epoch 14]	loss:	445647.1250
[Epoch 15]	loss:	866508.0625
[Epoch 16]	loss:	593015.8125
[Epoch 17]	loss:	1068986.3750
[Epoch 18]	loss:	3036987.0000
[Epoch 19]	loss:	2521481.5000


KeyboardInterrupt: 