In [6]:
using Pkg
Pkg.activate(mktempdir())
Pkg.update()
Pkg.add([
    "Flux",
    "LinearAlgebra",
    "Statistics",
    "Plots"
])

[32m[1m  Activating[22m[39m new project at `C:\Users\79021\AppData\Local\Temp\jl_fEUt6v`
[32m[1m    Updating[22m[39m registry at `C:\Users\79021\.julia\registries\General.toml`
[36m[1m     Project[22m[39m No packages added to or removed from `C:\Users\79021\AppData\Local\Temp\jl_fEUt6v\Project.toml`
[36m[1m    Manifest[22m[39m No packages added to or removed from `C:\Users\79021\AppData\Local\Temp\jl_fEUt6v\Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m   Installed[22m[39m Xorg_xkbcomp_jll ───────────── v1.4.7+0
[32m[1m   Installed[22m[39m libdecor_jll ───────────────── v0.2.2+0
[32m[1m   Installed[22m[39m GR_jll ─────────────────────── v0.73.19+1
[32m[1m   Installed[22m[39m Xorg_xcb_util_wm_jll ───────── v0.4.2+0
[32m[1m   Installed[22m[39m Measures ───────────────────── v0.3.3
[32m[1m   Installed[22m[39m ConcurrentUtilities ────────── v2.5.0
[32m[1m   Installed[22m[39m LoggingExtras ──────────────── v1.2.0
[32m

In [14]:
using Flux
using LinearAlgebra
using Statistics

# =============================================================================
# PART A: Core Building Blocks for Manifold-Constrained Hyper-Connections
# =============================================================================

# -----------------------------------------------------------------------------
# 1. SINKHORN-KNOPP ALGORITHM
# Projects a matrix onto the Birkhoff polytope (doubly stochastic matrices)
# 
# Mathematical intuition:
# - Start with exp(M) to ensure positivity
# - Alternately normalise rows and columns
# - Converges to matrix where all rows AND columns sum to 1
# - This is an entropy-regularised optimal transport projection
# -----------------------------------------------------------------------------

function sinkhorn_knopp(M::AbstractMatrix; max_iters::Int=20, ϵ::Float32=1f-8)
    # Step 1: Make all entries positive via exponentiation
    # This maps ℝ → ℝ⁺, preserving differentiability
    P = exp.(M)

    # Step 2: Alternating normalisation
    # T_r: divide each row by its sum
    # T_c: divide each column by its sum
    for _ in 1:max_iters
        # Row normalisation: each row sums to 1
        P = P ./ (sum(P, dims=2) .+ ϵ)
        # Column normalisation: each column sums to 1
        P = P ./ (sum(P, dims=1) .+ ϵ)
    end

    return P
end

# Verify doubly stochastic property
function check_doubly_stochastic(M::AbstractMatrix; tol::Float64=1e-3)
    row_sums = vec(sum(M, dims=2))
    col_sums = vec(sum(M, dims=1))

    row_ok = all(abs.(row_sums .- 1.0) .< tol)
    col_ok = all(abs.(col_sums .- 1.0) .< tol)
    non_neg = all(M .>= 0)

    println("Row sums: ", round.(row_sums, digits=4))
    println("Col sums: ", round.(col_sums, digits=4))
    println("All non-negative: ", non_neg)
    println("Doubly stochastic: ", row_ok && col_ok && non_neg)
end

# -----------------------------------------------------------------------------
# 2. STANDARD RESIDUAL CONNECTION (Baseline)
# x_{l+1} = x_l + F(x_l)
# 
# The identity mapping property: gradient flows directly through addition
# ∂L/∂x_l = ∂L/∂x_{l+1} · (1 + ∂F/∂x_l)
# Even if ∂F/∂x_l vanishes, gradient still flows through the "1"
# -----------------------------------------------------------------------------

struct ResidualBlock
    layer::Any  # The residual function F
end

Flux.@layer ResidualBlock

function (rb::ResidualBlock)(x)
    return x .+ rb.layer(x)
end

# -----------------------------------------------------------------------------
# 3. HYPER-CONNECTIONS (HC) LAYER
# Expands residual stream width by factor n
# 
# x_{l+1} = H_res · x_l + H_post^T · F(H_pre · x_l)
#
# Where:
# - x_l ∈ ℝ^{n×C} is the expanded hidden state (n streams of C features)
# - H_pre ∈ ℝ^{1×n} aggregates n streams → 1 stream for layer input
# - H_post ∈ ℝ^{1×n} distributes layer output → n streams
# - H_res ∈ ℝ^{n×n} mixes information between streams
#
# Problem: H_res is unconstrained → eigenvalues can be >1 or <1
# Across many layers: ∏ H_res_i explodes or vanishes
# -----------------------------------------------------------------------------

struct HyperConnectionLayer
    n::Int          # Expansion factor (number of streams)
    C::Int          # Feature dimension per stream
    layer::Any      # The actual computation (attention, MLP, etc.)

    # Learnable parameters for dynamic mappings
    ϕ_pre::Any      # Projects flattened input to H_pre
    ϕ_post::Any     # Projects flattened input to H_post
    ϕ_res::Any      # Projects flattened input to H_res

    # Static bias terms
    b_pre::Any
    b_post::Any
    b_res::Any

    # Gating factors (initialised small for stability)
    α_pre::Any
    α_post::Any
    α_res::Any
end

Flux.@layer HyperConnectionLayer

function HyperConnectionLayer(n::Int, C::Int, layer; α_init::Float32=0.01f0)
    nC = n * C # Flattened dimension

    HyperConnectionLayer(
        n, C, layer,
        Dense(nC => n),             # ϕ_pre: ℝ^{nC} → ℝ^n
        Dense(nC => n),             # ϕ_post: ℝ^{nC} → ℝ^n
        Dense(nC => n*n),             # ϕ_res: ℝ^{nC} → ℝ^{n²}
        zeros(Float32, 1, n),       # b_pre
        zeros(Float32, 1, n),       # b_post
        zeros(Float32, 1, n),       # b_res (initialise to identity-like)
        [α_init],                   # α_pre
        [α_init],                   # α_post
        [α_init]                    # α_res
    )
end

function (hc::HyperConnectionLayer)(x_expanded)
    # x_expanded: (n, C, batch) - n streams of C features
    n, C, batch = size(x_expanded)

    # Flatten for computing dynamic mappings: (nC, batch)
    x_flat = reshape(x_expanded, n * C, batch)

    # Compute dynamic mappings (simplified - full version uses RMSNorm)
    H_pre = hc.α_pre[1] .* hc.ϕ_pre(x_flat)' .+ hc.b_pre # (batch, n)
    H_post = hc.α_post[1] .* hc.ϕ_post(x_flat)' .+ hc.b_post # (batch, n)

    # H_res needs reshaping: (n, n) per batch element
    H_res_flat = hc.α_res[1] .* hc.ϕ_res(x_flat)    # (n², batch)

    # For simplicity, use batch-averaged H_res (full impl is per-sample)
    H_res_mean = reshape(mean(H_res_flat, dims=2), n, n) .+ hc.b_res .+ I(n)

    # === KEY ISSUE: H_res is UNCONSTRAINED ===
    # Eigenvalues can be anything → instability across layers

    # Pre-mapping: aggregate n streams to 1 for layer input
    # h_in = H_pre ⋅ x_expanded → (C, batch)
    h_in = zeros(Float32, C, batch)
    for b in 1:batch
        h_in[:, b] = sum(H_pre[b, 1] .* x_expanded[i, :, b] for i in 1:n)
    end

    # Apply the actual layer F
    h_out = hc.layer(h_in) # (C, batch)

    # Post-mapping: distribute output to n streams
    # Residual mapping: mix existing streams
    x_next = similar(x_expanded)
    for b in 1:batch
        # Residual: H_res ⋅ x_l
        for i in 1:n
            x_next[i, :, b] = sum(H_res_mean[i, j] .* x_expanded[j, :, b] for j in 1:n)
        end

        # Add post-mapped layer output
        for i in 1:n
            x_next[i, :, b] .+= H_post[b, i] .* h_out[:, b]
        end
    end

    return x_next
end

# -----------------------------------------------------------------------------
# 4. MANIFOLD-CONSTRAINED HYPER-CONNECTIONS (mHC)
# Same structure as HC, but H_res is projected onto Birkhoff polytope
#
# Key changes:
# - H_res = Sinkhorn-Knopp(H̃_res) → doubly stochastic
# - H_pre = σ(H̃_pre) → non-negative (prevents signal cancellation)
# - H_post = 2σ(H̃_post) → non-negative, scaled
#
# Why doubly stochastic works:
# 1. Spectral norm ≤ 1: ||H_res||_2 ≤ 1 → non-expansive
# 2. Closure: product of doubly stochastic is doubly stochastic
# 3. Convex combination: H_res · x is weighted average of streams
# -----------------------------------------------------------------------------

struct ManifoldHCLayer
    n::Int
    C::Int
    layer::Any

    ϕ_pre::Any
    ϕ_post::Any
    ϕ_res::Any

    b_pre::Any
    b_post::Any
    b_res::Any

    α_pre::Any
    α_post::Any
    α_res::Any
    
    sk_iters::Int # Sinkhorn-Knopp iterations
end

Flux.@layer ManifoldHCLayer

function ManifoldHCLayer(n::Int, C::Int, layer; α_init::Float32=0.01f0, sk_iters::Int=20)
    nC = n * C

    ManifoldHCLayer(
        n, C, layer,
        Dense(nC => n),
        Dense(nC => n),
        Dense(nC => n*n),
        zeros(Float32, 1, n),
        zeros(Float32, 1, n),
        zeros(Float32, 1, n),
        [α_init],
        [α_init],
        [α_init],
        sk_iters
    )
end

function (mhc::ManifoldHCLayer)(x_expanded)
    n, C, batch = size(x_expanded)

    x_flat = reshape(x_expanded, n * C, batch)

    # Compute raw mappings
    H_pre_raw = mhc.α_pre[1] .* mhc.ϕ_pre(x_flat)' .+ mhc.b_pre
    H_post_raw = mhc.α_post[1] .* mhc.ϕ_post(x_flat)' .+ mhc.b_post
    H_res_flat = mhc.α_res[1] .* mhc.ϕ_res(x_flat)

    # === KEY DIFFERENCE: Apply manifold constraints ===

    # Non-negativity fro pre/post via sigmoid
    H_pre = sigmoid.(H_pre_raw)     # ∈ [0,1]
    H_post = 2f0 .* sigmoid.(H_post_raw)    # ∈ [0,2] for expressivity

    # Doubly stochastic constraint for residual via Sinkhorn-Knopp
    H_res_mean = reshape(mean(H_res_flat, dims=2), n, n) .+ mhc.b_res
    H_res = sinkhorn_knopp(H_res_mean; max_iters=mhc.sk_iters)

    # Now H_res has:
    # - All entries ≥ 0
    # - Each row sums to 1 → output is convex combination
    # - Each column sums to 1 → gradients are bounded
    # - Spectral norm ≤ 1 → non-expansive

    # Pre-mapping with constrained coefficients
    h_in = zeros(Float32, C, batch)
    for b in 1:batch
        h_in[:, b] = sum(H_pre[b, i] .* x_expanded[i, :, b] for i in 1:n)
    end

    h_out = mhc.layer(h_in)

    # Residual and post-mapping with constrained matrices
    x_next = similar(x_expanded)
    for b in 1:batch
        for i in 1:n
            # Constrained residual mixing
            x_next[i, :, b] = sum(H_res[i, j] .* x_expanded[j, :, b] for j in 1:n)
        end
        for i in 1:n
            x_next[i, :, b] .+= H_post[b, i] .* h_out[:, b]
        end
    end

    return x_next
end

# -----------------------------------------------------------------------------
# 5. UTILITY: Measure Signal/Gradient Gain
# This is what the paper measures in Fig. 3 and Fig. 7
#
# Forward signal gain: max row sum of H_res (how much signal amplifies)
# Backward gradient gain: max column sum (how much gradient amplifies)
#
# For stable training, both should be close to 1
# -----------------------------------------------------------------------------

function measure_gain(H::AbstractMatrix)
    # Forward: signal x → H⋅x, gain is max row sum
    forward_gain = maximum(sum(abs.(H), dims=2))

    # Backward: gradient g → H^T⋅g, gain is max column sum
    backward_gain = maximum(sum(abs.(H), dims=1))

    return (forward=forward_gain, backward=backward_gain)
end

function mesaure_composite_gain(matrices::Vector{<:AbstractMatrix})
    # Composite mapping across layers: Π H_i
    composite = matrices[1]
    for i in 2:length(matrices)
        composite = matrices[i] * composite
    end
    return measure_gain(composite)
end

# -----------------------------------------------------------------------------
# 6. DEMONSTRATION: Compare stability of random vs doubly stochastic
# -----------------------------------------------------------------------------

function demo_stability_comparison()
    println("\n" * "="^60)
    println("STABILITY COMPARISON: Random vs Doubly Stchastic")
    println("="^60)

    n = 4 # Expansion factor
    num_layers = 30

    # Generate random unconstrained matrices (like HC)
    println("\n--- unconstrained (HC-style) ---")
    hc_matrices = [randn(Float32, n, n) for _ in 1:num_layers]

    # Check single layer
    single_gain = measure_gain(hc_matrices[1])
    println("Single layer gain: forward=$(round(single_gain.forward, digits=2)), backward=$(round(single_gain.backward, digits=2))")

    # Check composite
    composite_gain = mesaure_composite_gain(hc_matrices)
    println("30-layer composite: forward=$(round(composite_gain.forward, digits=2)), backward=$(round(composite_gain.backward, digits=2))")

    # Generate doubly stochastic matrices (like mHC)
    println("\n--- Doubly stochastic (mHC-style) ---")
    mhc_matrices = [sinkhorn_knopp(randn(Float32, n, n)) for _ in 1:num_layers]

    single_gain = measure_gain(mhc_matrices[1])
    println("Single layer gain: forward=$(round(single_gain.forward, digits=2)), backward=$(round(single_gain.backward, digits=2))")

    composite_gain = mesaure_composite_gain(mhc_matrices)
    println("30-layer composite: forward=$(round(composite_gain.forward, digits=2)), backward=$(round(composite_gain.backward, digits=2))")

    # Verify doubly stochastic property
    println("\n--- Verifying doubly stochastic property ---")
    check_doubly_stochastic(mhc_matrices[1])

    return hc_matrices, mhc_matrices
end

demo_stability_comparison()


STABILITY COMPARISON: Random vs Doubly Stchastic

--- unconstrained (HC-style) ---
Single layer gain: forward=4.58, backward=4.06
30-layer composite: forward=1.1157204e9, backward=2.802629e9

--- Doubly stochastic (mHC-style) ---
Single layer gain: forward=1.0, backward=1.0
30-layer composite: forward=1.0, backward=1.0

--- Verifying doubly stochastic property ---
Row sums: Float32[1.0, 1.0, 1.0, 1.0]
Col sums: Float32[1.0, 1.0, 1.0, 1.0]
All non-negative: true
Doubly stochastic: true


(Matrix{Float32}[[0.11284312 0.34406877 -0.68308604 0.61050475; 0.69754755 -0.94521415 0.7930082 -1.330823; 1.8120838 -1.6061634 -0.62386453 0.5416502; 1.0643485 -1.1655983 1.5447383 -0.6329996], [-0.45682418 1.6256608 -0.7404751 -0.4840161; -1.7511808 -0.69520056 -2.6625447 1.015289; -0.83744967 -0.09121479 -0.16222441 0.39838046; -0.7068127 0.13429482 -0.554819 -0.46567833], [-2.2253518 0.67969984 0.7834992 0.50639355; 1.614896 0.27227393 -1.2053431 1.6054468; 0.9186541 -1.7763978 -0.26800478 0.17776069; 1.2163626 0.22726877 0.48391786 1.9790481], [-1.5200093 0.9575117 1.4383426 0.97984815; -1.9807007 1.9034126 0.70416915 0.9489625; 0.20598412 -0.8418422 -1.0052196 -0.66530526; 1.9061929 0.73347324 2.4500709 -0.14837994], [-0.77188253 1.4073231 -1.702729 -0.0060323817; -0.5140237 1.0368892 0.055263437 0.868878; 0.5888396 0.39154175 0.9200625 -0.1971379; 1.8708401 1.4196786 0.7532068 -1.695753], [-0.41828302 -0.67458814 0.7085174 0.34073806; -1.5079988 -2.8482811 1.1855991 -1.4713229;

In [13]:
using Flux
using Flux: onehotbatch, onecold
using Statistics
using Random

# -----------------------------------------------------------------------------
# 1. STANDARD RESIDUAL MLP
# Simple baseline: stack of residual blocks
# x_{l+1} = x_l + MLP(x_l)
# -----------------------------------------------------------------------------

struct ResidualMLP
    input_proj::Any     # Project input to hidden dim
    blocks::Any         # Vector or residual blocks
    output_proj::Any    # Project to output classes
    norm::Any           # Final normalisation
end

Flux.@layer ResidualMLP

function ResidualMLP(input_dim::Int, hidden_dim::Int, output_dim::Int, num_layers::Int)
    # Each residual block: LayerNorm → Dense → ReLU → Dense
    make_block() = ResidualBlock(
        Chain(
            LayerNorm(hidden_dim),
            Dense(hidden_dim => hidden_dim * 4, relu),
            Dense(hidden_dim * 4 => hidden_dim)
        )
    )

    ResidualMLP(
        Dense(input_dim => hidden_dim),
        [make_block() for _ in 1:num_layers],
        Dense(hidden_dim => output_dim),
        LayerNorm(hidden_dim)
    )
end

function (m::ResidualMLP)(x)
    # x: (input_dim, batch)
    h = m.input_proj(x)

    for block in m.blocks
        h = block(h)
    end

    h = m.norm(h)
    return m.output_proj(h)
end

# -----------------------------------------------------------------------------
# 2. HYPER-CONNECTIONS MLP
# Expands residual stream by factor n
# Demonstrates the instability problem
# -----------------------------------------------------------------------------

struct HCMLP
    n::Int      # Expansion factor
    input_proj::Any     # Project to expanded hidden state
    blocks::Any         # Vector of HC layers
    output_proj::Any    # Collapse and project to output
    norm::Any
end

Flux.@layer HCMLP

function HCMLP(input_dim::Int, hidden_dim::Int, output_dim::Int, num_layers::Int; n::Int=4)
    # Inner MLP for each HC layer
    make_inner_mlp() = Chain(
        LayerNorm(hidden_dim),
        Dense(hidden_dim => hidden_dim * 4, relu),
        Dense(hidden_dim * 4 => hidden_dim)
    )

    HCMLP(
        n,
        Dense(input_dim => hidden_dim * n),     # Expand to n streams
        [HyperConnectionLayer(n, hidden_dim, make_inner_mlp()) for _ in 1:num_layers],
        Chain(
            x -> mean(x, dims=1)[1, :, :],      # Average across streams
            LayerNorm(hidden_dim),
            Dense(hidden_dim => output_dim)
        ),
        LayerNorm(hidden_dim)
    )
end

function (m::HCMLP)(x)
    batch = size(x, 2)

    # Project and reshape to (n, C, batch)
    h_flat = m.input_proj(x)    # (n*C, batch)
    h = reshape(h_flat, m.n, size(h_flat, 1) ÷ m.n, batch)

    for block in m.blocks
        h = block(h)
    end

    return m.output_proj(h)
end

# -----------------------------------------------------------------------------
# 3. MANIFOLD-CONSTRAINED HC MLP
# Same structure as HC but with stability guarantees
# This is the main contribution of the mHC paper
# -----------------------------------------------------------------------------

struct mHCMLP
    n::Int
    input_proj::Any
    blocks::Any
    output_proj::Any
    norm::Any
end

Flux.@layer mHCMLP

function mHCMLP(input_dim::Int, hidden_dim::Int, output_dim::Int, num_layers::Int;
    n::Int=4, sk_iters::Int=20)
    make_inner_mlp() = Chain(
        LayerNorm(hidden_dim),
        Dense(hidden_dim => hidden_dim * 4, relu),
        Dense(hidden_dim * 4 => hidden_dim)
    )

    mHCMLP(
        n,
        Dense(input_dim => hidden_dim * n),
        [ManifoldHCLayer(n, hidden_dim, make_inner_mlp(); sk_iters=sk_iters) for _ in 1:num_layers],
        Chain(
            x -> mean(x, dims=1)[1, :, :],
            LayerNorm(hidden_dim),
            Dense(hidden_dim => output_dim)
        ),
        LayerNorm(hidden_dim)
    )
end

function (m::mHCMLP)(x)
    batch = size(x, 2)

    h_flat = m.input_proj(x)
    h = reshape(h_flat, m.n, size(h_flat, 1) ÷ m.n, batch)

    for block in m.blocks
        h = block(h)
    end

    return m.output_proj(h)
end

# -----------------------------------------------------------------------------
# 4. TRAINING UTILITIES
# -----------------------------------------------------------------------------

# Cross-entropy loss with softmax
function ce_loss(model, x, y)
    logits = model(x)
    return Flux.logitcrossentropy(logits, y)
end

# Accurocy computation
function accuracy(model, x, y)
    logits = model(x)
    preds = onecold(logits)
    targets = onecold(y)
    return mean(preds .== targets)
end

# Gradient norm for monitoring stability
function compute_grad_norm(grads)
    total = 0.0f0
    for (_, g) in pairs(grads)
        if g !== nothing
            if g isa AbstractArray
                total += sum(abs2, g)
            elseif g isa NamedTuple || g isa Tuple
                total += compute_grad_norm(g)
            end
        end
    end
    return sqrt(total)
end

# Extract H_res matrices for analysis
function extract_hres_matrices(model::mHCMLP)
    matrices = Matrix{Float32}[]
    for block in model.blocks
        # Compute H_res for a dummy input
        n, C = block.n, block.C
        dummy = zeros(Float32, n * C, 1)
        H_res_flat = block.α_res[1] * block.ϕ_res(dummy)
        H_res_raw = reshape(H_res_flat, n, n) .+ block.b_res
        H_res = sinkhorn_knopp(H_res_raw; max_iters=block.sk_iters)
        push!(matrices, H_res)
    end
    return matrices
end

function extract_hres_matrices(model::HCMLP)
    matrices = Matrix{Float32}[]
    for block in model.blocks
        n, C = block.n, block.C
        dummy = zeros(Float32, n * C, 1)
        H_res_flat = block.α_res[1] .* block.ϕ_res(dummy)
        H_res = reshape(H_res_flat, n, n) .+ block.b_res .+ I(n)
        push!(matrices, H_res)
    end
    return matrices
end

# -----------------------------------------------------------------------------
# 5. TRAINING LOOP WITH METRICS COLLECTION
# Returns history for plotting
# -----------------------------------------------------------------------------

struct TrainingHistory
    losses::Vector{Float32}
    accuracies::Vector{Float32}
    grad_norms::Vector{Float32}
    forward_gains::Vector{Float32}  # Signal propagation stability
    backward_gains::Vector{Float32} # Gradient propagation stability
end

TrainingHistory() = TrainingHistory(
    Float32[], Float32[], Float32[], Float32[], Float32[]
)

function train_epoch!(model, opt_state, train_x, train_y, history::TrainingHistory)
    batch_size = 64
    n_samples = size(train_x, 2)
    n_batches = n_samples ÷ batch_size

    epoch_loss = 0.0f0
    epoch_grad_norm = 0.0f0

    for i in 1:n_batches
        idx = ((i-1)*batch_size + 1):(i*batch_size)
        x_batch = train_x[:, idx]
        y_batch = train_y[:, idx]

        # Compute loss and gradients
        loss, grads = Flux.withgradient(model) do m
            ce_loss(m, x_batch, y_batch)
        end

        # Update parameters
        Flux.update!(opt_state, model, grads[1])

        epoch_loss += loss
        epoch_grad_norm += compute_grad_norm(grads[1])
    end

    # Record metrics
    push!(history.losses, epoch_loss / n_batches)
    push!(history.grad_norms, epoch_grad_norm / n_batches)

    # Compute propagation stability (if HC or mHC)
    if model isa HCMLP || model isa mHCMLP
        matrices = extract_hres_matrices(model)
        if !isempty(matrices)
            composite_gain = measure_composite_gain(matrices)
            push!(history.forward_gains, composite_gain.forward)
            push!(history.backward_gains, composite_gain.backward)
        end
    else
        # For standard residual, gain is always 1
        push!(history.forward_gains, 1.0f0)
        push!(history.backward_gains, 1.0f0)
    end

    return history
end

# -----------------------------------------------------------------------------
# 6. SYNTHETIC DATASET GENERATOR
# For controlled experiments
# -----------------------------------------------------------------------------

function generate_synthetic_data(n_samples::Int, input_dim::Int, n_classes::Int; seed::Int=42)
    Random.seed!(seed)

    # Generate cluster centers
    centers = randn(Float32, input_dim, n_classes) .* 3

    # Generate samples around centers
    samples_per_class = n_samples ÷ n_classes

    X = zeros(Float32, input_dim, n_samples)
    Y = zeros(Float32, n_classes, n_samples)

    for c in 1:n_classes
        idx_start = (c-1) * samples_per_class + 1
        idx_end = c * samples_per_class

        X[:, idx_start:idx_end] = centers[:, c] .+ randn(Float32, input_dim, samples_per_class) .* 0.5
        Y[c, idx_start:idx_end] .= 1.0f0
    end

    # Shuffle
    perm = randperm(n_samples)
    return X[:, perm], Y[:, perm]
end

# -----------------------------------------------------------------------------
# 7. MODEL FACTORY
# Creates models with comparable parameter counts
# -----------------------------------------------------------------------------

function create_models(input_dim::Int, hidden_dim::Int, output_dim::Int, num_layers::Int; n::Int=4)
    println("\nCreating models...")

    # Standard residual
    residual = ResidualMLP(input_dim, hidden_dim, output_dim, num_layers)
    n_params_res = sum(length, Flux.params(residual))
    println("ResidualMLP: $(n_params_res) parameters")

    # Hyper-CONNECTIONS
    hc = HCMLP(input_dim, hidden_dim, output_dim, num_layers; n=n)
    n_params_hc = sum(length, Flux.params(hc))
    println("HCMLP: $(n_params_hc) parameters")

    # Manifold-Contrained HC
    mhc = mHCMLP(input_dim, hidden_dim, output_dim, num_layers; n=n)
    n_params_mhc = sum(length, Flux.params(mhc))
    println("mHCMLP: $(n_params_mhc) parameters")

    return residual, hc, mhc
end

# -----------------------------------------------------------------------------
# 8. FULL TRAINING COMPARISON
# -----------------------------------------------------------------------------

function run_training_comparison(;
    input_dim::Int=64,
    hidden_dim::Int=128,
    output_dim::Int=10,
    num_layers::Int=8,
    n_samples::Int=5000,
    n_epochs::Int=30,
    lr::Float64=1e-3,
    expansion::Int=4
    )

    println("\n" * "="^60)
    println("Training comparison: Residual vs HC vs mHC")
    println("="^60)
    println("Layers: $num_layers, Hidden: $hidden_dim, Expansion: $expansion")

    # Generate Data
    train_x, train_y = generate_synthetic_data(n_samples, input_dim, output_dim)
    println("Data shape: $(size(train_x)), $(size(train_y))")

    # Create models
    residual, hc, mhc = create_models(input_dim, hidden_dim, output_dim, num_layers; n=expansion)

    # Training histories
    hist_res = TrainingHistory()
    hist_hc = TrainingHistory()
    hist_mhc = TrainingHistory()

    # Optimisers
    opt_res = Flux.setup(Adam(lr), residual)
    opt_hc = Flux.setup(Adam(lr), hc)
    opt_mhc = Flux.setup(Adam(lr), mhc)

    println("\nTraining...")
    for epoch in 1:n_epochs
        train_epoch!(residual, opt_res, train_x, train_y, hist_res)
        train_epoch!(hc, opt_hc, train_x, train_y, hist_hc)
        train_epoch!(mhc, opt_mhc, train_x, train_y, hist_mhc)

        if epoch % 5 == 0 || epoch == 1
            println("Epoch $epoch:")
            println("   Residual - Loss: $(round(hist_res.losses[end], digits=4)), 
            Acc: $(round(hist_res.accuracies[end]*100, digits=1))%")
            println("   HC - Loss: $(round(hist_hc.losses[end], digits=4)), 
            Acc: $(round(hist_hc.accuracies[end]*100, digits=1))%")
            println("   mHC - Loss: $(round(hist_mhc.losses[end], digits=4)), 
            Acc: $(round(hist_mhc.accuracies[end]*100, digits=1))%")
        end
    end
    
    return (
        models = (residual=residual, hc=hc, mhc=mhc),
        histories = (residual=hist_res, hc=hist_hc, mhc=hist_mhc)
    )
end

run_training_comparison()


Training comparison: Residual vs HC vs mHC
Layers: 8, Hidden: 128, Expansion: 4
Data shape: (64, 5000), (10, 5000)

Creating models...
ResidualMLP: 1065610 parameters


UndefVarError: UndefVarError: `nC` not defined in `Main`
Suggestion: check for spelling errors or missing imports.

In [None]:
using Plots
using Statistics
using Printf

# =============================================================================
# PART C: Experiments and Visualisation
# Replicates key analyses from the mHC paper
# =============================================================================

# -----------------------------------------------------------------------------
# 1. EXPERIMENT: Matrix Propagation Stability
# Demonstrates why doubly stochastic constraint matters
# Similar to Fig. 3 vs Fig. 7 in the paper
# -----------------------------------------------------------------------------

function experiment_propagation_stability(; n::Int=4, num_layers::Int=60)
    println("\n" * "="^60)
    println("Experiment 1: Propagation stability analysis")
    println("="^60)

    # Generate unconstrained matrices (HC-style)
    hc_matrices = [randn(Float32, n, n) .* 0.5f0 .+ I(n) for _ in 1:num_layers]

    # Generated doubly stochastic matrices (mHC-style)
    mhc_matrices = [sinkhorn_knopp(randn(Float32, n, n)) for _ in 1:num_layers]

    # Compute gains at each layer depth
    hc_forward = Float32[]
    hc_backward = Float32[]
    mhc_forward = Float32[]
    mhc_backward = Float32[]

    for l in 1:num_layers
        # Composite mapping up to layer l
        hc_composite = reduce(*, hc_matrices[1:l])
        mhc_composite = reduce(*, mhc_matrices[1:l])

        hc_gain = measure_gain(hc_composite)
        mhc_gain = measure_gain(mhc_composite)

        push!(hc_forward, hc_gain.forward)
        push!(hc_backward, hc_gain.backward)
        push!(mhc_forward, mhc_gain.forward)
        push!(mhc_backward, mhc_gain.backward)
    end

    # Create Visualisation
    p1 = plot(1:num_layers, hc_forward,
        label="HC Forward",
        ylabel="Gain magnitude",
        xlabel="Layer Depth",
        title="Signal Propagation gain",
        yscale=:log10,
        linewidth=2,
        color=:red
    )
    plot!(p1, 1:num_layers, mhc_forward,
        label="mHC Forward",
        linewidth=2,
        color=:blue
    )
    hline!(p1, [0,0], label="Ideal (1.0)", linestyle=:dash, color=:black)

    p2 = plot(1:num_layers, hc_backward,
        label="HC Backward",
        ylabel="Gain magnitude",
        xlabel="Layer depth",
        title="Gradient propagation gain",
        yscale=:log10,
        linewidth=2,
        color=:red
    )
    plot!(p2, 1:num_layers, mhc_backward,
        label="mHC Backward",
        linewidth=2,
        color=:blue
    )
    hline!(p2, [1.0], label="Idel (1.0)", linestyle=:dash, color=:black)

    p = plot(p1, p2, layout=(1, 2), size=(900, 400))
    savefig(p, "propagation_stability.png")
    println("Saved: propagation_stability.png")

    # Print summary Statistics
    println("\nFinal layer gains (layer $num_layers):")
    println("   HC - Forward: $(@sprintf("%.2e", hc_forward[end])), 
    Backward: $(@sprintf("%.2e", hc_backward[end]))")
    println("   mHC - Forward: $(@sprintf("%.4f", mhc_forward[end])), 
    Backward: $(@sprintf("%.4f", mhc_backward[end]))")

    return (hc=(forward=hc_forward, backward=hc_backward),
            mhc=(forward=mhc_forward, backward=mhc_backward))
end

# -----------------------------------------------------------------------------
# 2. EXPERIMENT: Training Dynamics Comparison
# Compares loss curves, gradient norms, and stability metrics
# Similar to Fig. 5 in the paper
# -----------------------------------------------------------------------------

