# Benchmarking Classical vs Genetic Neural Networks

This notebook benchmarks classical feed-forward fully connected neural networks against networks using GeneticLayer on several classification datasets from scikit-learn, including complex ones like Olivetti faces and forest cover types.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from sklearn.datasets import (
    fetch_covtype,
    fetch_olivetti_faces,
    load_breast_cancer,
    load_digits,
    load_iris,
    load_wine,
    make_classification,
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset

from examples.utils.network_factory import (
    create_classical_network,
    create_genetic_network,
)
from examples.utils.trainer_utils import train_and_evaluate

In [None]:
# Load and preprocess datasets
datasets = {
    "digits": load_digits(),
    "iris": load_iris(),
    "wine": load_wine(),
    "breast_cancer": load_breast_cancer(),
    "synthetic": make_classification(
        n_samples=1500,
        n_features=50,
        n_informative=30,
        n_redundant=10,
        n_classes=7,
        random_state=42,
    ),
    "covtype": fetch_covtype(),
    "olivetti": fetch_olivetti_faces(),
}

processed_datasets = {}
for name, data in datasets.items():
    if isinstance(data, tuple):  # For synthetic dataset
        X, y = data
    else:
        X, y = data.data, data.target
    if name == "covtype":
        # Subsample to 10k samples for manageability
        X, y = X[:10000], y[:10000]
    X = StandardScaler().fit_transform(X)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    X_train = torch.tensor(X_train, dtype=torch.float32)
    X_test = torch.tensor(X_test, dtype=torch.float32)
    y_train = torch.tensor(y_train, dtype=torch.long)
    y_test = torch.tensor(y_test, dtype=torch.long)
    train_ds = TensorDataset(X_train, y_train)
    test_ds = TensorDataset(X_test, y_test)
    train_loader = DataLoader(
        train_ds, batch_size=32, shuffle=True, num_workers=4, persistent_workers=True
    )
    test_loader = DataLoader(
        test_ds, batch_size=32, num_workers=4, persistent_workers=True
    )
    processed_datasets[name] = {
        "input_size": X.shape[1],
        "output_size": len(np.unique(y)),
        "train_loader": train_loader,
        "val_loader": test_loader,
    }

In [None]:
# Define depths for each architecture
classical_depths = [2, 4, 6, 8]
genetic_heads = [1, 2, 3, 4]
max_epochs = 64
results = {}

In [None]:
# Train networks for each dataset, depth, and architecture
for dataset_name, data in processed_datasets.items():
    input_size = data["input_size"]
    output_size = data["output_size"]
    train_loader = data["train_loader"]
    val_loader = data["val_loader"]

    for depth in classical_depths:
        # Classical network
        classical_model = create_classical_network(input_size, depth, output_size)
        classical_metrics = train_and_evaluate(
            classical_model, train_loader, val_loader, max_epochs=max_epochs
        )
        results[f"{dataset_name}_classical_{depth}"] = classical_metrics

    for n_heads in genetic_heads:
        # Genetic network
        genetic_model = create_genetic_network(input_size, n_heads, output_size)
        genetic_metrics = train_and_evaluate(
            genetic_model, train_loader, val_loader, max_epochs=max_epochs
        )
        results[f"{dataset_name}_genetic_{n_heads}"] = genetic_metrics

In [None]:
# Plot comparisons
datasets_list = list(processed_datasets.keys())
models = [("classical", depth) for depth in classical_depths] + [
    ("genetic", n_heads) for n_heads in genetic_heads
]

for dataset in datasets_list:
    fig, axes = plt.subplots(1, 2, figsize=(20, 5))

    # First subplot: Loss curves
    axes[0].set_title(f"{dataset.capitalize()} - Training and Validation Loss")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Loss")

    # Second subplot: Balanced Accuracy
    axes[1].set_title(f"{dataset.capitalize()} - Balanced Accuracy")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Balanced Accuracy")

    colors = {"classical": "green", "genetic": "blue"}
    linestyles = ["-"] * len(models)  # Same linestyle for all
    markers = ["o", "s", "^", "D", "o", "s", "^", "D"]  # Same shapes for corresponding indices

    for i, (arch, depth) in enumerate(models):
        key = f"{dataset}_{arch}_{depth}"
        metrics = results[key]

        epochs = range(1, len(metrics["train_loss_history"]) + 1)

        color = colors[arch]
        linestyle = linestyles[i]
        marker = markers[i]

        # Plot training loss
        axes[0].plot(
            epochs,
            metrics["train_loss_history"],
            color=color,
            linestyle=linestyle,
            marker=marker,
            markevery=5,  # Marker every 5 epochs for clarity
            label=f"{arch.capitalize()} {depth} - Train",
        )

        # Plot validation loss
        axes[0].plot(
            epochs,
            metrics["val_loss_history"][: len(epochs)],
            color=color,
            linestyle=linestyle,
            marker=marker,
            markevery=5,
            label=f"{arch.capitalize()} {depth} - Val",
        )

        # Plot balanced accuracy
        axes[1].plot(
            epochs,
            metrics["val_bal_acc_history"][: len(epochs)],
            color=color,
            linestyle=linestyle,
            marker=marker,
            markevery=5,
            label=f"{arch.capitalize()} {depth}",
        )

    axes[0].legend()
    axes[1].legend()

    plt.tight_layout()
    plt.show()

In [None]:
# Summary table of results
summary_data = []
for dataset in datasets_list:
    for arch, depth in models:
        key = f"{dataset}_{arch}_{depth}"
        metrics = results[key]
        summary_data.append(
            {
                "Dataset": dataset.capitalize(),
                "Architecture": arch.capitalize(),
                "Depth": depth,
                "Num_Layers": depth,
                "Num_Params": metrics["num_params"],
                "Training_Time (s)": round(metrics["training_time"], 4),
                "Final_Val_Loss": round(metrics["final_val_loss"], 4),
                "Final_Val_Acc": round(metrics["final_val_acc"], 4),
                "Final_Bal_Acc": round(metrics["val_bal_acc_history"][-1], 4),
                "Param_Time_Ratio": round(
                    metrics["num_params"] / metrics["training_time"], 4
                ),
                "Acc_Param_Efficiency": round(
                    metrics["val_bal_acc_history"][-1] / metrics["num_params"], 4
                ),
                "Avg_Train_Loss": round(
                    sum(metrics["train_loss_history"])
                    / len(metrics["train_loss_history"]),
                    4,
                ),
                "Avg_Val_Loss": round(
                    sum(
                        metrics["val_loss_history"][
                            : len(metrics["train_loss_history"])
                        ]
                    )
                    / len(metrics["train_loss_history"]),
                    4,
                ),
            }
        )

summary_df = pd.DataFrame(summary_data)
summary_df = summary_df.drop(
    columns=[
        "Num_Layers",
        "Num_Params",
        "Training_Time (s)",
        "Final_Val_Loss",
        "Final_Val_Acc",
        "Avg_Train_Loss",
        "Avg_Val_Loss",
    ]
)
summary_df = summary_df.set_index(["Dataset", "Architecture", "Depth"])
summary_df = summary_df.sort_index()
columns_order = ["Final_Bal_Acc"] + [
    col for col in summary_df.columns if col != "Final_Bal_Acc"
]
summary_df = summary_df[columns_order]


def highlight_max_per_group(df):
    # Color rows by dataset with dark mode friendly colors
    def color_rows(row):
        dataset = row.name[0]
        colors = {
            "Breast_cancer": "#4C566A",
            "Covtype": "#BF616A",
            "Digits": "#5E81AC",
            "Iris": "#81A1C1",
            "Olivetti": "#D08770",
            "Synthetic": "#88C0D0",
            "Wine": "#8FBCBB",
        }
        color = colors.get(dataset, "#2E3440")
        return [
            f"background-color: {color}; color: #ECEFF4" for _ in row
        ]  # Light text on dark background

    styled = df.style.apply(color_rows, axis=1)

    # Highlight max balanced accuracy per dataset-architecture group
    def highlight(s):
        is_max = s.groupby(level=[0, 1]).transform("max") == s
        return [
            "background-color: #A3BE8C; color: #2E3440" if v else "" for v in is_max
        ]

    return styled.apply(highlight, subset=["Final_Bal_Acc"])


display(highlight_max_per_group(summary_df))

In [None]:
pd.DataFrame(summary_data)