In [None]:
from copy import deepcopy
from typing import List

import numpy as np
from matplotlib import pyplot as plt
from tensorflow import keras

from src.models.geometric_figure import GeometricFigure
from src.services.geometric_figure import (get_geometric_figures,
                                           get_input_and_output,
                                           plot_geometric_figures,
                                           prediction_to_category)

In [None]:
IMAGE_SIZE = (512, 512)
TEST_RATIO = 0.2
DATA_VERSION = '2023-04-02'

In [None]:
geometric_figures: List[GeometricFigure] = get_geometric_figures(f'data/{DATA_VERSION}', IMAGE_SIZE, memorize=True)
print(f'Loaded {len(geometric_figures)} geometric figures')

In [None]:
shuffled_geometric_figures: List[GeometricFigure] = deepcopy(geometric_figures)
np.random.shuffle(shuffled_geometric_figures)

x, y = zip(*[get_input_and_output(gf) for gf in shuffled_geometric_figures])
x = np.array(x)
y = np.array(y)

test_size = int(len(x) * TEST_RATIO)
x_train, x_test = x[:-test_size], x[-test_size:]
y_train, y_test = y[:-test_size], y[-test_size:]

In [None]:
model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 1)),
    keras.layers.Flatten(),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(3, activation='softmax')
])
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
model.summary()

In [None]:
model.fit(x_train, y_train, epochs=20, validation_split=0.2)

In [None]:
model.evaluate(x_test, y_test)

In [None]:
number_of_plots = 15
random_geometric_figures = deepcopy([geometric_figures[i] for i in np.random.randint(0, len(geometric_figures), number_of_plots)])
for gf in random_geometric_figures:
    x, y = get_input_and_output(gf)
    x = np.expand_dims(x, axis=0)
    y_pred = model.predict(x)
    category = prediction_to_category(y_pred)
    gf['category'] = category
plot_geometric_figures(random_geometric_figures, columns=5, plot_size=3)
plt.show()

In [None]:
model.save('data/geometric_figure_classifier.h5')