In [None]:
from typing import List

import numpy as np
import seaborn as sns
import tensorflow as tf
from matplotlib import pyplot as plt
from sklearn.metrics import (accuracy_score, confusion_matrix, f1_score,
                             precision_score, recall_score)
from tensorflow import keras

from src.models.geometric_figure import GeometricFigure
from src.services.geometric_figure import (get_geometric_figures,
                                           get_train_test_validation_split)
from src.utils.files import create_directory_if_not_exists

In [None]:
IMAGE_SIZE = (128, 128)
TEST_RATIO = 0.2
VALIDATION_RATIO = 0.1
DATA_VERSION = '2023-04-24'
MODEL_PATH = 'data/models/2023-04-24/CNN1/2023-04-25 03-11-23.h5'

In [None]:
geometric_figures: List[GeometricFigure] = get_geometric_figures(f'data/{DATA_VERSION}', IMAGE_SIZE)
print(f'Loaded {len(geometric_figures)} geometric figures')

In [None]:
model = keras.models.load_model(MODEL_PATH)

In [None]:
x_train, y_train, x_test, y_test, x_validation, y_validation = get_train_test_validation_split(
    geometric_figures,
    test_ratio=TEST_RATIO,
    validation_ratio=VALIDATION_RATIO,
    shuffle=False
)

In [None]:
predictions = model.predict(x_test)
y_test_predicted = tf.argmax(predictions, axis=1).numpy()

In [None]:
accuracy = accuracy_score(y_test, y_test_predicted)
precision = precision_score(y_test, y_test_predicted, average='weighted')
recall = recall_score(y_test, y_test_predicted, average='weighted')
f1 = f1_score(y_test, y_test_predicted, average='weighted')

print(f'Acurácia: {accuracy:.2%}')
print(f'Precisão: {precision:.2%}')
print(f'Sensibilidade: {recall:.2%}')
print(f'F1: {f1:.2%}')

In [None]:
confusion = confusion_matrix(y_test, y_test_predicted)
labels = ['circle', 'square', 'triangle', 'failed circle', 'failed square', 'failed triangle']
fig = plt.figure(figsize=(8, 6))
sns.heatmap(confusion, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.title(f'Matriz de confusão - {DATA_VERSION}')
plt.xlabel('Classe Prevista')
plt.ylabel('Classe Verdadeira')
path = f'data/confusion_matrix/{DATA_VERSION}.svg'
create_directory_if_not_exists(path)
fig.savefig(path, bbox_inches='tight', dpi=300)
plt.show()