# Differentiation of objective function for weighted GNMs
From: https://www.biorxiv.org/content/10.1101/2023.06.23.546237v1?med=mas.

This is mostly concerned with finding the fastest way to compute the derivative of the matrix exponential.
$$
f(w_{i,j}) = (c_{i,k} * d_{i,j})^{\omega}
$$
Where, the communicability matrix is defined as follows. This captures the proportion of signals, that propagate randomly from node i would reach node j over an infinite time horizon.
$$
c_{i,j} = e^{s^{-1/2}w_{i,j}s^{-1/2}}
$$

## Setup

In [1]:
using Polynomials
using ForwardDiff
using ExponentialAction
using ExponentialUtilities
using LinearAlgebra
using BenchmarkTools

In [2]:
W = [0.0 0.8 0.0;
     0.8 0.0 0.2; 
     0.0 0.2 0.0]
demo_edge = [CartesianIndex(1, 2)]
W

3×3 Matrix{Float64}:
 0.0  0.8  0.0
 0.8  0.0  0.2
 0.0  0.2  0.0

## Tangent approximation
Here we are computing the derivative for a single edge. We will use the tangent approximation to compute the derivative. 
This is the approach used in the Akarca paper.

### Problem 1
As far as I understand, there is a small but significant error in the implementation provided at https://github.com/DanAkarca/weighted_generative_models/blob/main/weighted_generative_model.m in line 197.
When fitting the first-order polynomial, the independent variable should not be the range of the differences to the edge, which have been used, but should be the actual values, the edge is taking on.

### Problem 2
In the same implementation the derivative is approximated from the tangent approximation after nudging the edge (i,j) as well as (j,i).
In a sense we are computing the derivative with respect to both (i,j) and (j,i) at once.
First of all, this will create computational overhead, as the derivatives (i.e., the jacobian) will also always be symmetric.
Secondly, this will create a problem when we are trying to compute the gradient, which we currently simply take as the the derivative.
But then we update both elements in the W matrix with the same gradient which is the derivative with respect to nudging both (i,j) and (j,i) - thus our update to the objective function is overshooting.

See line 208 and 208.

In [3]:
function paper_tangent_approx(W, edges::Vector{CartesianIndex{2}}, resolution = 0.01, min = -0.25, max = 0.25)
    results = zeros(length(edges))
    rep_vec = collect(min:resolution:max)

    for (i_edge, edge_idx) in enumerate(edges)
        
        # Points for evaluation
        edge_val = W[edge_idx]
        reps = [edge_val * (1 + i) for i in rep_vec]


        # For each nudge save difference in communicability 
        sum_comm = zeros(length(reps))
        for (i_rep, rep) in enumerate(reps)
            W_copy = copy(W)
            W_copy[edge_idx] = W_copy[edge_idx[2], edge_idx[1]] = rep
            comm = exp(W_copy)
            sum_comm[i_rep] = sum(comm)
        end

        # Line 197 in MATLAB code
        x = 1:length(reps)
        results[i_edge] = fit(x, sum_comm, 1)[1]
    end

    return results
end

paper_tangent_approx(W, demo_edge, 0.01), paper_tangent_approx(W, demo_edge, 0.2)

([0.038767237882607476], [0.7456781939183535])

## Correct tangent approximation

In [4]:
function tangent_approx(W, edges::Vector{CartesianIndex{2}}, 
    resolution = 0.01, min = -0.25, max = 0.25)::Vector{Float64}
    results = zeros(length(edges))
    rep_vec = collect(min:resolution:max)

    for (i_edge, edge_idx) in enumerate(edges)     
        # Points for evaluation
        edge_val = W[edge_idx]
        reps = [edge_val * (1 + i) for i in rep_vec]

        # For each nudge save difference in communicability 
        sum_comm = zeros(length(reps))
        for (i_rep, rep) in enumerate(reps)
            W_copy = copy(W)
            W_copy[edge_idx] = W_copy[edge_idx[2], edge_idx[1]] = rep
            comm = exp(W_copy)
            sum_comm[i_rep] = sum(comm)
        end

        results[i_edge] = fit(reps, sum_comm, 1)[1]
    end

    return results
end

tangent_approx(W, demo_edge, 0.01), tangent_approx(W, demo_edge, 0.1)

([4.8459047353259335], [4.851647952441173])

## Finite difference method
Compute the approximate derivative using the finite difference method.
https://en.wikipedia.org/wiki/Finite_difference_method

In [5]:
function finite_diff(W, edges::Vector{CartesianIndex{2}}, delta::Float64)::Vector{Float64}
    results = zeros(length(edges))
    for (i_edge, edge_idx) in enumerate(edges) 
        # Evaluate the function at two nearby points
        W_copy = copy(W)
        W_copy[edge_idx]  = W_copy[edge_idx] + delta
        f_plus_delta = sum(exp(W_copy))
        
        W_copy[edge_idx]  = W_copy[edge_idx] - 2*delta
        f_minus_delta = sum(exp(W_copy))

        # Calculate the derivative approximation
        results[i_edge] = (f_plus_delta - f_minus_delta) / (2 * delta)
    end

    return results
end


finite_diff(W, demo_edge, 0.01)*2, finite_diff(W, demo_edge,0.2)*2

([4.826503769214696], [4.827101432180703])

## Forward Differentiation
Results are derived in the numerator layout, which means that the dimensionality of the input is added to the right

In [6]:
function forward_diff_j(W::Matrix{Float64}, edges::Vector{CartesianIndex{2}})::Vector{Float64}
    # Column indices for retrieval
    indices = collect(CartesianIndices(W))
    index_vec = sort(vec(indices), by = x -> x[1])

    diff_exp(W) = exponential!(copyto!(similar(W), W), ExpMethodGeneric())
    J = ForwardDiff.jacobian(diff_exp, W)

    results = zeros(length(edges))
    tangent = vec(permutedims(exp(W), [2, 1]))
    for (i_edge, edge) in enumerate(edges)
        # we get all partial derivative positions that are non-zero
        Jₓ = J[:, findfirst(x -> x == edge, index_vec)]
        results[i_edge] = dot(Jₓ, tangent)
    end

    return results .* 2
end

forward_diff_j(W, demo_edge)

1-element Vector{Float64}:
 4.8612290825962985

## Forward Differentiation with Jacobian vector product

In [7]:
function forward_diff_jvp(W::Matrix{Float64}, edges::Vector{CartesianIndex{2}})::Vector{Float64}
    tangent = exp(W)
    diff_exp(W) = exponential!(copyto!(similar(W), W), ExpMethodGeneric())
    g(t) = diff_exp(W + t * tangent)
    JVP = ForwardDiff.derivative(g, 0.0)
    return JVP[edges] .* 2
end


forward_diff_jvp(W, demo_edge)

1-element Vector{Float64}:
 4.861229082596298

## Benchmarking

In [8]:
function init_sparse_matrix(n = 100, density = 0.2)
    # Initialize a sparse matrix
    W = zeros(n, n)
    for i in 1:n, j in 1:n
        if rand() < (density / 2)
            W[i, j] = W[j,i] = rand()
        end
    end
    return W
end

W_bench = init_sparse_matrix(50, 0.10)
edges_bench = findall(x -> x != 0, W_bench)

@btime tangent_approx($W_bench, $edges_bench, 0.01)
@btime finite_diff($W_bench, $edges_bench, 0.1)
@btime forward_diff_j($W_bench, $edges_bench)
@btime forward_diff_jvp($W_bench, $edges_bench)

  2.608 s (279080 allocations: 1.83 GiB)


  88.504 ms (7743 allocations: 60.52 MiB)


  1.728 s (5961 allocations: 312.16 MiB)


  1.700 ms (39 allocations: 432.30 KiB)


241-element Vector{Float64}:
  9.435189136678718
  5.195743270503177
 13.467930447408666
  4.729657826331669
 37.97108049898049
 35.63246316466708
 40.14988627804099
 52.20936212065478
 61.064680835548295
 22.35336492281464
  ⋮
 14.609734107238188
 17.08779448253476
 34.43686560573628
 11.032408188686084
 22.008071742700885
 25.426604501682263
  6.7603142647465155
  7.540743730991197
 20.12146474452763

## Old

In [9]:
all_indices = [[[] for _ in 1:length(cartesian_indices)] for _ in 1:length(cartesian_indices)]
for i in 1:size(A, 1)^2
    for j in 1:size(A, 1)^2
        push!(all_indices[i][j], index_vec_r[i], index_vec_c[j])
    end
end

# Any combination of edges that includes a self-loop needs to be removed later
cycle_indices = CartesianIndex[]
for i in 1:size(A, 1)^2
    for j in 1:size(A, 1)^2
        if ((all_indices[i][j][1][1] == all_indices[i][j][1][2])||
            (all_indices[i][j][2][1] == all_indices[i][j][2][2])||
            (all_indices[i][j][1] == all_indices[i][j][2]) ||
            ((all_indices[i][j][1][1] == all_indices[i][j][2][2]) && 
             (all_indices[i][j][1][2] == all_indices[i][j][2][1]))
            )
            push!(cycle_indices, CartesianIndex(i, j))
        end
    end
end

UndefVarError: UndefVarError: `cartesian_indices` not defined