In [None]:
from copy import deepcopy
from typing import List

import numpy as np
from matplotlib import pyplot as plt
from tensorflow import keras

from src.models.geometric_figure import GeometricFigure
from src.services.geometric_figure import (get_geometric_figures,
                                           get_input_and_output,
                                           plot_geometric_figures,
                                           prediction_to_category,
                                           get_train_test_validation_split)

In [None]:
IMAGE_SIZE = (128, 128)
TEST_RATIO = 0.2
VALIDATION_RATIO = 0.1
DATA_VERSION = '2023-04-02'
EPOCHS = 50

In [None]:
geometric_figures: List[GeometricFigure] = get_geometric_figures(f'data/{DATA_VERSION}', IMAGE_SIZE, memorize=True)
print(f'Loaded {len(geometric_figures)} geometric figures')

In [None]:
model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 1)),
    keras.layers.Flatten(),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(3, activation='softmax')
])
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
model.summary()

In [None]:
x_train, y_train, x_test, y_test, x_validation, y_validation = get_train_test_validation_split(
    geometric_figures,
    test_ratio=TEST_RATIO,
    validation_ratio=VALIDATION_RATIO
)
print(f'x_train size: {x_train.shape[0]}')
print(f'x_test size: {x_test.shape[0]}')
print(f'x_validation size: {x_validation.shape[0]}')

In [None]:
data_generator = keras.preprocessing.image.ImageDataGenerator(
    rotation_range=360,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=True,
)
train_generator = data_generator.flow(x_train, y_train)
model.fit(train_generator, epochs=EPOCHS, validation_data=(x_validation, y_validation))

In [None]:
model.evaluate(x_test, y_test)

In [None]:
number_of_plots = 15
random_geometric_figures = deepcopy([geometric_figures[i] for i in np.random.randint(0, len(geometric_figures), number_of_plots)])
for gf in random_geometric_figures:
    x, y = get_input_and_output(gf)
    x = np.expand_dims(x, axis=0)
    y_pred = model.predict(x)
    category = prediction_to_category(y_pred)
    gf['category'] = category
plot_geometric_figures(random_geometric_figures, columns=5, plot_size=3)
plt.show()

In [None]:
model.save('data/geometric_figure_classifier.h5')