In [None]:
import tensorflow as tf
import tensorflow.keras.models as models
import tensorflow.keras.layers as layers
from tensorflow.keras.utils import Sequence
from tensorflow import keras
import sys 
import numpy as np 
import os 
import random
import gzip
from sklearn.model_selection import train_test_split
import tensorflow.keras.callbacks as callbacks

In [None]:
#Get the model to train, dataset for training, and side to train (from / to)\
model_name = sys.argv[1]
dataset = sys.argv[2]
side = sys.argv[3]

In [None]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

learning_rate = 0.001
weight_decay = 0.0001
batch_size = 2048 
num_epochs = 1000


In [None]:
class DataGenerator(Sequence):
    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        return batch_x, batch_y


In [None]:
def CNN(conv_size, conv_depth):
  board_in = layers.Input(shape=(14, 8, 8))

  x = board_in
  for _ in range(conv_depth):
    x = layers.Conv2D(filters=conv_size, kernel_size=3, padding='same', activation='relu')(x)
  x = layers.Flatten()(x)
  x = layers.Dense(64, 'relu')(x)
  x = layers.Dense(64, 'softmax')(x)

  return models.Model(inputs=board_in, outputs=x)


In [None]:
def residual(conv_size, conv_depth):
  board_in = layers.Input(shape=(14, 8, 8))

  x = layers.Conv2D(filters=conv_size, kernel_size=3, padding='same')(board_in)
  for _ in range(conv_depth):
    previous = x
    x = layers.Conv2D(filters=conv_size, kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filters=conv_size, kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Add()([x, previous])
    x = layers.Activation('relu')(x)
  x = layers.Flatten()(x)
  x = layers.Dense(64, 'softmax')(x)

  return models.Model(inputs=board_in, outputs=x)


In [None]:
def vit():
    num_classes = 64
    input_shape = (8, 8, 14)
    patch_size = 2  # Size of the patches to be extract from the input images
    num_patches = (8 // patch_size) ** 2
    projection_dim = 176
    num_heads = 3
    transformer_units = [
        projection_dim * 2,
        projection_dim,
    ]  # Size of the transformer layers
    transformer_layers = 4
    mlp_head_units = [512, 1984]  # Size of the dense layers of the final classifier

    
    def mlp(x, hidden_units, dropout_rate):
        for units in hidden_units:
            x = layers.Dense(units, activation=tf.nn.gelu)(x)
            x = layers.Dropout(dropout_rate)(x)
        return x
    class Patches(layers.Layer):
        def __init__(self, patch_size):
            super().__init__()
            self.patch_size = patch_size

        def call(self, images):
            batch_size = tf.shape(images)[0]
            patches = tf.image.extract_patches(
                images=images,
                sizes=[1, self.patch_size, self.patch_size, 1],
                strides=[1, self.patch_size, self.patch_size, 1],
                rates=[1, 1, 1, 1],
                padding="VALID",
            )
            patch_dims = patches.shape[-1]
            patches = tf.reshape(patches, [batch_size, -1, patch_dims])
            return patches

    class PatchEncoder(layers.Layer):
        def __init__(self, num_patches, projection_dim):
            super().__init__()
            self.num_patches = num_patches
            self.projection = layers.Dense(units=projection_dim)
            self.position_embedding = layers.Embedding(
                input_dim=num_patches, output_dim=projection_dim
            )

        def call(self, patch):
            positions = tf.range(start=0, limit=self.num_patches, delta=1)
            encoded = self.projection(patch) + self.position_embedding(positions)
            return encoded

    def create_vit_classifier():
        inputs = layers.Input(shape=input_shape)
        # Create patches.
        patches = Patches(patch_size)(inputs)
        # Encode patches.
        encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

        # Create multiple layers of the Transformer block.
        for _ in range(transformer_layers):
            # Layer normalization 1.
            x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
            # Create a multi-head attention layer.
            attention_output = layers.MultiHeadAttention(
                num_heads=num_heads, key_dim=projection_dim, dropout=0.1
            )(x1, x1)
            # Skip connection 1.
            x2 = layers.Add()([attention_output, encoded_patches])
            # Layer normalization 2.
            x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
            # MLP.
            x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
            # Skip connection 2.
            encoded_patches = layers.Add()([x3, x2])

        # Create a [batch_size, projection_dim] tensor.
        representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        representation = layers.Flatten()(representation)
        representation = layers.Dropout(0.5)(representation)
        # Add MLP.
        features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
        # Classify outputs.
        logits = layers.Dense(num_classes)(features)
        # Create the Keras model.
        model = keras.Model(inputs=inputs, outputs=logits)

        return model
    
    return create_vit_classifier()


In [None]:
model = None
if model_name == 'CNN':
    model = CNN(32, 4)
    model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=[
        keras.metrics.CategoricalAccuracy(name="accuracy"),
        keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
    ])

elif model_name == 'residual':
    model = residual(32, 4)
    model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=[
        keras.metrics.CategoricalAccuracy(name="accuracy"),
        keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
    ])

elif model_name == 'vit':
    model = vit()
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=learning_rate, decay=weight_decay
    )
    model.compile(
        optimizer=optimizer,
        loss=keras.losses.CategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.CategoricalAccuracy(name="accuracy"),
            keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

else:
    print('Invalid model name')
    exit()


In [None]:
model_folder = f"{model_name}_{dataset}_models"
if not os.path.exists(model_folder):
    os.makedirs(model_folder)
    print(f"Created folder '{model_folder}'.")

if not os.path.exists(model_folder + "/logs"):
    os.makedirs(model_folder + "/logs")
    print(f"Created folder '{model_folder}/logs'.")

if not os.path.exists(model_folder + "/models"):
    os.makedirs(model_folder + "/models")
    print(f"Created folder '{model_folder}/models'.")


In [None]:
print(f"Training {model_name}_{dataset}_{side}.")

f = gzip.GzipFile(f"{dataset}/boards.npy.gz", "r")
board = np.load(f)
f.close()

f = gzip.GzipFile(f"{dataset}/{side}.npy.gz", "r")
labels = np.load(f)
f.close()

assert board.shape[0] == labels.shape[0]

if model_name == 'vit':
    #board = np.transpose(board, (0, 2, 3, 1))
    print("Shape : ", board.shape)
    board = np.moveaxis(board, 1, -1)

X_train, X_validate, y_train, y_validate = train_test_split(board, labels, test_size=0.1, random_state=SEED)
board = None
labels = None

train_gen = DataGenerator(X_train, y_train, batch_size)
valid_gen = DataGenerator(X_validate, y_validate, batch_size)

history = model.fit(train_gen, epochs=num_epochs, validation_data=valid_gen, callbacks=[ 
    callbacks.ReduceLROnPlateau(monitor='loss', patience=10),
    callbacks.EarlyStopping(monitor='loss', patience=15, min_delta=1e-4, restore_best_weights=True),
    callbacks.CSVLogger(f"{model_folder}/logs/{model_name}_{dataset}_{side}.csv", separator=",", append=True)
    ])

model.save(f"{model_folder}/models/{model_name}_{dataset}_{side}.tf", save_format='tf')
print(f"Saved model : {model_folder}/models/{model_name}_{dataset}_{side}.tf")


In [None]:
f = open(f"{model_folder}/{model_name}_{dataset}_{side}_information.txt", "w")
f.write("Model Name : {}\n".format(model_name))
f.write("Dataset used : \n".format(dataset))
f.write("Side trained on : {}\n".format(side))
f.write("Seed : {}\n".format(SEED))
f.write("Batch Size : {}\n".format(batch_size))
f.close()
print("Done!")