In [None]:
from copy import deepcopy
from datetime import datetime
from typing import Dict, List

import numpy as np
import tensorflow as tf
from tensorflow import keras
from termcolor import colored

from src.models.geometric_figure import GeometricFigure
from src.services.geometric_figure import (get_geometric_figures,
                                           get_train_test_validation_split)
from src.services.result import save_model_results

In [None]:
IMAGE_SIZE = (128, 128)
TEST_RATIO = 0.2
VALIDATION_RATIO = 0.1
DATA_VERSIONS = ['2023-03-22', '2023-04-02', '2023-04-03']
MODEL_RESULTS_PATH = 'data/results.json'
NUMBER_OF_REPETITIONS = 10
EPOCHS = 100

In [None]:
ai_models: Dict[str, keras.Model] = {
    'P1': lambda: keras.Sequential([
        keras.layers.InputLayer(input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 1)),
        keras.layers.Flatten(),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dense(64, activation='relu'),
        keras.layers.Dense(32, activation='relu'),
        keras.layers.Dense(3, activation='softmax')
    ]),
    'P2': lambda: keras.Sequential([
        keras.layers.InputLayer(input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 1)),
        keras.layers.Flatten(),
        keras.layers.Dense(64, activation='relu'),
        keras.layers.Dense(32, activation='relu'),
        keras.layers.Dense(3, activation='softmax')
    ]),
    'CNN1': lambda: keras.Sequential([
        keras.layers.InputLayer(input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 1)),
        keras.layers.Conv2D(32, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((4, 4)),
        keras.layers.Conv2D(32, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((4, 4)),
        keras.layers.Flatten(),
        keras.layers.Dense(8, activation='relu'),
        keras.layers.Dense(3, activation='softmax')
    ]),
    'CNN2': lambda: keras.Sequential([
        keras.layers.InputLayer(input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 1)),
        keras.layers.Conv2D(32, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((4, 4)),
        keras.layers.Conv2D(64, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((4, 4)),
        keras.layers.Flatten(),
        keras.layers.Dense(8, activation='relu'),
        keras.layers.Dense(3, activation='softmax')
    ]),
}

In [None]:
geometric_figures_data_versions_dict: Dict[str, List[GeometricFigure]] = {
    version: get_geometric_figures(f'data/{version}', IMAGE_SIZE) for version in DATA_VERSIONS
}
for version, geometric_figures in geometric_figures_data_versions_dict.items():
    print(f'Loaded {colored(len(geometric_figures), "green")} geometric figures for version {colored(version, "green")}')

In [None]:
for data_version in DATA_VERSIONS:
    geometric_figures = geometric_figures_data_versions_dict[data_version]
    for i in range(NUMBER_OF_REPETITIONS):
        x_train, y_train, x_test, y_test, x_validation, y_validation = get_train_test_validation_split(
            geometric_figures,
            test_ratio=TEST_RATIO,
            validation_ratio=VALIDATION_RATIO
        )
        data_generator = keras.preprocessing.image.ImageDataGenerator(
            rotation_range=360,
            width_shift_range=0.1,
            height_shift_range=0.1,
            horizontal_flip=True,
            vertical_flip=True,
        )
        train_generator = data_generator.flow(x_train, y_train)
        for name, create_model in ai_models.items():
            model = create_model()
            model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
            print(f'Training {name}')
            model.fit(train_generator, epochs=EPOCHS, validation_data=(x_validation, y_validation))
            print(f'Evaluating {name}')
            loss, accuracy = model.evaluate(x_test, y_test)
            print(f'Saving {name}')
            datetime_now = datetime.now().strftime('%Y-%m-%d %H-%M-%S')
            model.save(f'data/models/{data_version}/{name}/{datetime_now}.h5')
            save_model_results(name, accuracy, datetime_now, data_version, MODEL_RESULTS_PATH)