# Table of contents
1. [Introduction](#intro)
1. [Setup](#setup)
2. [Methods](#methods)
    1. [Original tangent approximation](#org-tangent)
    2. [Adapted tangent approximation](#adapted-tangent)
    3. [Finite difference approximation](#finite-diff)

# Introduction: Differentiation of matrix exp for weighted GNMs <a name="intro"></a>
From: https://www.biorxiv.org/content/10.1101/2023.06.23.546237v1?med=mas.

This is mostly concerned with finding the fastest way to compute the derivative of the matrix exponential.
$$
f(w_{i,j}) = (c_{i,k} * d_{i,j})^{\omega}
$$
Where, the communicability matrix is defined as follows. This captures the proportion of signals, that propagate randomly from node i would reach node j over an infinite time horizon.
$$
c_{i,j} = e^{s^{-1/2}w_{i,j}s^{-1/2}}
$$

In this demonstration, we will simply work on the matrix exponential differentiation:
$$
f(W)=e^{W}
$$

## Setup <a name="setup"></a>

In [42]:
using Polynomials
using ForwardDiff
using ExponentialUtilities
using LinearAlgebra
using BenchmarkTools

In [43]:
W = [0.0 0.8 0.0;
     0.8 0.0 0.2; 
     0.0 0.2 0.0]
demo_edge = [CartesianIndex(1, 2)]
W

3×3 Matrix{Float64}:
 0.0  0.8  0.0
 0.8  0.0  0.2
 0.0  0.2  0.0

# Methods <a name="methods"></a>

## Original tangent approximation <a name="org-tangent"></a>

Here we are computing the derivative for a single edge. We will use the tangent approximation to compute the derivative. 
This is the approach used in the Akarca paper.

**Problem 1**
As far as I understand, there is a small but significant error in the implementation provided at https://github.com/DanAkarca/weighted_generative_models/blob/main/weighted_generative_model.m in line 197.
When fitting the first-order polynomial, the independent variable should not be the range of the differences to the edge, which have been used, but should be the actual values, the edge is taking on.

**Problem 2**
In the same implementation the derivative is approximated from the tangent approximation after nudging the edge (i,j) as well as (j,i).
In a sense we are computing the derivative with respect to both (i,j) and (j,i) at once.
First of all, this will create computational overhead, as the derivatives (i.e., the jacobian) will also always be symmetric.
Secondly, this will create a problem when we are trying to compute the gradient, which we currently simply take as the the derivative.
But then we update both elements in the W matrix with the same gradient which is the derivative with respect to nudging both (i,j) and (j,i) - thus our update to the objective function is overshooting.

See line 208 and 208.

In [44]:
function paper_tangent_approx(W, edges::Vector{CartesianIndex{2}}, resolution = 0.01, min = -0.25, max = 0.25)
    results = zeros(length(edges))
    rep_vec = collect(min:resolution:max)

    for (i_edge, edge_idx) in enumerate(edges)
        
        # Points for evaluation
        edge_val = W[edge_idx]
        reps = [edge_val * (1 + i) for i in rep_vec]

        # For each nudge save difference in communicability 
        sum_comm = zeros(length(reps))
        for (i_rep, rep) in enumerate(reps)
            W_copy = copy(W)
            W_copy[edge_idx] = W_copy[edge_idx[2], edge_idx[1]] = rep
            comm = exp(W_copy)
            sum_comm[i_rep] = sum(comm)
        end

        # Line 197 in MATLAB code
        x = 1:length(reps)
        results[i_edge] = fit(x, sum_comm, 1)[1]
    end

    return results
end

paper_tangent_approx(W, demo_edge, 0.01), paper_tangent_approx(W, demo_edge, 0.2)

([0.038767237882607476], [0.7456781939183535])

## Adapted tangent approximation <a name="adapted-tangent"></a>

In [45]:
function tangent_approx(W, edges::Vector{CartesianIndex{2}}, 
    resolution = 0.01, min = -0.25, max = 0.25)::Vector{Float64}
    results = zeros(length(edges))
    rep_vec = collect(min:resolution:max)

    for (i_edge, edge_idx) in enumerate(edges)     
        # Points for evaluation
        edge_val = W[edge_idx]
        reps = [edge_val * (1 + i) for i in rep_vec]

        # For each nudge save difference in communicability 
        sum_comm = zeros(length(reps))
        for (i_rep, rep) in enumerate(reps)
            W_copy = copy(W)
            W_copy[edge_idx] = W_copy[edge_idx[2], edge_idx[1]] = rep
            comm = exp(W_copy)
            sum_comm[i_rep] = sum(comm)
        end

        results[i_edge] = fit(reps, sum_comm, 1)[1]
    end

    return results
end

tangent_approx(W, demo_edge, 0.01), tangent_approx(W, demo_edge, 0.1)

([4.8459047353259335], [4.851647952441173])

## Finite difference method <a name="finite-diff"></a>
Compute the approximate derivative using the finite difference method.
At the moment, this is a significantly faster but less accurate method than the tangent approximation.
Only implemented to be potentially extended to complex step approximation.

References:<br>
https://en.wikipedia.org/wiki/Finite_difference_method

In [46]:
function finite_diff(W, edges::Vector{CartesianIndex{2}}, delta::Float64)::Vector{Float64}
    results = zeros(length(edges))
    for (i_edge, edge_idx) in enumerate(edges) 
        # Evaluate the function at two nearby points
        W_copy = copy(W)
        W_copy[edge_idx]  = W_copy[edge_idx] + delta
        f_plus_delta = sum(exp(W_copy))
        
        W_copy[edge_idx]  = W_copy[edge_idx] - 2*delta
        f_minus_delta = sum(exp(W_copy))

        # Calculate the derivative approximation
        results[i_edge] = (f_plus_delta - f_minus_delta) / (2 * delta)
    end

    return results
end


finite_diff(W, demo_edge, 0.01)*2, finite_diff(W, demo_edge,0.2)*2

([4.826503769214696], [4.827101432180703])

## Forward Differentiation
Results are derived in the numerator layout, which means that the dimensionality of the input is added to the right

In [47]:
function forward_diff_j(W::Matrix{Float64}, edges::Vector{CartesianIndex{2}})::Vector{Float64}
    # Column indices for retrieval
    indices = collect(CartesianIndices(W))
    index_vec = sort(vec(indices), by = x -> x[1])

    diff_exp(W) = exponential!(copyto!(similar(W), W), ExpMethodGeneric())
    J = ForwardDiff.jacobian(diff_exp, W)

    results = zeros(length(edges))
    tangent = vec(permutedims(exp(W), [2, 1]))
    for (i_edge, edge) in enumerate(edges)
        # we get all partial derivative positions that are non-zero
        Jₓ = J[:, findfirst(x -> x == edge, index_vec)]
        results[i_edge] = dot(Jₓ, tangent)
    end

    return results .* 2
end

forward_diff_j(W, demo_edge)

1-element Vector{Float64}:
 4.8612290825962985

## Forward Differentiation with Jacobian vector product

In [48]:
function forward_diff_jvp(W::Matrix{Float64}, edges::Vector{CartesianIndex{2}})::Vector{Float64}
    tangent = exp(W)
    diff_exp(W) = exponential!(copyto!(similar(W), W), ExpMethodGeneric())
    g(t) = diff_exp(W + t * tangent)
    JVP = ForwardDiff.derivative(g, 0.0)
    return JVP[edges] .* 2
end

forward_diff_jvp(W, demo_edge)

1-element Vector{Float64}:
 4.861229082596298

## Fréchet: Block enlarge

In [49]:
function frechet_block_enlarge(W::Matrix{Float64}, edges::Vector{CartesianIndex{2}})::Vector{Float64}
    E = exp(W)
    n = size(W, 1)
    M = [W E; zeros(size(W)) W]
    expm_M = exp(M) 
    frechet_AE = expm_M[1:n, n+1:end]
    return frechet_AE[edges] .* 2
end

frechet_block_enlarge(W, demo_edge)

1-element Vector{Float64}:
 4.861229082596298

## Fréchet: Scaling-Pade-Squaring


References:<br>
Higham (2008), Functions of Matrices - Theory and Computation", Chapter 10, Algorithm 10.20 <br>
https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.expm_frechet.html#scipy.linalg.expm_frechet <br>
https://rdrr.io/cran/expm/man/expmFrechet.html

In [50]:
function _diff_pade3(A, E, ident)
    b = (120.0, 60.0, 12.0, 1.0)
    A2 = A * A
    M2 = A * E + E * A
    U = A * (b[4] * A2 + b[2] * ident)
    V = b[3] * A2 + b[1] * ident
    Lu = A * (b[4] * M2) + E * (b[4] * A2 + b[2] * ident)
    Lv = b[3] * M2
    return U, V, Lu, Lv
end
    
function _diff_pade5(A, E, ident)
    b = (30240.0, 15120.0, 3360.0, 420.0, 30.0, 1.0)
    A2 = A * A
    M2 = A * E + E * A
    A4 = A2 * A2
    M4 = A2 * M2 + M2 * A2
    U = A * (b[6] * A4 + b[4] * A2 + b[2] * ident)
    V = b[5] * A4 + b[3] * A2 + b[1] * ident
    Lu = A * (b[6] * M4 + b[4] * M2) + E * (b[6] * A4 + b[4] * A2 + b[2] * ident)
    Lv = b[5] * M4 + b[3] * M2
    return U, V, Lu, Lv
end

function _diff_pade7(A, E, ident)
    b = (17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0, 1.0)
    A2 = A * A
    M2 = A * E + E * A
    A4 = A2 * A2
    M4 = A2 * M2 + M2 * A2
    A6 = A2 * A4
    M6 = A4 * M2 + M4 * A2
    U = A * (b[8] * A6 + b[6] * A4 + b[4] * A2 + b[2] * ident)
    V = b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident
    Lu = A * (b[8] * M6 + b[6] * M4 + b[4] * M2) + E * (b[8] * A6 + b[6] * A4 + b[4] * A2 + b[2] * ident)
    Lv = b[7] * M6 + b[5] * M4 + b[3] * M2
    return U, V, Lu, Lv
end

function _diff_pade9(A, E, ident)
    b = (17643225600.0, 8821612800.0, 2075673600.0, 302702400.0, 30270240.0, 2162160.0, 110880.0, 3960.0, 90.0, 1.0)
    A2 = A * A
    M2 = A * E + E * A
    A4 = A2 * A2
    M4 = A2 * M2 + M2 * A2
    A6 = A2 * A4
    M6 = A4 * M2 + M4 * A2
    A8 = A4 * A4
    M8 = A4 * M4 + M4 * A4
    U = A * (b[10] * A8 + b[8] * A6 + b[6] * A4 + b[4] * A2 + b[2] * ident)
    V = b[9] * A8 + b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident
    Lu = A * (b[10] * M8 + b[8] * M6 + b[6] * M4 + b[4] * M2) + E * (b[10] * A8 + b[8] * A6 + b[6] * A4 + b[4] * A2 + b[2] * ident)
    Lv = b[9] * M8 + b[7] * M6 + b[5] * M4 + b[3] * M2
    return U, V, Lu, Lv
end

ell_table_61 = (
        nothing,
        2.11e-8,
        3.56e-4,
        1.08e-2,
        6.49e-2,
        2.00e-1,
        4.37e-1,
        7.83e-1,
        1.23e0,
        1.78e0,
        2.42e0,
        # 11
        3.13e0,
        3.90e0,
        4.74e0,
        5.63e0,
        6.56e0,
        7.52e0,
        8.53e0,
        9.56e0,
        1.06e1,
        1.17e1)

(nothing, 2.11e-8, 0.000356, 0.0108, 0.0649, 0.2, 0.437, 0.783, 1.23, 1.78, 2.42, 3.13, 3.9, 4.74, 5.63, 6.56, 7.52, 8.53, 9.56, 10.6, 11.7)

In [51]:
function frechet_algo(A::Matrix{Float64}, edges::Vector{CartesianIndex{2}})::Vector{Float64}
    E = exp(A)
    n = size(A, 1)
    s = nothing
    ident = Matrix{Float64}(I, n, n)
    A_norm_1 = norm(A, 1)
    m_pade_pairs = [
        (3, _diff_pade3),
        (5, _diff_pade5),
        (7, _diff_pade7),
        (9, _diff_pade9)
    ]
    for (m, pade) in m_pade_pairs
        if A_norm_1 <= ell_table_61[m]
            U, V, Lu, Lv = pade(A, E, ident)
            s = 0
            break
        end
    end
    if s == nothing
        # scaling
        s = max(0, ceil(Int, log2(A_norm_1 / ell_table_61[13])))
        A *= 2.0^-s
        E *= 2.0^-s
        # pade order 13
        A2 = A * A
        M2 = A * E + E * A
        A4 = A2 * A2
        M4 = A2 * M2 + M2 * A2
        A6 = A2 * A4
        M6 = A4 * M2 + M4 * A2
        b = (64764752532480000., 32382376266240000., 7771770303897600.,
                1187353796428800., 129060195264000., 10559470521600.,
                670442572800., 33522128640., 1323241920., 40840800., 960960.,
                16380., 182., 1.)
        W1 = b[14] * A6 + b[12] * A4 + b[10] * A2
        W2 = b[8] * A6 + b[6] * A4 + b[4] * A2 + b[2] * ident
        Z1 = b[13] * A6 + b[11] * A4 + b[9] * A2
        Z2 = b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident
        W = A6 * W1 + W2
        U = A * W
        V = A6 * Z1 + Z2
        Lw1 = b[14] * M6 + b[12] * M4 + b[10] * M2
        Lw2 = b[8] * M6 + b[6] * M4 + b[4] * M2
        Lz1 = b[13] * M6 + b[11] * M4 + b[9] * M2
        Lz2 = b[7] * M6 + b[5] * M4 + b[3] * M2
        Lw = A6 * Lw1 + M6 * W1 + Lw2
        Lu = A * Lw + E * W
        Lv = A6 * Lz1 + M6 * Z1 + Lz2
    end
    # factor once and solve twice
    lu_piv = lu(-U + V)
    R = lu_piv \ (U + V)
    L = lu_piv \ (Lu + Lv + ((Lu - Lv) * R))
    # squaring
    for k in 1:s
        L = R * L + L * R
    end
    return L[edges] .*2
end

frechet_algo(W, demo_edge)

1-element Vector{Float64}:
 4.861229082596298

# Benchmarking <a name="benchmarking"></a>

In [58]:
function init_sparse_matrix(n = 100, density = 0.2)
    # Initialize a sparse matrix
    W = zeros(n, n)
    for i in 1:n, j in 1:n
        if rand() < (density / 2)
            W[i, j] = W[j,i] = rand()
        end
    end
    return W
end

W_bench = init_sparse_matrix(50, 0.10)
edges_bench = findall(x -> x != 0, W_bench)

@btime tangent_approx($W_bench, $edges_bench, 0.01) 
@btime finite_diff($W_bench, $edges_bench, 0.1) 
@btime forward_diff_j($W_bench, $edges_bench) 
@btime forward_diff_jvp($W_bench, $edges_bench) 
@btime frechet_block_enlarge($W_bench, $edges_bench) 
@btime frechet_algo($W_bench, $edges_bench) ;

  2.605 s (264026 allocations: 1.73 GiB)


  84.730 ms (7317 allocations: 57.23 MiB)


  1.717 s (5867 allocations: 311.91 MiB)


  1.711 ms (39 allocations: 432.30 KiB)


  3.319 ms (43 allocations: 726.62 KiB)


  575.458 μs (221 allocations: 1.80 MiB)


# Extension to full objective function