# Introduction: Differentiation of matrix exp for weighted GNMs <a name="intro"></a>
From: https://www.biorxiv.org/content/10.1101/2023.06.23.546237v1?med=mas.

This is mostly concerned with finding the fastest way to compute the derivative of the matrix exponential.
$$
f(w_{i,j}) = (c_{i,k} * d_{i,j})^{\omega}
$$
Where, the communicability matrix is defined as follows. This captures the proportion of signals, that propagate randomly from node i would reach node j over an infinite time horizon.
$$
c_{i,j} = e^{s^{-1/2}w_{i,j}s^{-1/2}}
$$

Finally, the update rule is given as:
$$
w_{i,j} = w_{i,j} - \alpha f'(w_{i,j})
$$

**Question**<br>
In the code implementation this seems to be bound to zero in Line 110, but there is no upper bound, should it be normalized to range [0,1]?

## Setup <a name="setup"></a>

In [1]:
using Polynomials
using ForwardDiff
using ExponentialUtilities
using LinearAlgebra
using BenchmarkTools
using StatsBase: sample
include("test_data.jl")

load_weight_test_data (generic function with 1 method)

In [2]:
W = [0.0 0.8 0.0;
     0.8 0.0 0.2; 
     0.0 0.2 0.0]
demo_edge = [CartesianIndex(1, 2)]
W

3×3 Matrix{Float64}:
 0.0  0.8  0.0
 0.8  0.0  0.2
 0.0  0.2  0.0

# Methods <a name="methods"></a>
Here we will work on different approaches as to compute the derivative.
In this demonstration, we will simply work on the matrix exponential differentiation, as the rest of the operation is trivial and computationally cheap.

For:
$$
f(W)=e^{W}
$$
Find
$$
\frac{\delta f}{\delta W_{i,j}}
$$
Or more specifically 
$$
f(W)=\sum_{i,j} e^{W}
$$

But this also simply boils down to the evaluation of the derivative at one (a matrix of ones at the same size as W for the jacobian vector product).

## Original tangent approximation <a name="org-tangent"></a>

Here we are computing the derivative for a single edge. We will use the tangent approximation to compute the derivative. 
This is the approach used in the Akarca paper and available in the repo, as can be seen below this will not produce the correct result.

**Problem 1**<br>
As far as I understand, there is a small but significant error in the implementation provided at https://github.com/DanAkarca/weighted_generative_models/blob/main/weighted_generative_model.m in line 197.
When fitting the first-order polynomial, the independent variable should not be the range of the differences to the edge, which have been used, but should be the actual values, the edge is taking on.

**Problem 2**<br>
The second problem is that we will often need to compute the derivative with respect to a value in the matrix, which is currently zero. (This happens easily as the value for any weight has a lower bound of zero which is often hit).
Now the approach implemented in the code builds a range of x value by using a value for the number of repititions and resolutions for the tangent approximations, but it multiples the resolution with the current x-value, which will produce a vector of zeros whenever the current value is zero, which will always produce a zero slope and zero derivative approximations - which is wrong.

**Problem 2**
In the same implementation the derivative is approximated from the tangent approximation after nudging the edge (i,j) as well as (j,i).
In a sense we are computing the derivative with respect to both (i,j) and (j,i) at once.
First of all, this will create computational overhead, as the derivatives (i.e., the jacobian) will also always be symmetric.
Secondly, this will create a problem when we are trying to compute the gradient, which we currently simply take as the the derivative.
But then we update both elements in the W matrix with the same gradient which is the derivative with respect to nudging both (i,j) and (j,i) - thus our update to the objective function is overshooting.

See line 208 and 208.

In [3]:
function paper_tangent_approx(W, edges::Vector{CartesianIndex{2}}, 
    resolution = 0.01, steps = 5)::Vector{Float64}
    results = zeros(length(edges))
    rep_vec = collect(range(-steps*resolution, steps*resolution, step=resolution))

    for (i_edge, edge_idx) in enumerate(edges)
        
        # Points for evaluation
        edge_val = W[edge_idx]
        reps = [edge_val * (1 + i) for i in rep_vec]

        # For each nudge save difference in communicability 
        sum_comm = zeros(length(reps))
        for (i_rep, rep) in enumerate(reps)
            W_copy = copy(W)
            W_copy[edge_idx] = W_copy[edge_idx[2], edge_idx[1]] = rep
            comm = exp(W_copy)
            sum_comm[i_rep] = sum(comm)
        end

        # Line 197 in MATLAB code
        x = 1:length(reps)
        results[i_edge] = fit(x, sum_comm, 1)[1]
    end

    return results
end

paper_tangent_approx(W, demo_edge, 0.01), paper_tangent_approx(W, demo_edge, 0.2)

([0.03861909629029788], [0.830380354061412])

## Adapted tangent approximation <a name="adapted-tangent"></a>

In [4]:
function tangent_approx(f::Function, W::Matrix{Float64}, edges::Vector{CartesianIndex{2}}, 
    resolution = 0.01, steps = 5)::Vector{Float64}
    results = zeros(length(edges))
    rep_vec = collect(range(-steps*resolution, steps*resolution, step=resolution))

    for (i_edge, edge_idx) in enumerate(edges)     
        # Points for evaluation
        edge_val = W[edge_idx]
        sign_edge = sign(edge_val) == 0 ? 1 : sign(edge_val)
        reps = [edge_val + sign_edge * (max(abs(edge_val), 1e-3) * i) for i in rep_vec]
        

        # For each nudge save difference in communicability 
        sum_comm = zeros(length(reps))
        for (i_rep, rep) in enumerate(reps)
            W_copy = copy(W)
            W_copy[edge_idx] = rep 
            sum_comm[i_rep] = f(W_copy, edge_idx)
        end

        results[i_edge] = fit(reps, sum_comm, 1)[1]
    end

    return results
end

f = (W, _) -> sum(exp(W))
tangent_approx(f, W, demo_edge, 0.01), tangent_approx(f, W, demo_edge, 0.1)

([2.4132596676512077], [2.4141043556312676])

## Finite difference method <a name="finite-diff"></a>
Compute the approximate derivative using the finite difference method.
At the moment, this is a significantly faster but less accurate method than the tangent approximation.
Only implemented to be potentially extended to complex step approximation.

References:<br>
https://en.wikipedia.org/wiki/Finite_difference_method

In [5]:
function finite_diff(f, W, edges::Vector{CartesianIndex{2}}, delta::Float64)::Vector{Float64}
    results = zeros(length(edges))
    for (i_edge, edge_idx) in enumerate(edges) 
        # Evaluate the function at two nearby points
        W_copy = copy(W)
        W_copy[edge_idx] = W_copy[edge_idx] + delta
        f_plus_delta = sum(exp(W_copy))
        
        W_copy[edge_idx]  = W_copy[edge_idx] - 2*delta
        f_minus_delta = sum(exp(W_copy))

        # Calculate the derivative approximation
        results[i_edge] = (f_plus_delta - f_minus_delta) / (2 * delta)
    end

    return results
end

f = W -> sum(exp(W))
finite_diff(f, W, demo_edge, 0.01), finite_diff(f, W, demo_edge,0.2)

([2.413251884607348], [2.4135507160903513])

## Forward Differentiation
Results are derived in the numerator layout, which means that the dimensionality of the input is added to the right

In [6]:
function forward_diff_j(W::Matrix{Float64}, edges::Vector{CartesianIndex{2}})::Vector{Float64}
    # Column indices for retrieval
    indices = collect(CartesianIndices(W))
    index_vec = sort(vec(indices), by = x -> x[1])

    diff_exp(W) = exponential!(copyto!(similar(W), W), ExpMethodGeneric())
    J = ForwardDiff.jacobian(diff_exp, W)

    results = zeros(length(edges))
    tangent = vec(permutedims(exp(W), [2, 1]))
    for (i_edge, edge) in enumerate(edges)
        # we get all partial derivative positions that are non-zero
        Jₓ = J[:, findfirst(x -> x == edge, index_vec)]
        results[i_edge] = sum(Jₓ)
    end

    return results
end

forward_diff_j(W, demo_edge)

1-element Vector{Float64}:
 2.4132511356618336

## Forward Differentiation with Jacobian vector product

In [7]:
function forward_diff_jvp(W::Matrix{Float64}, edges::Vector{CartesianIndex{2}})::Vector{Float64}
    tangent = ones(size(W))
    diff_exp(W) = exponential!(copyto!(similar(W), W), ExpMethodGeneric())
    g(t) = diff_exp(W + t * tangent)
    JVP = ForwardDiff.derivative(g, 0.0)
    return JVP[edges]
end

forward_diff_jvp(W, demo_edge)

1-element Vector{Float64}:
 2.413251135661834

## Fréchet: Block enlarge

In [8]:
function frechet_block_enlarge(W::Matrix{Float64}, edges::Vector{CartesianIndex{2}})::Vector{Float64}
    E = ones(size(W))
    n = size(W, 1)
    M = [W E; zeros(size(W)) W]
    expm_M = exp(M) 
    frechet_AE = expm_M[1:n, n+1:end]
    return frechet_AE[edges]
end

frechet_block_enlarge(W, demo_edge)

1-element Vector{Float64}:
 2.413251135661833

## Fréchet: Scaling-Pade-Squaring

The algorithm:
$$
\begin{array}{l}
\text{for } m = [3, 5, 7, 9] \\
\quad \quad \text{if } ||A||_{1} \leq \theta_{m} \\
\quad \quad \quad \quad r_{m}(A) \quad \text{ \textit{form Pade approximant to A}} \\
\quad \quad \quad \quad U, V     \quad \text{ \textit{Evaluate U and V, see below for }} r_{13} \\
\quad \quad \quad \quad s=0 \\
\quad \quad \quad \quad   \text{break} \\
\quad \quad \text{end} \\
\text{end} \\

\text{if } ||A||_{1} \geq  \theta_{m[-1]} \\
\quad \quad s = \text{ceil}(\log_{2}(||A||_{1}/ \theta_{13})) \\
\quad \quad A = A/2^{s} \\
\quad \quad A_{2} = A^{2}, A_{4} = A_{2}^{2}, A_{6} = A_{2}A_{4}\\ 
\quad \quad U = A[A_{6}(b_{13} A_{6} + b_{11} A_{4} + b_{9} A_{2}) + b_{7} A_{6} + b_{5} A_{4} + b_{3} A_{2} + b_{1} I]\\
\quad \quad V = A_{6}(b_{12}A_{6} + b_{10}A_{4} + b_{8}A_{2}) + b_{6}A_{6} + b_{4}A_{4} + b_{2}A_{2} + b_{0}I\\
\text{end} \\

\text{Solve } (-U+V)r_{m}(A) = U + V \text{ for } r_{m} \\
\text{for } k = 1:s \\
\quad \quad r_{m} = r_{m}*r_{m} \\
\text{end} \\
\text{return } r_{m} \\
\end{array}
$$

References:<br>
Higham (2008), Functions of Matrices - Theory and Computation", Chapter 10, Algorithm 10.20 <br>
https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.expm_frechet.html#scipy.linalg.expm_frechet <br>
https://rdrr.io/cran/expm/man/expmFrechet.html

In [9]:
function _diff_pade3(A, E, ident)
    b = (120.0, 60.0, 12.0, 1.0)
    A2 = A * A
    M2 = A * E + E * A
    U = A * (b[4] * A2 + b[2] * ident)
    V = b[3] * A2 + b[1] * ident
    Lu = A * (b[4] * M2) + E * (b[4] * A2 + b[2] * ident)
    Lv = b[3] * M2
    return U, V, Lu, Lv
end
    
function _diff_pade5(A, E, ident)
    b = (30240.0, 15120.0, 3360.0, 420.0, 30.0, 1.0)
    A2 = A * A
    M2 = A * E + E * A
    A4 = A2 * A2
    M4 = A2 * M2 + M2 * A2
    U = A * (b[6] * A4 + b[4] * A2 + b[2] * ident)
    V = b[5] * A4 + b[3] * A2 + b[1] * ident
    Lu = A * (b[6] * M4 + b[4] * M2) + E * (b[6] * A4 + b[4] * A2 + b[2] * ident)
    Lv = b[5] * M4 + b[3] * M2
    return U, V, Lu, Lv
end

function _diff_pade7(A, E, ident)
    b = (17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0, 1.0)
    A2 = A * A
    M2 = A * E + E * A
    A4 = A2 * A2
    M4 = A2 * M2 + M2 * A2
    A6 = A2 * A4
    M6 = A4 * M2 + M4 * A2
    U = A * (b[8] * A6 + b[6] * A4 + b[4] * A2 + b[2] * ident)
    V = b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident
    Lu = A * (b[8] * M6 + b[6] * M4 + b[4] * M2) + E * (b[8] * A6 + b[6] * A4 + b[4] * A2 + b[2] * ident)
    Lv = b[7] * M6 + b[5] * M4 + b[3] * M2
    return U, V, Lu, Lv
end

function _diff_pade9(A, E, ident)
    b = (17643225600.0, 8821612800.0, 2075673600.0, 302702400.0, 30270240.0, 
         2162160.0, 110880.0, 3960.0, 90.0, 1.0)
    A2 = A * A
    M2 = A * E + E * A
    A4 = A2 * A2
    M4 = A2 * M2 + M2 * A2
    A6 = A2 * A4
    M6 = A4 * M2 + M4 * A2
    A8 = A4 * A4
    M8 = A4 * M4 + M4 * A4
    U = A * (b[10] * A8 + b[8] * A6 + b[6] * A4 + b[4] * A2 + b[2] * ident)
    V = b[9] * A8 + b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident
    Lu = A * (b[10] * M8 + b[8] * M6 + b[6] * M4 + b[4] * M2) + E * (b[10] * A8 + b[8] * A6 + b[6] * A4 + b[4] * A2 + b[2] * ident)
    Lv = b[9] * M8 + b[7] * M6 + b[5] * M4 + b[3] * M2
    return U, V, Lu, Lv
end

function _diff_pade13(A, E, ident)
    # pade order 13
    A2 = A * A
    M2 = A * E + E * A
    A4 = A2 * A2
    M4 = A2 * M2 + M2 * A2
    A6 = A2 * A4
    M6 = A4 * M2 + M4 * A2
    b = (64764752532480000., 32382376266240000., 7771770303897600.,
            1187353796428800., 129060195264000., 10559470521600.,
            670442572800., 33522128640., 1323241920., 40840800., 960960.,
            16380., 182., 1.)
    W1 = b[14] * A6 + b[12] * A4 + b[10] * A2
    W2 = b[8] * A6 + b[6] * A4 + b[4] * A2 + b[2] * ident
    Z1 = b[13] * A6 + b[11] * A4 + b[9] * A2
    Z2 = b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident
    W = A6 * W1 + W2
    U = A * W
    V = A6 * Z1 + Z2
    Lw1 = b[14] * M6 + b[12] * M4 + b[10] * M2
    Lw2 = b[8] * M6 + b[6] * M4 + b[4] * M2
    Lz1 = b[13] * M6 + b[11] * M4 + b[9] * M2
    Lz2 = b[7] * M6 + b[5] * M4 + b[3] * M2
    Lw = A6 * Lw1 + M6 * W1 + Lw2
    Lu = A * Lw + E * W
    Lv = A6 * Lz1 + M6 * Z1 + Lz2
    return U, V, Lu, Lv
end

ell_table_61 = (nothing, 2.11e-8, 3.56e-4, 1.08e-2, 6.49e-2, 2.00e-1, 4.37e-1,
        7.83e-1, 1.23e0, 1.78e0, 2.42e0, 3.13e0, 3.90e0, 4.74e0, 5.63e0,
        6.56e0, 7.52e0, 8.53e0, 9.56e0, 1.06e1, 1.17e1);

In [10]:
function frechet_algo(A::Matrix{Float64}, edges::Vector{CartesianIndex{2}})::Vector{Float64}
    E = ones(size(A))
    n = size(A, 1)
    s = nothing
    ident = Matrix{Float64}(I, n, n)
    A_norm_1 = norm(A, 1)
    m_pade_pairs = [(3, _diff_pade3),(5, _diff_pade5),(7, _diff_pade7),(9, _diff_pade9)]
    for (m, pade) in m_pade_pairs
        if A_norm_1 <= ell_table_61[m]
            U, V, Lu, Lv = pade(A, E, ident)
            s = 0
            break
        end
    end
    if s == nothing
        # scaling
        s = max(0, ceil(Int, log2(A_norm_1 / ell_table_61[13])))
        A *= 2.0^-s
        E *= 2.0^-s
        U, V, Lu, Lv = _diff_pade13(A, E, ident)
    end
    
    # factor once and solve twice
    lu_piv = lu(-U + V)
    R = lu_piv \ (U + V)
    L = lu_piv \ (Lu + Lv + ((Lu - Lv) * R))
    
    # repeated squaring
    for k in 1:s
        L = R * L + L * R
        R= R * R
    end
    return L[edges]
end

frechet_algo(W, demo_edge)

1-element Vector{Float64}:
 2.413251135661833

# Benchmarking <a name="benchmarking"></a>

In [11]:
function init_sparse_matrix(n = 100, density = 0.2)
    # Initialize a sparse matrix
    W = zeros(n, n)
    for i in 1:n, j in 1:n
        if rand() < (density / 2)
            W[i, j] = W[j,i] = rand()
        end
    end
    return W
end

W_bench = init_sparse_matrix(50, 0.10)
edges_bench = findall(x -> x != 0, W_bench)
f = (W, _) -> sum(exp(W))

@btime tangent_approx(f, $W_bench, $edges_bench, 0.01) 
@btime finite_diff(f, $W_bench, $edges_bench, 0.1) 
@btime forward_diff_j($W_bench, $edges_bench) 
@btime forward_diff_jvp($W_bench, $edges_bench) 
@btime frechet_block_enlarge($W_bench, $edges_bench) 
@btime frechet_algo($W_bench, $edges_bench);

  490.936 ms (64371 allocations: 362.92 MiB)
  85.713 ms (7423 allocations: 58.01 MiB)
  1.721 s (5884 allocations: 311.96 MiB)
  1.503 ms (20 allocations: 315.22 KiB)
  2.591 ms (24 allocations: 609.55 KiB)
  385.708 μs (211 allocations: 1.78 MiB)


# Extension to full objective function

Currently I can successfully compute the objective function derivative without normalization.
This brings a tremendous speedup, demonstrated below on the dataset provided by Akarca in the demo repo.
Taking the derivative of the normalized version turns out to be more tricky than expected.

In [24]:
# load synthetic data
W_Y, D, A_init = load_weight_test_data()
A_Y = Float64.(W_Y .> 0);
α = 0.01
ω = 0.9
ϵ = 1e-5
m_seed = Int(sum(A_init))
m_all = Int(sum(A_Y))
resolution = 0.01
steps = 5

zero_indices = (findall(==(1), triu(abs.(A_init .- 1), 1)))
edges_to_add = sample(zero_indices, m_all-m_seed; replace = false);

function obj_func_auto(W::Matrix{Float64})
    return sum(exponential!(copyto!(similar(W), W), ExpMethodGeneric()))
end;

function norm_obj_func_auto(W::Matrix{Float64})
    node_strengths = dropdims(sum(W, dims=2), dims=2)
    node_strengths[node_strengths.==0] .= 1e-5
    S = sqrt(inv(Diagonal(node_strengths)))
    return sum(exponential!(copyto!(similar(W), S * W * S), ExpMethodGeneric()))
end;

function obj_func_tangent(W::Matrix{Float64}, edge_idx)
    return (sum(exp(W))  * D[edge_idx])^ω
end;

function norm_obj_func_tangent(W::Matrix{Float64}, edge_idx)
    node_strengths = dropdims(sum(W, dims=2), dims=2)
    node_strengths[node_strengths.==0] .= 1e-5
    S = sqrt(inv(Diagonal(node_strengths)))
    return (sum(exp(S * W * S)) * D[edge_idx])^ω
end;

In [22]:
function test_model(
    model::String,
    obj_func::Union{Function, Nothing},
    A_init::Matrix{Float64}, 
    m_max::Int, 
    verbose::Bool, 
    normalize::Bool)
    A_current = copy(A_init)
    W_current = copy(A_init)

    for m in 1:m_max
        # Get the edge, order of added edges is fixed
        edge_idx = edges_to_add[m-m_seed]
        rev_idx = CartesianIndex(edge_idx[2], edge_idx[1])
        A_current[edge_idx] = W_current[edge_idx] =  1
        A_current[rev_idx] = W_current[rev_idx] =  1    
        edge_indices = findall(!=(0), triu(A_current, 1))
        
        # Compute the derivative
        if model == "tangent"
            derivative = tangent_approx(obj_func, W_current, edge_indices, resolution, steps)
        elseif model == "auto-diff"
            println(sum(W_current))
            println(typeof(W_current))
            auto_d = ForwardDiff.gradient(obj_func, W_current)
            prinln(size(auto_d))
            println(round.(auto_d, digits=6))
            derivative = [ω * D[edge]^ω  * auto_d[i_edge] * current_Y^(ω-1) for (i_edge, edge) in enumerate(edge_indices)]
            throw(ArgumentError("Model not supported"))
        elseif model == "frechet-algo"
            current_Y = sum(exp(W_current))
            frechet_d = frechet_algo(W_current, edge_indices)
            derivative = [ω * D[edge]^ω  * frechet_d[i_edge] * current_Y^(ω-1) for (i_edge, edge) in enumerate(edge_indices)]
        else
            throw(ArgumentError("Model not supported"))
        end

        # Print the derivative for the at the current iteration, all edges added so far
        if verbose
            println(round.(derivative, digits=6))
        end

        # Update W matrix
        for (i_edge, edge) in enumerate(edge_indices)
            W_current[edge] -= (α * derivative[i_edge])
            W_current[edge] = max(0, W_current[edge])
            W_current[CartesianIndex(edge[2], edge[1])] = W_current[edge]
        end
    end
    return W_current
end;

In [25]:
@time W_res_auto = test_model("auto-diff", obj_fu, A_init, 10, true, false);

2.0
Matrix{Float64}


LoadError: MethodError: no method matching (::var"#20#21")(::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#20#21", Float64}, Float64, 12}})

[0mClosest candidates are:
[0m  (::var"#20#21")(::Any, [91m::Any[39m)
[0m[90m   @[39m [36mMain[39m [90m[4mIn[11]:14[24m[39m


In [15]:
@time W_res_tangent = test_model("tangent", obj_func_tangent, A_init, 10, true, false);
@time W_res_frechet = test_model("frechet-algo", nothing, A_init, 10, true, false);
sum(abs.(W_res_tangent .- W_res_frechet))

[101.454088]
[117.231767, 37.322597]
[43.126838, 64.13074, 71.883459]
[43.087831, 43.058883, 34.998296, 133.995615]
[43.126838, 37.322597, 26.444251, 66.617135, 49.338492]
[43.046904, 37.253421, 26.395238, 56.717627, 49.247045, 38.198427]
[43.024912, 37.234389, 26.381753, 33.833792, 49.221885, 129.767734, 22.844578]
[43.031256, 37.239879, 26.385643, 32.027406, 49.229144, 47.745567, 48.704336, 30.02135]
[43.032361, 37.240835, 26.38632, 25.792979, 49.230407, 89.030037, 47.746792, 26.629555, 17.492573]
[43.079351, 37.281501, 26.415133, 24.479908, 84.684085, 80.841006, 36.589357, 47.79893, 19.655044, 13.982765]
  3.177863 seconds (303.10 k allocations: 213.664 MiB, 0.74% gc time, 5.16% compilation time)
[101.453338]
[117.230899, 37.322597]
[43.126838, 64.13074, 71.882927]
[43.08783, 43.059002, 34.998479, 133.99462]
[43.126838, 37.322597, 26.444251, 66.616642, 49.338492]
[43.046903, 37.25342, 26.395237, 56.717872, 49.247044, 38.198245]
[43.024911, 37.234389, 26.381752, 33.833825, 49.221885,

1.5537775412743926e-5

In [16]:
sum(W_res_frechet)

0.4834167279944077