In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.utils import to_categorical
from sklearn.metrics import confusion_matrix, classification_report

In [6]:
def ottieni_dati():
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    return X_train, y_train, X_test, y_test


(60000, 28, 28)

In [10]:
def target_handling():
    y_train = to_categorical(y_train, num_classes=10)
    y_test = to_categorical(y_test, num_classes=10)
    return y_train, y_test


In [13]:
def model_construction():
    model = Sequential()
    model.add(Flatten())
    # layer
    model.add(Dense(64, activation='relu'))
    # layer output
    model.add(Dense(10, activation='softmax'))
    # compilazione del modello
    model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [14]:
def model_fit_predict_visualization(model,X_train,y_train,X_test,y_test):
    model.fit(X_train, y_train,
                    epochs=10,
                    batch_size=32,
                    validation_split=0.1)
    # evaluate
    test_loss, test_accuracy = model.evaluate(X_test, y_test)
    print(f'Perdita sul test set: {test_loss:.4f}')
    print(f'Accuratezza sul test set: {test_accuracy:.4f}')
    # predict
    predictions = model.predict(X_test)
    # converti in etichette
    predicted_classes = np.argmax(predictions)
    true_classes = np.argmax(y_test)
    confusion_matrix = confusion_matrix(true_classes, predicted_classes)
    # confusion matrix
    plt.figure(figsize=(10,8))
    sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues')
    plt.title('Matrice di Confusione')
    plt.xlabel('Predizione')
    plt.ylabel('Vero Valore')
    plt.show()
    # report
    report = classification_report(true_classes, predicted_classes)
    print('Report di Classificazione:')
    print(report)