In [None]:
import math
from copy import deepcopy
from typing import List, Tuple

import numpy as np
from matplotlib import pyplot as plt
from tensorflow import keras

from src.models.geometric_figure import GeometricFigure
from src.services.geometric_figure import (get_geometric_figures,
                                           plot_geometric_figures,
                                           prediction_to_category,
                                           get_input_and_output,
                                           plot_geometric_figures_processed,
                                           geometric_figure_to_category_number)

In [None]:
IMAGE_SIZE = (128, 128)
DATA_VERSION = '2023-04-03'
MODEL_PATH = 'data/models/2023-04-03/CNN2/2023-04-10 09-49-03.h5'

In [None]:
geometric_figures: List[GeometricFigure] = get_geometric_figures(f'data/{DATA_VERSION}', IMAGE_SIZE)
print(f'Loaded {len(geometric_figures)} geometric figures')

In [None]:
model = keras.models.load_model(MODEL_PATH)

In [None]:
x, y = zip(*[get_input_and_output(gf) for gf in geometric_figures])
x = np.array(x)
y = np.array(y)
loss, accuracy = model.evaluate(x, y)
print(f'Loss: {loss:.2f}')
print(f'Accuracy: {accuracy:.2%}')

In [None]:
geometric_figures_error_predictions: List[Tuple[GeometricFigure, str]] = []
predictions = model.predict(x)
for i, prediction in enumerate(predictions):
    prediction_category_number = np.argmax(prediction)
    correct_category_number = geometric_figure_to_category_number(geometric_figures[i])
    if prediction_category_number != correct_category_number:
        geometric_figures_error_predictions.append((geometric_figures[i], prediction_to_category(prediction)))
print(f'Predicted {len(geometric_figures_error_predictions)} geometric figures with error')

In [None]:
plot_size = 3
columns = 7
rows = math.ceil(len(geometric_figures_error_predictions)/columns)
error_geometric_figures: List[GeometricFigure] = []
for gf, prediction in geometric_figures_error_predictions:
    new_gf = deepcopy(gf)
    new_gf['category'] = prediction
    error_geometric_figures.append(new_gf)
if len(error_geometric_figures) > 0:
    plot_geometric_figures(error_geometric_figures, columns, plot_size)
    plt.show()
    plot_geometric_figures_processed(error_geometric_figures, columns, plot_size, cmap='gray')
    plt.show()