# Difference in tangent approximation results
Ref: https://www.biorxiv.org/content/10.1101/2023.06.23.546237v1?med=mas.

As this is a minimal example, we will simply concern ourselves with taking the derivative of
$$
f(W)=\sum\sum e^{W}
$$

In [7]:
using Polynomials
using ForwardDiff
using ExponentialUtilities
using LinearAlgebra
using BenchmarkTools

include("example_imports.jl");

W = [0.0 0.8 0.0;
     0.8 0.0 0.2; 
     0.0 0.2 0.0]
     
demo_edge = [CartesianIndex(1, 2)];

## Tangent approximation (as I think is implemented)
Results will change for example with resolution, see below

In [8]:
function paper_tangent_approx(W, edges::Vector{CartesianIndex{2}}, 
    resolution = 0.01, steps = 5)::Vector{Float64}
    results = zeros(length(edges))
    rep_vec = collect(range(-steps*resolution, steps*resolution, step=resolution))

    for (i_edge, edge_idx) in enumerate(edges)
        
        # Points for evaluation
        edge_val = W[edge_idx]
        reps = [edge_val * (1 + i) for i in rep_vec]

        # For each nudge save difference in communicability 
        sum_comm = zeros(length(reps))
        for (i_rep, rep) in enumerate(reps)
            W_copy = copy(W)
            W_copy[edge_idx] = W_copy[edge_idx[2], edge_idx[1]] = rep
            comm = exp(W_copy)
            sum_comm[i_rep] = sum(comm)
        end

        # Line 197 in MATLAB code
        x = 1:length(reps)
        results[i_edge] = fit(x, sum_comm, 1)[1]
    end

    return results
end

paper_tangent_approx(W, demo_edge, 0.01), paper_tangent_approx(W, demo_edge, 0.1)

([0.03861909629029788], [0.3932448105709426])

## Adapted tangent approximation <a name="adapted-tangent"></a>

In [9]:
function tangent_approx(
    f::Function, 
    W::Matrix{Float64},
    edges::Vector{CartesianIndex{2}}, 
    resolution = 0.01, steps = 5)::Vector{Float64}
    """
    Compute the derivative of any function f with respec to the vector of edges provided
    Will use n = steps number of nudges of current w_{edge idx} by (1+resolution)
    """
    results = zeros(length(edges))
    rep_vec = collect(range(-steps*resolution, steps*resolution, step=resolution))

    for (i_edge, edge_idx) in enumerate(edges)     
        # Points for evaluation
        edge_val = W[edge_idx]
        sign_edge = sign(edge_val) == 0 ? 1 : sign(edge_val)
        x = [edge_val + sign_edge * (max(abs(edge_val), 1e-3) * i) for i in rep_vec]
        
        # For each nudge save difference in communicability 
        y = zeros(length(x))
        for (i_rep, rep) in enumerate(x)
            W_copy = copy(W)
            W_copy[edge_idx] = rep 
            y[i_rep] = f(W_copy, edge_idx)
        end

        results[i_edge] = fit(x, y, 1)[1]
    end

    return results
end;

f = (W, _) -> sum(exp(W))
tangent_approx(f, W, demo_edge, 0.01), tangent_approx(f, W, demo_edge, 0.1)

([2.4132596676512077], [2.4141043556312676])

## Other ways to compute the derivative fast (not directly relevant to our discussion)

### Finite difference method <a name="finite-diff"></a>
Compute the approximate derivative using the finite difference method.
Only implemented to be potentially extended to complex step approximation.

References:<br>
https://en.wikipedia.org/wiki/Finite_difference_method

In [10]:
function finite_diff(f, W, edges::Vector{CartesianIndex{2}}, delta::Float64)::Vector{Float64}
    results = zeros(length(edges))
    for (i_edge, edge_idx) in enumerate(edges) 
        # Evaluate the function at two nearby points
        W_copy = copy(W)
        W_copy[edge_idx] = W_copy[edge_idx] + delta
        f_plus_delta = sum(exp(W_copy))
        
        W_copy[edge_idx]  = W_copy[edge_idx] - 2*delta
        f_minus_delta = sum(exp(W_copy))

        # Calculate the derivative approximation
        results[i_edge] = (f_plus_delta - f_minus_delta) / (2 * delta)
    end

    return results
end

f = W -> sum(exp(W))
finite_diff(f, W, demo_edge, 0.01), finite_diff(f, W, demo_edge,0.1)

([2.413251884607348], [2.4133260303517945])

### Forward Differentiation
Results are derived in the numerator layout, which means that the dimensionality of the input is added to the right

In [11]:
function forward_diff_j(W::Matrix{Float64}, edges::Vector{CartesianIndex{2}})::Vector{Float64}
    # Column indices for retrieval
    indices = collect(CartesianIndices(W))
    index_vec = sort(vec(indices), by = x -> x[1])

    diff_exp(W) = exponential!(copyto!(similar(W), W), ExpMethodGeneric())
    J = ForwardDiff.jacobian(diff_exp, W)

    results = zeros(length(edges))
    tangent = vec(permutedims(exp(W), [2, 1]))
    for (i_edge, edge) in enumerate(edges)
        # we get all partial derivative positions that are non-zero
        Jₓ = J[:, findfirst(x -> x == edge, index_vec)]
        results[i_edge] = sum(Jₓ)
    end

    return results
end

forward_diff_j(W, demo_edge)

1-element Vector{Float64}:
 2.4132511356618336

### Forward Differentiation with Jacobian vector product

In [12]:
function forward_diff_jvp(W::Matrix{Float64}, edges::Vector{CartesianIndex{2}})::Vector{Float64}
    # note, that this is only possible because we have symmetry
    tangent = ones(size(W))
    diff_exp(W) = exponential!(copyto!(similar(W), W), ExpMethodGeneric())
    g(t) = diff_exp(W + t * tangent)
    JVP = ForwardDiff.derivative(g, 0.0)
    return JVP[edges]
end

forward_diff_jvp(W, demo_edge)

1-element Vector{Float64}:
 2.413251135661834

### Fréchet: Block enlarge

In [13]:
function frechet_block_enlarge(W::Matrix{Float64}, edges::Vector{CartesianIndex{2}})::Vector{Float64}
    E = ones(size(W))
    n = size(W, 1)
    M = [W E; zeros(size(W)) W]
    expm_M = exp(M) 
    frechet_AE = expm_M[1:n, n+1:end]
    return frechet_AE[edges]
end

frechet_block_enlarge(W, demo_edge)

1-element Vector{Float64}:
 2.413251135661833

### Fréchet: Scaling-Pade-Squaring

The algorithm:
$$
\begin{array}{l}
\text{for } m = [3, 5, 7, 9] \\
\quad \quad \text{if } ||A||_{1} \leq \theta_{m} \\
\quad \quad \quad \quad r_{m}(A) \quad \text{ \textit{form Pade approximant to A}} \\
\quad \quad \quad \quad U, V     \quad \text{ \textit{Evaluate U and V, see below for }} r_{13} \\
\quad \quad \quad \quad s=0 \\
\quad \quad \quad \quad   \text{break} \\
\quad \quad \text{end} \\
\text{end} \\

\text{if } ||A||_{1} \geq  \theta_{m[-1]} \\
\quad \quad s = \text{ceil}(\log_{2}(||A||_{1}/ \theta_{13})) \\
\quad \quad A = A/2^{s} \\
\quad \quad A_{2} = A^{2}, A_{4} = A_{2}^{2}, A_{6} = A_{2}A_{4}\\ 
\quad \quad U = A[A_{6}(b_{13} A_{6} + b_{11} A_{4} + b_{9} A_{2}) + b_{7} A_{6} + b_{5} A_{4} + b_{3} A_{2} + b_{1} I]\\
\quad \quad V = A_{6}(b_{12}A_{6} + b_{10}A_{4} + b_{8}A_{2}) + b_{6}A_{6} + b_{4}A_{4} + b_{2}A_{2} + b_{0}I\\
\text{end} \\

\text{Solve } (-U+V)r_{m}(A) = U + V \text{ for } r_{m} \\
\text{for } k = 1:s \\
\quad \quad r_{m} = r_{m}*r_{m} \\
\text{end} \\
\text{return } r_{m} \\
\end{array}
$$

References:<br>
Higham (2008), Functions of Matrices - Theory and Computation", Chapter 10, Algorithm 10.20 <br>
https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.expm_frechet.html#scipy.linalg.expm_frechet <br>
https://rdrr.io/cran/expm/man/expmFrechet.html

In [14]:
function frechet_algo(A::Matrix{Float64}, edges::Vector{CartesianIndex{2}})::Vector{Float64}
    E = ones(size(A))
    n = size(A, 1)
    s = nothing
    ident = Matrix{Float64}(I, n, n)
    A_norm_1 = norm(A, 1)
    m_pade_pairs = [(3, _diff_pade3),(5, _diff_pade5),(7, _diff_pade7),(9, _diff_pade9)]
    for (m, pade) in m_pade_pairs
        if A_norm_1 <= ell_table_61[m]
            U, V, Lu, Lv = pade(A, E, ident)
            s = 0
            break
        end
    end
    if s == nothing
        # scaling
        s = max(0, ceil(Int, log2(A_norm_1 / ell_table_61[13])))
        A *= 2.0^-s
        E *= 2.0^-s
        U, V, Lu, Lv = _diff_pade13(A, E, ident)
    end
    
    # factor once and solve twice
    lu_piv = lu(-U + V)
    R = lu_piv \ (U + V)
    L = lu_piv \ (Lu + Lv + ((Lu - Lv) * R))
    
    # repeated squaring
    for k in 1:s
        L = R * L + L * R
        R= R * R
    end
    return L[edges]
end

frechet_algo(W, demo_edge)

1-element Vector{Float64}:
 2.413251135661833

### Benchmarking <a name="benchmarking"></a>

In [15]:
function init_sparse_matrix(n = 100, density = 0.2)
    # Initialize a sparse matrix
    W = zeros(n, n)
    for i in 1:n, j in 1:n
        if rand() < (density / 2)
            W[i, j] = W[j,i] = rand()
        end
    end
    return W
end

W_bench = init_sparse_matrix(50, 0.10)
edges_bench = findall(x -> x != 0, W_bench)
f = (W, _) -> sum(exp(W))

@btime tangent_approx(f, $W_bench, $edges_bench, 0.01) 
@btime finite_diff(f, $W_bench, $edges_bench, 0.1) 
@btime forward_diff_j($W_bench, $edges_bench) 
@btime forward_diff_jvp($W_bench, $edges_bench) 
@btime frechet_block_enlarge($W_bench, $edges_bench) 
@btime frechet_algo($W_bench, $edges_bench);

  466.782 ms (63587 allocations: 358.38 MiB)
  80.744 ms (7337 allocations: 57.29 MiB)
  1.723 s (5884 allocations: 311.91 MiB)
  1.503 ms (20 allocations: 315.22 KiB)
  1.969 ms (24 allocations: 609.55 KiB)
  396.000 μs (211 allocations: 1.78 MiB)
