In [38]:
import keras
from keras.layers import Input, Dense, Conv2D, Flatten, BatchNormalization, Activation
from keras.layers.merge import concatenate
from keras.models import Model
from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau
from datetime import datetime
import sqlite3
from blosc import decompress
from msgpack import unpackb
import pandas as pd
import numpy as np

# Load data

In [2]:
con = sqlite3.connect("games.sqlite")

In [3]:
def unpack_state(row):
    vec = unpackb(decompress(row.board_state))
    arr = np.array(np.reshape(np.frombuffer(vec, np.uint8), (8, 4, 8)).T, np.float32)
    arr[:, :, 7] /= 100
    return [arr]

In [4]:
def unpack_moves(row):
    vec = unpackb(decompress(row.mcts_moves))
    return [np.reshape(np.array(vec, np.float32), (4, 4, 8)).T]

In [27]:
positions = pd.read_sql_query("select positions.*, games.outcome from positions, games where positions.game_id=games.id", con)
positions['state_tensor'] = positions.apply(unpack_state, axis=1)
positions['moves_tensor'] = positions.apply(unpack_moves, axis=1)
len(positions)

138531

In [28]:
positions.head()

Unnamed: 0,id,game_id,move_number,board_state,mcts_moves,mcts_score,outcome,state_tensor,moves_tensor
0,1,1,1,b'\x02\x01\x01\x01\x03\x01\x00\x00\x03\x01\x00...,b'\x02\x01\x01\x01\x83\x02\x00\x00\x83\x02\x00...,0.499667,0.0,"[[[[0. 1. 0. 0. 0. 1. 0. 0.], [0. 1. 0. 0. 0. ...","[[[[0. 0. 0. 0.], [0. 0. 0. 0.], [0. 0. 0. 0.]..."
1,2,1,2,b'\x02\x01\x01\x01\x03\x01\x00\x00\x03\x01\x00...,b'\x02\x01\x01\x01\x83\x02\x00\x00\x83\x02\x00...,0.5065,0.0,"[[[[0. 1. 0. 0. 0. 0. 0. 0.01], ...","[[[[0. 0. 0. 0.], [0. 0. 0. 0.], [0. 0. 0. 0.]..."
2,3,1,3,b'\x02\x01\x01\x01\x03\x01\x00\x00\x03\x01\x00...,b'\x02\x01\x01\x01\x83\x02\x00\x00\x83\x02\x00...,0.495167,0.0,"[[[[0. 1. 0. 0. 0. 1. 0. 0.02], ...","[[[[0. 0. 0. 0.], [0. 0. 0. 0.], [0. 0. 0. 0.]..."
3,4,1,4,b'\x02\x01\x01\x01\x03\x01\x00\x00\x03\x01\x00...,b'\x02\x01\x01\x01\x83\x02\x00\x00\x83\x02\x00...,0.4895,0.0,"[[[[0. 1. 0. 0. 0. 0. 0. 0.03], ...","[[[[0. 0. 0. 0.], [0. 0. 0. 0.], [0. 0. 0. 0.]..."
4,5,1,5,b'\x02\x01\x01\x01\x03\x01\x00\x00\x03\x01\x00...,b'\x02\x01\x01\x01\x83\x02\x00\x00\x83\x02\x00...,0.448167,0.0,"[[[[0. 1. 0. 0. 0. 1. 0. 0.04], ...","[[[[0. 0. 0. 0.], [0. 0. 0. 0.], [0. 0. 0. 0.]..."


# Outcome model

In [30]:
start = Input(shape=positions.loc[0].state_tensor[0].shape)

x = start
for i in range(100):
    prev_layer = x
    x = Conv2D(4, (3, 3), padding='same', activation='relu')(prev_layer)
    x = BatchNormalization()(x)
    x = concatenate([prev_layer, x])
    
x = Conv2D(64, (1, 1), padding='same', activation='relu')(x)
x = Flatten()(x)
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(1, activation='sigmoid')(x)

model = Model(start, x)

## Train/validate split

In [32]:
split = positions.groupby(lambda r: positions.loc[r].game_id % 5 == 0)
val = split.get_group(True)
train = split.get_group(False)

In [54]:
def extract_tensors(dataframe):
    inputs = []
    outcomes = []
    mcts_probs = []
    for _, row in dataframe.sample(frac=1).iterrows():
        inputs.append(row.state_tensor[0])
        outcomes.append(row.outcome)
        mcts_probs.append(row.moves_tensor[0].ravel())
    return np.array(inputs), np.array(outcomes), np.array(mcts_probs)

In [55]:
val_in, val_outcomes, val_probs = extract_tensors(val)
train_in, train_outcomes, train_probs = extract_tensors(train)

In [35]:
savedir = 'logs-outcome/' + str(datetime.now())
tbcb = TensorBoard(log_dir=savedir, histogram_freq=0, write_graph=True, write_images=True)
mccb = ModelCheckpoint(savedir+'/model.{epoch:04d}-{loss:.4f}-{val_loss:.4f}.hdf5',
                       monitor='val_loss', save_best_only=False, period=5)
redlr = ReduceLROnPlateau('loss', factor=0.1, cooldown=1, verbose=1, patience=4)
callbacks = [tbcb, mccb, redlr]

model.compile(loss='mean_absolute_error', optimizer='adam')
model.fit(train_in, train_outcomes, 1024, 100, verbose=1, callbacks=callbacks, validation_data=(val_in, val_outcomes))

Train on 111093 samples, validate on 27438 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100

KeyboardInterrupt: 

# MCTS probability model

In [41]:
start = Input(shape=positions.loc[0].state_tensor[0].shape)

x = start
for i in range(100):
    prev_layer = x
    x = Conv2D(4, (3, 3), padding='same', activation='relu')(prev_layer)
    x = BatchNormalization()(x)
    x = concatenate([prev_layer, x])
    
x = Conv2D(64, (1, 1), padding='same', activation='relu')(x)
x = Conv2D(4, (1, 1), padding='same', activation='relu')(x)
x = Flatten()(x)
x = Activation('softmax')(x)

model = Model(start, x)

In [62]:
savedir = 'logs-probs/' + str(datetime.now())
tbcb = TensorBoard(log_dir=savedir, histogram_freq=0, write_graph=True, write_images=True)
mccb = ModelCheckpoint(savedir+'/model.{epoch:04d}-{loss:.4f}-{val_loss:.4f}.hdf5',
                       monitor='val_loss', save_best_only=False, period=5)
redlr = ReduceLROnPlateau('loss', factor=0.1, cooldown=1, verbose=1, patience=4)
callbacks = [tbcb, mccb, redlr]

model.compile(loss='categorical_crossentropy', optimizer=keras.optimizers.RMSprop(2e-4))
model.fit(train_in, train_probs, 1024, 100, verbose=1, callbacks=callbacks, validation_data=(val_in, val_probs))

Train on 111093 samples, validate on 27438 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100

KeyboardInterrupt: 

In [67]:
preds = np.reshape(model.predict(np.array(positions.loc[0].state_tensor)), (8, 4, 4))

In [68]:
preds[:, :, 0]

array([[2.63240895e-06, 2.63240895e-06, 2.63240895e-06, 2.63240895e-06],
       [2.63240895e-06, 2.63240895e-06, 2.63240895e-06, 2.63240895e-06],
       [2.63240895e-06, 2.63240895e-06, 2.63240895e-06, 2.63240895e-06],
       [2.63240895e-06, 2.63240895e-06, 2.63240895e-06, 2.63240895e-06],
       [2.63240895e-06, 2.63240895e-06, 3.14325507e-06, 2.63240895e-06],
       [8.47379706e-05, 1.67024240e-01, 1.22421496e-01, 1.35151908e-01],
       [6.87249030e-06, 9.28240661e-06, 2.28467979e-05, 2.15322034e-05],
       [2.63240895e-06, 1.28239972e-05, 8.73178124e-06, 2.63240895e-06]],
      dtype=float32)

In [75]:
positions.loc[0].moves_tensor[0][:, :, 0]

array([[0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.18533333, 0.135     , 0.16066666],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ]], dtype=float32)