In [4]:
using Pkg
Pkg.activate(mktempdir())
Pkg.update()
Pkg.add([
    "Flux",
    "LinearAlgebra",
    "Statistics"
])

[32m[1m  Activating[22m[39m new project at `C:\Users\79021\AppData\Local\Temp\jl_tmsXnd`
[32m[1m    Updating[22m[39m registry at `C:\Users\79021\.julia\registries\General.toml`
[36m[1m     Project[22m[39m No packages added to or removed from `C:\Users\79021\AppData\Local\Temp\jl_tmsXnd\Project.toml`
[36m[1m    Manifest[22m[39m No packages added to or removed from `C:\Users\79021\AppData\Local\Temp\jl_tmsXnd\Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m    Updating[22m[39m `C:\Users\79021\AppData\Local\Temp\jl_tmsXnd\Project.toml`
  [90m[587475ba] [39m[92m+ Flux v0.16.7[39m
  [90m[10745b16] [39m[92m+ Statistics v1.11.1[39m
  [90m[37e2e46d] [39m[92m+ LinearAlgebra v1.12.0[39m
[32m[1m    Updating[22m[39m `C:\Users\79021\AppData\Local\Temp\jl_tmsXnd\Manifest.toml`
  [90m[621f4979] [39m[92m+ AbstractFFTs v1.5.0[39m
  [90m[7d9f7c33] [39m[92m+ Accessors v0.1.43[39m
  [90m[79e6a3ab] [39m[92m+ Adapt v4.4.0[39m
  [90m[

In [6]:
using Flux
using LinearAlgebra
using Statistics

# =============================================================================
# PART A: Core Building Blocks for Manifold-Constrained Hyper-Connections
# =============================================================================

# -----------------------------------------------------------------------------
# 1. SINKHORN-KNOPP ALGORITHM
# Projects a matrix onto the Birkhoff polytope (doubly stochastic matrices)
# 
# Mathematical intuition:
# - Start with exp(M) to ensure positivity
# - Alternately normalise rows and columns
# - Converges to matrix where all rows AND columns sum to 1
# - This is an entropy-regularised optimal transport projection
# -----------------------------------------------------------------------------

function sinkhorn_knopp(M::AbstractMatrix; max_iters::Int=20, ϵ::Float32=1f-8)
    # Step 1: Make all entries positive via exponentiation
    # This maps ℝ → ℝ⁺, preserving differentiability
    P = exp.(M)

    # Step 2: Alternating normalisation
    # T_r: divide each row by its sum
    # T_c: divide each column by its sum
    for _ in 1:max_iters
        # Row normalisation: each row sums to 1
        P = P ./ (sum(P, dims=2) .+ ϵ)
        # Column normalisation: each column sums to 1
        P = P ./ (sum(P, dims=1) .+ ϵ)
    end

    return P
end

# Verify doubly stochastic property
function check_doubly_stochastic(M::AbstractMatrix; tol::Float64=1e-3)
    row_sums = vec(sum(M, dims=2))
    col_sums = vec(sum(M, dims=1))

    row_ok = all(abs.(row_sums .- 1.0) .< tol)
    col_ok = all(abs.(col_sums .- 1.0) .< tol)
    non_neg = all(M .>= 0)

    println("Row sums: ", round.(row_sums, digits=4))
    println("Col sums: ", round.(col_sums, digits=4))
    println("All non-negative: ", non_neg)
    println("Doubly stochastic: ", row_ok && col_ok && non_neg)
end

# -----------------------------------------------------------------------------
# 2. STANDARD RESIDUAL CONNECTION (Baseline)
# x_{l+1} = x_l + F(x_l)
# 
# The identity mapping property: gradient flows directly through addition
# ∂L/∂x_l = ∂L/∂x_{l+1} · (1 + ∂F/∂x_l)
# Even if ∂F/∂x_l vanishes, gradient still flows through the "1"
# -----------------------------------------------------------------------------

struct ResidualBlock
    layer::Any  # The residual function F
end

Flux.@layer ResidualBlock

function (rb::ResidualBlock)(x)
    return x .+ rb.layer(x)
end

# -----------------------------------------------------------------------------
# 3. HYPER-CONNECTIONS (HC) LAYER
# Expands residual stream width by factor n
# 
# x_{l+1} = H_res · x_l + H_post^T · F(H_pre · x_l)
#
# Where:
# - x_l ∈ ℝ^{n×C} is the expanded hidden state (n streams of C features)
# - H_pre ∈ ℝ^{1×n} aggregates n streams → 1 stream for layer input
# - H_post ∈ ℝ^{1×n} distributes layer output → n streams
# - H_res ∈ ℝ^{n×n} mixes information between streams
#
# Problem: H_res is unconstrained → eigenvalues can be >1 or <1
# Across many layers: ∏ H_res_i explodes or vanishes
# -----------------------------------------------------------------------------

struct HyperConnectionLayer
    n::Int          # Expansion factor (number of streams)
    C::Int          # Feature dimension per stream
    layer::Any      # The actual computation (attention, MLP, etc.)

    # Learnable parameters for dynamic mappings
    ϕ_pre::Any      # Projects flattened input to H_pre
    ϕ_post::Any     # Projects flattened input to H_post
    ϕ_res::Any      # Projects flattened input to H_res

    # Static bias terms
    b_pre::Any
    b_post::Any
    b_res::Any

    # Gating factors (initialised small for stability)
    α_pre::Any
    α_post::Any
    α_res::Any
end

Flux.@layer HyperConnectionLayer

function HyperConnectionLayer(n::Int, C::Int, layer; α_init::Float32=0.01f0)
    nc = n * C # Flattened dimension

    HyperConnectionLayer(
        n, C, layer,
        Dense(nC => n),             # ϕ_pre: ℝ^{nC} → ℝ^n
        Dense(nC => n),             # ϕ_post: ℝ^{nC} → ℝ^n
        Dense(nC => n*n),             # ϕ_res: ℝ^{nC} → ℝ^{n²}
        zeros(Float32, 1, n),       # b_pre
        zeros(Float32, 1, n),       # b_post
        zeros(Float32, 1, n),       # b_res (initialise to identity-like)
        [α_init],                   # α_pre
        [α_init],                   # α_post
        [α_init]                    # α_res
    )
end

function (hc::HyperConnectionLayer)(x_expanded)
    # x_expanded: (n, C, batch) - n streams of C features
    n, C, batch = size(x_expanded)

    # Flatten for computing dynamic mappings: (nC, batch)
    x_flat = reshape(x_expanded, n * C, batch)

    # Compute dynamic mappings (simplified - full version uses RMSNorm)
    H_pre = hc.α_pre[1] .* hc.ϕ_pre(x_flat)' .+ hc.b_pre # (batch, n)
    H_post = hc.α_post[1] .* hc.ϕ_post(x_flat)' .+ hc.b_post # (batch, n)

    # H_res needs reshaping: (n, n) per batch element
    H_res_flat = hc.α_res[1] .* hc.ϕ_res(x_flat)    # (n², batch)

    # For simplicity, use batch-averaged H_res (full impl is per-sample)
    H_res_mean = reshape(mean(H_res_flat, dims=2), n, n) .+ hc.b_res .+ I(n)

    # === KEY ISSUE: H_res is UNCONSTRAINED ===
    # Eigenvalues can be anything → instability across layers

    # Pre-mapping: aggregate n streams to 1 for layer input
    # h_in = H_pre ⋅ x_expanded → (C, batch)
    h_in = zeros(Float32, C, batch)
    for b in 1:batch
        h_in[:, b] = sum(H_pre[b, 1] .* x_expanded[i, :, b] for i in 1:n)
    end

    # Apply the actual layer F
    h_out = hc.layer(h_in) # (C, batch)

    # Post-mapping: distribute output to n streams
    # Residual mapping: mix existing streams
    x_next = similar(x_expanded)
    for b in 1:batch
        # Residual: H_res ⋅ x_l
        for i in 1:n
            x_next[i, :, b] = sum(H_res_mean[i, j] .* x_expanded[j, :, b] for j in 1:n)
        end

        # Add post-mapped layer output
        for i in 1:n
            x_next[i, :, b] .+= H_post[b, i] .* h_out[:, b]
        end
    end

    return x_next
end

# -----------------------------------------------------------------------------
# 4. MANIFOLD-CONSTRAINED HYPER-CONNECTIONS (mHC)
# Same structure as HC, but H_res is projected onto Birkhoff polytope
#
# Key changes:
# - H_res = Sinkhorn-Knopp(H̃_res) → doubly stochastic
# - H_pre = σ(H̃_pre) → non-negative (prevents signal cancellation)
# - H_post = 2σ(H̃_post) → non-negative, scaled
#
# Why doubly stochastic works:
# 1. Spectral norm ≤ 1: ||H_res||_2 ≤ 1 → non-expansive
# 2. Closure: product of doubly stochastic is doubly stochastic
# 3. Convex combination: H_res · x is weighted average of streams
# -----------------------------------------------------------------------------

struct ManifoldHCLayer
    n::Int
    C::Int
    layer::Any

    ϕ_pre::Any
    ϕ_post::Any
    ϕ_res::Any

    b_pre::Any
    b_post::Any
    b_res::Any

    α_pre::Any
    α_post::Any
    α_res::Any
    
    sk_iters::Int # Sinkhorn-Knopp iterations
end

Flux.@layer ManifoldHCLayer

function ManifoldHCLayer(n::Int, C::Int, layer; α_init::Float32=0.01f0, sk_iters::Int=20)
    nC = n * C

    ManifoldHCLayer(
        n, C, layer,
        Dense(nC => n),
        Dense(nC => n),
        Dense(nC => n*n),
        zeros(Float32, 1, n),
        zeros(Float32, 1, n),
        zeros(Float32, 1, n),
        [α_init],
        [α_init],
        [α_init],
        sk_iters
    )
end

function (mhc::ManifoldHCLayer)(x_expanded)
    n, C, batch = size(x_expanded)

    x_flat = reshape(x_expanded, n * C, batch)

    # Compute raw mappings
    H_pre_raw = mhc.α_pre[1] .* mhc.ϕ_pre(x_flat)' .+ mhc.b_pre
    H_post_raw = mhc.α_post[1] .* mhc.ϕ_post(x_flat)' .+ mhc.b_post
    H_res_flat = mhc.α_res[1] .* mhc.ϕ_res(x_flat)

    # === KEY DIFFERENCE: Apply manifold constraints ===

    # Non-negativity fro pre/post via sigmoid
    H_pre = sigmoid.(H_pre_raw)     # ∈ [0,1]
    H_post = 2f0 .* sigmoid.(H_post_raw)    # ∈ [0,2] for expressivity

    # Doubly stochastic constraint for residual via Sinkhorn-Knopp
    H_res_mean = reshape(mean(H_res_flat, dims=2), n, n) .+ mhc.b_res
    H_res = sinkhorn_knopp(H_res_mean; max_iters=mhc.sk_iters)

    # Now H_res has:
    # - All entries ≥ 0
    # - Each row sums to 1 → output is convex combination
    # - Each column sums to 1 → gradients are bounded
    # - Spectral norm ≤ 1 → non-expansive

    # Pre-mapping with constrained coefficients
    h_in = zeros(Float32, C, batch)
    for b in 1:batch
        h_in[:, b] = sum(H_pre[b, i] .* x_expanded[i, :, b] for i in 1:n)
    end

    h_out = mhc.layer(h_in)

    # Residual and post-mapping with constrained matrices
    x_next = similar(x_expanded)
    for b in 1:batch
        for i in 1:n
            # Constrained residual mixing
            x_next[i, :, b] = sum(H_res[i, j] .* x_expanded[j, :, b] for j in 1:n)
        end
        for i in 1:n
            x_next[i, :, b] .+= H_post[b, i] .* h_out[:, b]
        end
    end

    return x_next
end

# -----------------------------------------------------------------------------
# 5. UTILITY: Measure Signal/Gradient Gain
# This is what the paper measures in Fig. 3 and Fig. 7
#
# Forward signal gain: max row sum of H_res (how much signal amplifies)
# Backward gradient gain: max column sum (how much gradient amplifies)
#
# For stable training, both should be close to 1
# -----------------------------------------------------------------------------

function measure_gain(H::AbstractMatrix)
    # Forward: signal x → H⋅x, gain is max row sum
    forward_gain = maximum(sum(abs.(H), dims=2))

    # Backward: gradient g → H^T⋅g, gain is max column sum
    backward_gain = maximum(sum(abs.(H), dims=1))

    return (forward=forward_gain, backward=backward_gain)
end

function mesaure_composite_gain(matrices::Vector{<:AbstractMatrix})
    # Composite mapping across layers: Π H_i
    composite = matrices[1]
    for i in 2:length(matrices)
        composite = matrices[i] * composite
    end
    return measure_gain(composite)
end

# -----------------------------------------------------------------------------
# 6. DEMONSTRATION: Compare stability of random vs doubly stochastic
# -----------------------------------------------------------------------------

function demo_stability_comparison()
    println("\n" * "="^60)
    println("STABILITY COMPARISON: Random vs Doubly Stchastic")
    println("="^60)

    n = 4 # Expansion factor
    num_layers = 30

    # Generate random unconstrained matrices (like HC)
    println("\n--- unconstrained (HC-style) ---")
    hc_matrices = [randn(Float32, n, n) for _ in 1:num_layers]

    # Check single layer
    single_gain = measure_gain(hc_matrices[1])
    println("Single layer gain: forward=$(round(single_gain.forward, digits=2)), backward=$(round(single_gain.backward, digits=2))")

    # Check composite
    composite_gain = mesaure_composite_gain(hc_matrices)
    println("30-layer composite: forward=$(round(composite_gain.forward, digits=2)), backward=$(round(composite_gain.backward, digits=2))")

    # Generate doubly stochastic matrices (like mHC)
    println("\n--- Doubly stochastic (mHC-style) ---")
    mhc_matrices = [sinkhorn_knopp(randn(Float32, n, n)) for _ in 1:num_layers]

    single_gain = measure_gain(mhc_matrices[1])
    println("Single layer gain: forward=$(round(single_gain.forward, digits=2)), backward=$(round(single_gain.backward, digits=2))")

    composite_gain = mesaure_composite_gain(mhc_matrices)
    println("30-layer composite: forward=$(round(composite_gain.forward, digits=2)), backward=$(round(composite_gain.backward, digits=2))")

    # Verify doubly stochastic property
    println("\n--- Verifying doubly stochastic property ---")
    check_doubly_stochastic(mhc_matrices[1])

    return hc_matrices, mhc_matrices
end

demo_stability_comparison()


STABILITY COMPARISON: Random vs Doubly Stchastic

--- unconstrained (HC-style) ---
Single layer gain: forward=4.14, backward=5.4
30-layer composite: forward=1.4998611e8, backward=8.664067e7

--- Doubly stochastic (mHC-style) ---
Single layer gain: forward=1.0, backward=1.0
30-layer composite: forward=1.0, backward=1.0

--- Verifying doubly stochastic property ---
Row sums: Float32[1.0, 1.0, 1.0, 1.0]
Col sums: Float32[1.0, 1.0, 1.0, 1.0]
All non-negative: true
Doubly stochastic: true


(Matrix{Float32}[[-0.8730346 -1.3453312 -1.7622908 0.15638407; 1.7752821 0.017110074 0.994798 0.7747467; 1.3025111 -0.45908275 0.11434036 -1.9253013; 1.4473243 0.75216186 1.21052 -0.5488745], [1.3129697 1.7265815 2.1018589 -0.950145; -0.0031080034 1.4747337 -0.052296303 -1.8700589; 2.1753619 -1.0943469 1.4916992 -0.030708117; -1.6087803 0.12747055 -1.528769 -0.3932202], [0.872154 0.72785246 -0.22621523 -0.22506519; 0.44933677 0.64087397 0.30742508 -0.08029394; 0.23011298 -3.1592748 1.1446458 0.25878134; -0.22762509 0.20757219 0.85346437 1.4071758], [-2.0113664 -0.51709193 -0.74163526 -1.584318; -1.7815436 -0.075148694 1.9731529 0.17927128; 2.293176 0.17538112 -0.778651 -0.008514785; 0.121242985 -0.44671312 0.013792434 -2.36872], [-0.02950247 -0.058637068 -1.223052 1.1240518; -0.6855741 -0.85295117 -0.45500576 -0.54721713; -0.3468072 1.3532914 -0.054172315 0.06630569; 0.86786246 -0.2813954 1.9008653 0.43101868], [-0.9414279 0.25615162 0.63966703 0.40906414; -2.2534947 -0.4334738 0.29875