# Differentiation of objective function for weighted GNMs
From: https://www.biorxiv.org/content/10.1101/2023.06.23.546237v1?med=mas.

This is mostly concerned with finding the fastest way to compute the derivative of the matrix exponential.
$$
f(w_{i,j}) = (c_{i,k} * d_{i,j})^{\omega}
$$
Where, the communicability matrix is defined as follows. This captures the proportion of signals, that propagate randomly from node i would reach node j over an infinite time horizon.
$$
c_{i,j} = e^{s^{-1/2}w_{i,j}s^{-1/2}}
$$

## Setup

In [1]:
using Polynomials
using ForwardDiff
using ExponentialAction
using ExponentialUtilities
using LinearAlgebra

In [3]:
W = [0.0 0.8 0.0;
     0.8 0.0 0.2; 
     0.0 0.2 0.0]
edge_idx = CartesianIndex(1, 2)
edge_val = W[edge_idx]
W

3×3 Matrix{Float64}:
 0.0  0.8  0.0
 0.8  0.0  0.2
 0.0  0.2  0.0

## Tangent approximation
Here we are computing the derivative for a single edge. We will use the tangent approximation to compute the derivative. 
This is the approach used in the Akarca paper.
As far as I understand, there is a small but significant error in the implementation provided at https://github.com/DanAkarca/weighted_generative_models/blob/main/weighted_generative_model.m in line 197.
When fitting the first-order polynomial, the independent variable should not be the range of the differences to the edge, which have been used, but should be the actual values, the edge is taking on.

In [13]:
function paper_tangent_approx(W, edge_idx, resolution = 0.01, min = -0.25, max = 0.25)
    
    # Points for evaluation
    rep_vec = collect(min:resolution:max)
    reps = [edge_val * (1 + i) for i in rep_vec]

    # For each nudge save difference in communicability 
    sum_comm = zeros(length(rep_vec))
    for (i_rep, rep) in enumerate(reps)
        W_copy = copy(W)
        W_copy[edge_idx] = W_copy[edge_idx[2], edge_idx[1]] = rep
        comm = exp(W_copy)
        sum_comm[i_rep] = sum(comm)
    end

    # Line 197 in MATLAB code
    x = 1:length(reps)

    return fit(x, sum_comm, 1)[1]
end

paper_tangent_approx(W, edge_idx, 0.01), paper_tangent_approx(W, edge_idx, 0.2)

(0.038767237882607476, 0.7456781939183535)

## Correct tangent approximation

In [14]:
function tangent_approx(W, edge_idx, resolution = 0.01, min = -0.25, max = 0.25)
    
    # Points for evaluation
    rep_vec = collect(min:resolution:max)
    reps = [edge_val * (1 + i) for i in rep_vec]

    # For each nudge save difference in communicability 
    sum_comm = zeros(length(rep_vec))
    for (i_rep, rep) in enumerate(reps)
        W_copy = copy(W)
        W_copy[edge_idx] = W_copy[edge_idx[2], edge_idx[1]] = rep
        comm = exp(W_copy)
        sum_comm[i_rep] = sum(comm)
    end

    return fit(reps, sum_comm, 1)[1]
end

tangent_approx(W, edge_idx, 0.01), tangent_approx(W, edge_idx, 0.2)

(4.8459047353259335, 4.660488711989713)

## Finite difference method
Compute the approximate derivative using the finite difference method.
https://en.wikipedia.org/wiki/Finite_difference_method

In [21]:
function finite_diff(f, W, delta, edge_idx)
    # Evaluate the function at two nearby points
    W_plus = copy(W)
    W_plus[edge_idx]  = W_plus[edge_idx] +delta

    W_minus = copy(W)
    W_minus[edge_idx]  = W_minus[edge_idx]- delta

    f_plus_delta = sum(f(W_plus))
    f_minus_delta = sum(f(W_minus))

    # Calculate the derivative approximation
    derivative = (f_plus_delta - f_minus_delta) / (2 * delta)

    return derivative
end


finite_diff(W -> exp(W), W, 0.01, edge_idx)*2, finite_diff(W -> exp(W), W, 0.2, edge_idx)*2

(4.826503769214696, 4.827101432180703)

## Forward Differentiation
Results are derived in the numerator layout, which means that the dimensionality of the input is added to the right.

In [368]:
# construct n x n matrixes of edge combinations
indices = collect(CartesianIndices(A))
index_vec = vec(indices)
index_vec_c = sort(index_vec, by = x -> x[1])
index_vec_r= sort(index_vec, by = x -> x[2])

all_indices = [[[] for _ in 1:length(cartesian_indices)] for _ in 1:length(cartesian_indices)]
for i in 1:size(A, 1)^2
    for j in 1:size(A, 1)^2
        push!(all_indices[i][j], index_vec_r[i], index_vec_c[j])
    end
end

# Any combination of edges that includes a self-loop needs to be removed later
cycle_indices = CartesianIndex[]
for i in 1:size(A, 1)^2
    for j in 1:size(A, 1)^2
        if ((all_indices[i][j][1][1] == all_indices[i][j][1][2])||
            (all_indices[i][j][2][1] == all_indices[i][j][2][2])||
            (all_indices[i][j][1] == all_indices[i][j][2]) ||
            ((all_indices[i][j][1][1] == all_indices[i][j][2][2]) && 
             (all_indices[i][j][1][2] == all_indices[i][j][2][1]))
            )
            push!(cycle_indices, CartesianIndex(i, j))
        end
    end
end

In [390]:
myexp(A) = exponential!(copyto!(similar(A), A), ExpMethodGeneric());
d = ForwardDiff.jacobian(myexp, A)

# we remove all partial derivates that include self loops
# for idx in cycle_indices 
#     d[idx] = 0
# end

d

9×9 Matrix{Float64}:
 1.22423     0.445505    0.0284695   …  0.0284695   0.00557956  0.00022034
 0.445505    1.23135     0.111376       0.00557956  0.0285246   0.00139489
 0.0284695   0.111376    1.11747        0.00022034  0.00139489  0.0276432
 0.445505    0.114098    0.00557956     0.111376    0.0285246   0.00139489
 0.114098    0.4469      0.0285246      0.0285246   0.111725    0.00713115
 0.00557956  0.0285246   0.424582    …  0.00139489  0.00713115  0.106145
 0.0284695   0.00557956  0.00022034     1.11747     0.424582    0.0276432
 0.00557956  0.0285246   0.00139489     0.424582    1.12438     0.106145
 0.00022034  0.00139489  0.0276432      0.0276432   0.106145    1.01381

In [391]:
# we get all partial derivative positions that are non-zero
dx = d[:, findfirst(x -> x == edge_idx, index_vec_c)]
dot(dx, vec(permutedims(exp(A), [2, 1])))

2.4306145412981492

## Forward Differentiation with Jacobian vector product

In [375]:
function jvp(func, primal, tangent)
    g(t) = myexp(primal + t * tangent)
    jvp_result = ForwardDiff.derivative(g, 0.0)
    return jvp_result
end

jvp(myexp, A, exp(A))

3×3 Matrix{Float64}:
 2.59776   2.43061   0.399439
 2.43061   2.69762   0.607654
 0.399439  0.607654  1.09986