In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from ffnn import FFNN  # Import FFNN implementation

# Load MNIST dataset
def load_mnist_dataset():
    mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
    X, y = mnist.data, mnist.target
    X = X.astype('float32') / 255.0  # Normalize
    y = y.astype('int')
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    y_train_onehot = np.eye(10)[y_train]
    y_test_onehot = np.eye(10)[y_test]
    return X_train, X_test, y_train, y_test, y_train_onehot, y_test_onehot

# Visualize FFNN structure
def visualize_ffnn_structure(model):
    print("Visualizing Network Structure...")
    model.visualize_network_structure(highlight_weights=True, highlight_gradients=False)

# Visualize weight distribution
def visualize_weight_distribution(model):
    print("Visualizing Weight Distribution...")
    layers_to_visualize = range(len(model.layers))  # Visualize all layers
    model.plot_weight_distribution(layers=layers_to_visualize)

# Visualize gradient distribution
def visualize_gradient_distribution(model):
    print("Visualizing Gradient Distribution...")
    layers_to_visualize = range(len(model.layers))  # Visualize all layers
    model.plot_weight_gradient_distribution(layers=layers_to_visualize)

# Main function to test visualizations
def main():
    # Load dataset
    X_train, X_test, y_train, y_test, y_train_onehot, y_test_onehot = load_mnist_dataset()

    # Initialize FFNN
    model = FFNN(
        layer_sizes=[784, 128, 64, 10],
        activation_func=['relu', 'relu', 'softmax'],
        loss_func='cce',
        weight_init='uniform',
        learning_rate=0.01
    )

    # Train FFNN
    print("Training FFNN...")
    model.train(X_train, y_train_onehot, X_test, y_test_onehot, epochs=10, batch_size=64, verbose=1)

    # Visualize network structure
    visualize_ffnn_structure(model)

    # Visualize weight distribution
    visualize_weight_distribution(model)

    # Visualize gradient distribution
    visualize_gradient_distribution(model)

if __name__ == "__main__":
    main()