In [None]:
import os
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Flatten, Conv2D, MaxPool2D, BatchNormalization, Activation, Dropout, Input
from keras.utils import plot_model
from keras.callbacks import EarlyStopping
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from keras.callbacks import ReduceLROnPlateau

In [None]:
import seaborn as sns
import tensorflow as tf
import numpy as np

In [None]:
input_length_h, input_length_w = 64, 64

In [None]:
data_generator = ImageDataGenerator(rescale = 1./255)

train_generator = data_generator.flow_from_directory(
    './data/',
    target_size = (input_length_h, input_length_w),
    batch_size = 551,
    class_mode='categorical'
)

In [None]:
X_data, y_data = train_generator.next()

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    X_data, y_data, test_size = 0.1, stratify = y_data
)

In [None]:
model = Sequential()

# 1층
model.add(Conv2D(filters=64, kernel_size=(3,3), padding="same", input_shape = (input_length_h, input_length_w, 3), activation="relu", name = 'conv1'))
model.add(Conv2D(filters=64, kernel_size=(3,3), padding="same", activation="relu", name = 'conv2'))
model.add(MaxPool2D((2, 2), name = 'pool1'))

# 출력층
model.add(Flatten())
model.add(Dense(4096, activation="relu", name = 'dense1'))
model.add(Dropout(0.6))
model.add(Dense(2048, activation="relu", name = 'dense2'))
model.add(Dropout(0.6))
model.add(Dense(6, activation="softmax", name = 'output'))

# compile
model.compile(loss = 'categorical_crossentropy', optimizer='adam', metrics = ['accuracy'])

In [None]:
history = model.fit(
    X_train, y_train,
    epochs=15,
    validation_data=(X_test, y_test),
    )

In [None]:
y_pred = np.argmax(model.predict(X_test), axis=-1)
y = np.argmax(y_test, axis = 1)

plt.figure(figsize =(10, 8))
class_names = ['Bishop', 'King', 'Knight', 'Pawn', ' Queen', 'Rook']
conf = tf.math.confusion_matrix(labels=y, predictions=y_pred)
sns.heatmap(conf, annot=True, cmap='Blues', yticklabels=class_names, xticklabels=class_names)
plt.show()