# --- Introduction ---
# This notebook demonstrates the usage of the MnistClassifier class
# For more information please check `Readme.md`

In [1]:
# --- Imports
import tensorflow as tf
import numpy as np
from sklearn.metrics import accuracy_score
from mnist_classifier import MnistClassifier
import matplotlib.pyplot as plt

In [2]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0

In [3]:
# --- Algorithm Loop ---
algorithms = ['rf', 'nn', 'cnn']
epochs_settings = {'rf': 2, 'nn': None, 'cnn': 2} # Let's use more epochs for better visualization

for algorithm in algorithms:
    print(f"\n### Training and Evaluating {algorithm.upper()} Model ###")
    mnist_classifier = MnistClassifier(algorithm=algorithm)
    epochs_value = epochs_settings[algorithm]

    print(f"Training {algorithm.upper()} model...")
    history = mnist_classifier.train(X_train, y_train, epochs=epochs_value) # Capture history

    print(f"Making predictions with {algorithm.upper()} model...")
    predictions = mnist_classifier.predict(X_test)
    accuracy = accuracy_score(y_test, predictions)

    print(f"{algorithm.upper()} Model Accuracy: {accuracy:.4f}")
    if epochs_value is not None and algorithm != 'rf': # Indicate epochs for NN/CNN
        print(f"(Trained for {epochs_value} epochs)")

        # Plotting Loss and Accuracy (for NN and CNN only)
        if history is not None: # Check if history is available (it will be for NN/CNN)
            plt.figure(figsize=(10, 4)) # Adjust figure size as needed

            # Plot Loss
            plt.subplot(1, 2, 1) # 1 row, 2 columns, first subplot
            plt.plot(history.history['loss'])
            plt.title(f'{algorithm.upper()} Model Loss')
            plt.xlabel('Epochs')
            plt.ylabel('Loss')

            # Plot Accuracy (if accuracy is in history - it usually is)
            if 'accuracy' in history.history: # 'accuracy' for tf>=2, 'acc' for older versions
                plt_accuracy_metric = 'accuracy' if 'accuracy' in history.history else 'acc'
                plt.subplot(1, 2, 2) # 1 row, 2 columns, second subplot
                plt.plot(history.history[plt_accuracy_metric])
                plt.title(f'{algorithm.upper()} Model Accuracy')
                plt.xlabel('Epochs')
                plt.ylabel('Accuracy')

            plt.tight_layout() # Adjust layout to prevent overlapping titles/labels
            plt.show() # Display the plot


### Training and Evaluating RF Model ###
Training RF model...
Making predictions with RF model...
RF Model Accuracy: 0.9704

### Training and Evaluating NN Model ###
Training NN model...


  super().__init__(**kwargs)


ValueError: Invalid epochs value: None. Epochs must be a positive integer.

# --- Conclusion ---
# Naturally: RF < NN < CNN 
# CNN 