In [1]:
import PyTetris
import tensorflow.keras as keras
from tools import *
from time import time
import numpy as np

In [2]:
batch_size = 100

behavior_policy = policy_e_greedy
target_policy   = policy_greedy

display = True
new_model = True
model_name = "Q2"

In [3]:
if display:
    window = PyTetris.Window()
    window.set_ghost(0)
    window.set_gravity(0)

In [4]:
def build_model():
    x_input = keras.Input(shape=(10, 20, 26))
    x = x_input
    
    x = keras.layers.Conv2D(16, (3, 3), padding = 'valid')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(alpha=0.2)(x)
    
    x = keras.layers.Conv2D(16, (3, 3), padding = 'valid')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(alpha=0.2)(x)
    
    x = keras.layers.Conv2D(16, (3, 3), padding = 'valid')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(alpha=0.2)(x)
    
    x = keras.layers.Flatten()(x)
    
    x = keras.layers.Dense(128)(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(alpha=0.3)(x)
    
    x = keras.layers.Dense(1)(x)

    x_output = x / 128
    return keras.Model(x_input, x_output)
if new_model:
    Q = build_model()
    Q.compile(loss = "MSE", optimizer = keras.optimizers.Adam(learning_rate=0.02))
else:
    Q = keras.models.load_model(model_name + ".h5")

In [None]:
S = PyTetris.State(10, 20)

X = []
Y = []

epoch = 1
while True:
    for batch in range(batch_size):
        s0 = S.copy()
        s1, r1, Q1 = behavior_policy(S, Q)
        s2, r2, Q2 = target_policy(s1, Q)
        
        X.append(state_to_layer(s0, s1))
        Y.append(r1 + gamma * Q2)
        
        S = s1.copy()
        
        if display:
            window.set_state(s0)
            window.tick()
    
    loss = Q.train_on_batch(np.asarray(X), np.asarray(Y))
    print("[Epoch %d]\tloss:\t%.4f" % (epoch, loss))
    Q.save(model_name + ".h5")
    X, Y = [], []
    epoch += 1
    

[Epoch 1]	loss:	1999.2729
[Epoch 2]	loss:	5098.1206
[Epoch 3]	loss:	1899.2599
[Epoch 4]	loss:	5195.5918
[Epoch 5]	loss:	5199.8452
[Epoch 6]	loss:	5090.0317
[Epoch 7]	loss:	8490.6699
[Epoch 8]	loss:	8390.7100
[Epoch 9]	loss:	10102.7744
[Epoch 10]	loss:	8404.1592
[Epoch 11]	loss:	8537.2197
[Epoch 12]	loss:	13364.3213
[Epoch 13]	loss:	6932.3208
[Epoch 14]	loss:	19692.5312
[Epoch 15]	loss:	13327.8066
[Epoch 16]	loss:	8527.2842
[Epoch 17]	loss:	10155.7441
[Epoch 18]	loss:	8579.0459
[Epoch 19]	loss:	16544.6016
[Epoch 20]	loss:	10110.1309
[Epoch 21]	loss:	6891.4800
[Epoch 22]	loss:	14928.5146
[Epoch 23]	loss:	6867.7637
[Epoch 24]	loss:	3876.4268
[Epoch 25]	loss:	3712.1550
[Epoch 26]	loss:	557.5839
[Epoch 27]	loss:	5251.7612
[Epoch 28]	loss:	2106.7319
[Epoch 29]	loss:	8489.1758
[Epoch 30]	loss:	10117.7480
[Epoch 31]	loss:	7001.9917
[Epoch 32]	loss:	10312.0234
[Epoch 33]	loss:	15001.8291
[Epoch 34]	loss:	14961.9336
[Epoch 35]	loss:	24597.2676
[Epoch 36]	loss:	17019.5293
[Epoch 37]	loss:	15332.8