<a href="https://colab.research.google.com/github/dnguyend/StiefelGeodesic/blob/main/colab/StiefelLogJulia.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# <img src="https://github.com/JuliaLang/julia-logo-graphics/raw/master/images/julia-logo-color.png" height="100" /> _Colab Notebook For Stiefel Logarithm_

## Work book showing the algorithms for Riemannian Logarithms on Stiefel manifolds - with a family of metrics - the algorithm works for both embedded and canonical metrics - and beyond.

## Including: implementation of Frechet derivatives (previously not available in Julia - function is expm_frechet_algo_64)

## Numerical verification of frechet derivatives as directional derivative, trace formula for Frechet derivatives

## simple implementation of the exponential map for Stiefel manifold, both versions in our paper.

## Detailed implementation and verification of all steps in the paper

## use rlog_descent for the gradient descent algorithm, rlog_lbfgs for lbfgs algorithm.


## Instructions
1. Work on a copy of this notebook: _File_ > _Save a copy in Drive_ (you will need a Google account). Alternatively, you can download the notebook using _File_ > _Download .ipynb_, then upload it to [Colab](https://colab.research.google.com/).
2. If you need a GPU: _Runtime_ > _Change runtime type_ > _Harware accelerator_ = _GPU_.
3. Execute the following cell (click on it and press Ctrl+Enter) to install Julia, IJulia and other packages (if needed, update `JULIA_VERSION` and the other parameters). **This takes a couple of minutes.**
4. **Reload this page (press Ctrl+R, or ⌘+R, or the F5 key) and continue to the next section.**

_Notes_:
* If your Colab Runtime gets reset (e.g., due to inactivity), repeat steps 2, 3 and 4.
* After installation, if you want to change the Julia version or activate/deactivate the GPU, you will need to reset the Runtime: _Runtime_ > _Factory reset runtime_ and repeat steps 3 and 4.

In [None]:
%%shell
set -e

#---------------------------------------------------#
JULIA_VERSION="1.6.0" # any version ≥ 0.7.0
JULIA_PACKAGES="IJulia BenchmarkTools Plots"
JULIA_PACKAGES_IF_GPU="CUDA" # or CuArrays for older Julia versions
JULIA_NUM_THREADS=2
#---------------------------------------------------#

if [ -n "$COLAB_GPU" ] && [ -z `which julia` ]; then
  # Install Julia
  JULIA_VER=`cut -d '.' -f -2 <<< "$JULIA_VERSION"`
  echo "Installing Julia $JULIA_VERSION on the current Colab Runtime..."
  BASE_URL="https://julialang-s3.julialang.org/bin/linux/x64"
  URL="$BASE_URL/$JULIA_VER/julia-$JULIA_VERSION-linux-x86_64.tar.gz"
  wget -nv $URL -O /tmp/julia.tar.gz # -nv means "not verbose"
  tar -x -f /tmp/julia.tar.gz -C /usr/local --strip-components 1
  rm /tmp/julia.tar.gz

  # Install Packages
  if [ "$COLAB_GPU" = "1" ]; then
      JULIA_PACKAGES="$JULIA_PACKAGES $JULIA_PACKAGES_IF_GPU"
  fi
  for PKG in `echo $JULIA_PACKAGES`; do
    echo "Installing Julia package $PKG..."
    julia -e 'using Pkg; pkg"add '$PKG'; precompile;"' &> /dev/null
  done

  # Install kernel and rename it to "julia"
  echo "Installing IJulia kernel..."
  julia -e 'using IJulia; IJulia.installkernel("julia", env=Dict(
      "JULIA_NUM_THREADS"=>"'"$JULIA_NUM_THREADS"'"))'
  KERNEL_DIR=`julia -e "using IJulia; print(IJulia.kerneldir())"`
  KERNEL_NAME=`ls -d "$KERNEL_DIR"/julia*`
  mv -f $KERNEL_NAME "$KERNEL_DIR"/julia  

  echo ''
  echo "Success! Please reload this page and jump to the next section."
fi

Installing Julia 1.6.0 on the current Colab Runtime...
2021-08-24 11:39:21 URL:https://storage.googleapis.com/julialang2/bin/linux/x64/1.6/julia-1.6.0-linux-x86_64.tar.gz [112838927/112838927] -> "/tmp/julia.tar.gz" [1]
Installing Julia package IJulia...
Installing Julia package BenchmarkTools...
Installing Julia package Plots...
Installing IJulia kernel...
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mInstalling julia kernelspec in /root/.local/share/jupyter/kernels/julia-1.6

Success! Please reload this page and jump to the next section.




# Checking the Installation
**REMEMBER TO LOAD THE PAGE BY RUNNING F5 IF the following command does not work**

The `versioninfo()` function should print your Julia version and some other info about the system:

In [1]:
versioninfo()

Julia Version 1.6.0
Commit f9720dc2eb (2021-03-24 12:55 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Xeon(R) CPU @ 2.30GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, haswell)
Environment:
  JULIA_NUM_THREADS = 2


# CHECK EXECUTION of EXPM - Julia does not have EXPM_FRECHET.
We will port from scipy

In [2]:
using BenchmarkTools

using LinearAlgebra
# using Base.LinAlg
A = rand(1000, 1000)
A = A - A'
@benchmark exp(A)

BenchmarkTools.Trial: 5 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.046 s[22m[39m … [35m   1.332 s[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 2.26% … 21.07%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m1.151 s               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m12.10%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m1.172 s[22m[39m ± [32m129.641 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m12.14% ±  9.61%

  [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m█[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[3

In [31]:
function logom(U)
   v, V = eigen(U)
    return real(V*broadcast(*, log.(v), V'))
end
for i in 1:10
  n = 1000
  A = rand(n, n)
  A = A - A'
  U = exp(A)
  X = logom(U)
  println(linf(exp(X) - U))
end

2.6706529349107555e-11
7.66508384542064e-12
1.3862403412667756e-11
1.3088582301312712e-11
4.374465720213827e-11
3.155618561595519e-11
4.4705732171745893e-11
2.878959570740136e-11
7.443969052278732e-12
8.389938049857548e-11


In [33]:
@benchmark logom(U)

BenchmarkTools.Trial: 2 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m3.690 s[22m[39m … [35m  3.714 s[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.07% … 0.12%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m3.702 s              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.10%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m3.702 s[22m[39m ± [32m17.040 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.10% ± 0.03%

  [34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m 
  [34m█[39m[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[

In [20]:
ell_table_61 = (
        nothing,
        # 1
        2.11e-8,
        3.56e-4,
        1.08e-2,
        6.49e-2,
        2.00e-1,
        4.37e-1,
        7.83e-1,
        1.23e0,
        1.78e0,
        2.42e0,
        # 11
        3.13e0,
        3.90e0,
        4.74e0,
        5.63e0,
        6.56e0,
        7.52e0,
        8.53e0,
        9.56e0,
        1.06e1,
        1.17e1,
        )
function _diff_pade3(A, E)
    b = (120., 60., 12., 1.)
    A2 = A * A
    M2 = A * E + E*A
    U = A * (b[4]*A2 + UniformScaling(b[2]))
    V = b[3]*A2 + UniformScaling(b[1])
    Lu = A * (b[3]*M2) + E * (b[3]*A2 + UniformScaling(b[1]))
    Lv = b[3] .* M2
    return U, V, Lu, Lv
end        

function _diff_pade5(A, E)
    b = (30240., 15120., 3360., 420., 30., 1.)
    A2 = A * A
    M2 = A * E + E * A
    A4 = A2 * A2
    M4 = A2 * M2 + M2 * A2
    U = A * (b[6]*A4 + b[4]*A2 + UniformScaling(b[2]))
    V = b[5]*A4 + b[3]*A2 + UniformScaling(b[1])
    Lu = (A * (b[6]*M4 + b[4]*M2) +
            E * (b[6]*A4 + b[4]*A2 + UniformScaling(b[2])))
    Lv = b[5]*M4 + b[3]*M2
    return U, V, Lu, Lv
end

function _diff_pade7(A, E)
    b = (17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.)
    A2 = A * A
    M2 = A * E + E * A
    A4 = A2 * A2
    M4 = A2 * M2 + M2 * A2
    A6 = A2 * A4
    M6 = A4 * M2 + M4 * A2
    U = A * (b[8]*A6 + b[6]*A4 + b[4]*A2 + UniformScaling(b[2]))
    V = b[7]*A6 + b[5]*A4 + b[3]*A2 + UniformScaling(b[1])
    Lu = (A*(b[8]*M6 + b[6]*M4 + b[4]*M2) +
            E*(b[8]*A6 + b[6]*A4 + b[4]*A2 + UniformScaling(b[2])))
    Lv = b[7]*M6 + b[5]*M4 + b[3]*M2
    return U, V, Lu, Lv
end

function _diff_pade9(A, E)
    b = (17643225600., 8821612800., 2075673600., 302702400., 30270240.,
            2162160., 110880., 3960., 90., 1.)
    A2 = A * A
    M2 = A * E + E * A
    A4 = A2 * A2
    M4 = A2 * M2 + M2 * A2
    A6 = A2 * A4
    M6 = A4 * M2 + M4 * A2
    A8 = A4 * A4
    M8 = A4 * M4 + M4 * A4
    U = A * (b[10]*A8 + b[8]*A6 + b[6]*A4 + b[4]*A2 + UniformScaling(b[2]))
    V = b[9]*A8 + b[7]*A6 + b[5]*A4 + b[3]*A2 + UniformScaling(b[1])
    Lu = (A *(b[10]*M8 + b[8]*M6 + b[6]*M4 + b[4]*M2) +
            E * (b[10]*A8 + b[8]*A6 + b[6]*A4 + b[4]*A2 + UniformScaling(b[2])))
    Lv = b[9]*M8 + b[7]*M6 + b[5]*M4 + b[3]*M2
    return U, V, Lu, Lv
end


_diff_pade9 (generic function with 1 method)

In [21]:
function norm_axes(A, axes)
  return sqrt.(sum!(Vector{Float64}(undef, size(A, 1)), A .* A))
end

function expm_frechet_algo_64(A, E)
    n = size(A, 1)
    s = nothing    
    A_norm_1 = norm(A, 1)
    m_pade_pairs = (
            (3, _diff_pade3),
            (5, _diff_pade5),
            (7, _diff_pade7),
            (9, _diff_pade9))
    for m_pade in m_pade_pairs
        m, pade = m_pade
        if A_norm_1 <= ell_table_61[m]
            U, V, Lu, Lv = pade(A, E)
            s = 0
            break
        end            
    end
    if isnothing(s)
        # scaling
        s = max(0, Int(ceil(log2(A_norm_1 / ell_table_61[13]))))
        A = (2.0^-s) * A
        E = (2.0^-s) * E 
        # pade order 13
        A2 = A * A
        M2 = A * E + E * A
        A4 = A2 * A2
        M4 = A2 * M2 + M2 * A2
        A6 = A2 * A4
        M6 = A4 * M2 + M4 * A2
        b = (64764752532480000., 32382376266240000., 7771770303897600.,
                1187353796428800., 129060195264000., 10559470521600.,
                670442572800., 33522128640., 1323241920., 40840800., 960960.,
                16380., 182., 1.)
        W1 = b[14]*A6 + b[12]*A4 + b[10]*A2
        W2 = b[8]*A6 + b[6]*A4 + b[4]*A2 + UniformScaling(b[2])
        Z1 = b[13]*A6 + b[11]*A4 + b[9]*A2
        Z2 = b[7]*A6 + b[5]*A4 + b[3]*A2 + UniformScaling(b[1])
        W = A6 * W1 + W2
        U = A * W
        V = A6 * Z1 + Z2
        Lw1 = b[14]*M6 + b[12]*M4 + b[10]*M2
        Lw2 = b[8]*M6 + b[6]*M4 + b[4]*M2
        Lz1 = b[13]*M6 + b[11]*M4 + b[9]*M2
        Lz2 = b[7]*M6 + b[5]*M4 + b[3]*M2
        Lw = A6 * Lw1 + M6 * W1 + Lw2
        Lu = A * Lw + E * W
        Lv = A6 * Lz1 + M6 * Z1 + Lz2
    end        
    # factor once and solve twice
    lu_piv = lu(-U + V)
    R = lu_piv \ (U + V)
    L = lu_piv \ (Lu + Lv + (Lu - Lv)* R)
    # squaring
    for k in 1:s
        L = R * L + L * R
        R = R * R
    end
    return R, L
end

expm_frechet_algo_64 (generic function with 1 method)

## NOW TEST expm_frechet_algo_64

In [22]:
n = 5
A = randn(n, n)
E = randn(n, n)
pp = expm_frechet_algo_64(A, E)
println(pp[1])

[1.0060109754371573 -2.537360492724837 1.5102593862800624 2.7007006389200146 -0.9159775192490337; 2.7483322334494105 2.9954445501467486 2.2277889088533573 2.55987326251544 0.780954247371168; 0.08231076935810337 -1.0801238471596377 0.5108818696333359 1.1516535604769988 -0.4948580414097984; -0.7432634706397881 -3.1498608935805144 -0.03368432980302584 1.105916827935206 -1.1994727437587096; -0.8959621069409777 2.2463198874402726 -0.826468311258394 -0.4092355089758429 1.6584545012602825]


In [24]:
n = 100
A = 0.1*(reshape(0:(n*n-1), n, n)' .% 7) + UniformScaling(0.5)
E = reshape(0:(n*n-1), n, n)'
# E = E .* E
E = (E .* E) .% 23

println(A)
println(E)
if false
  @benchmark expm_frechet_algo_64(A, E)
end
# println(_diff_pade9(A, E))

[0.5 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1; 0.2 0.8 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2 0.30000000000000004 0.4 0.5 0.6000000000000001 0.0 0.1 0.2

In [25]:
@benchmark expm_frechet_algo_64(A, E)

BenchmarkTools.Trial: 619 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m6.306 ms[22m[39m … [35m27.280 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 16.65%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m6.822 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m8.029 ms[22m[39m ± [32m 2.744 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m12.59% ± 16.37%

  [34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [34m█[39m[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[3

# VERIFY FRECHET DERIVATIVE AND THE TRACE FORMULA

In [26]:
function linf(mat)
  return maximum(abs.(mat))
end  

n = 5
A = 0.1*(reshape(0:(n*n-1), n, n)' .% 7) + UniformScaling(0.5)
E = reshape(0:(n*n-1), n, n)'
# E = E .* E
E = (E .* E) .% 23

e1 = exp(A)
dlt = 1e-8
e2 = exp(A + dlt*E)
println("VERIFYING FRECHET DERIVATIVE")
(e2-e1)/dlt
# expm_frechet_algo_64(A, E)[2]
println(linf((e2-e1)/dlt - expm_frechet_algo_64(A, E)[2]))

println("VERIFYING THE TRACE FORMULA")
C = randn(n, n)
D = randn(n, n)
println(tr(C*expm_frechet_algo_64(A, E)[2]*D))
println(tr(expm_frechet_algo_64(A, D*C)[2]*E))

VERIFYING FRECHET DERIVATIVE
1.4476886747161188e-5
VERIFYING THE TRACE FORMULA
48.066842252764786
48.066842252764786


# Code simple Stiefel manifold with geodesic, random point, random tangent vector, exponential map (both versions)

In [34]:
struct Stf
  n::Int64
  p::Int64
  alpha::Vector{Float64}
end

function stf_rand(M)
  ## Random point on the manifold M
  QR = qr(randn(M.n, M.p))
  return QR.Q * vcat(I, zeros((M.n-M.p, M.p)))
end

function stf_inner(M, Y, eta, xi)
  # inner product
  return M.alpha[1]*sum(eta .* xi) + (M.alpha[2] - M.alpha[1])*sum((eta' * Y) .* (xi' * Y))
end

function linf(mat)
  return maximum(abs.(mat))
end  

function logom(U)
  # log for orthogonal matrices
  # cost only 3 times exp
   v, V = eigen(U)
    return real(V*broadcast(*, log.(v), V'))
end

function sym(mat)
  return 0.5*(mat + mat')
end

function asym(mat)
  return 0.5*(mat - mat')
end

function stf_proj(M, Y, omg)
  # projection of an ambient vector omg to the tangent space at Y  
  return omg - Y*sym(Y'*omg)
end

function stf_randvec(M, Y)
  # random tangent vector at Y
  r = stf_proj(M, Y, randn(size(Y)))
  return r ./ sqrt(stf_inner(M, Y, r, r))
end

function get_Q(Y, Y1)
    """ algorithm: find a basis in linear span of Y Y1
    orthogonal to Y
    """
    n , p = size(Y)
    F = svd([Y Y1])
    k = sum(F.S .> 1e-14)
    good = F.U[:, 1:k]*F.Vt[1:k, 1:k]
    qs = nullspace(Y'*good)
    QR = qr(good*qs)
    return QR.Q * vcat(I, zeros((n-k+p, k - p)))
end

function sexp(M, Y, Q, A, R )
  # exponential map, given Y, Q
  alf = M.alpha[2]/M.alpha[1]
  # println(alf)
  p, k  = size(Y, 2), size(Q, 2)
  ex1 = exp((1-2*alf)*A)
  ex2 = exp(
      vcat([2*alf*A -R'], [R zeros((k, k))]))
  return Y*ex2[1:p, 1:p]*ex1 + Q*ex2[(p+1):end, 1:p]*ex1  
end

function stf_exp(M, Y, eta)
  # exponential map
  p = size(Y, 2)
  A = Y' * eta
  QR = qr(eta - Y * A)
  return sexp(M, Y, QR.Q* vcat(I, zeros((M.n-M.p, M.p))), A, QR.R)
end 


function Pi0(Y, a)
  return a - Y*(Y'*a)
end  

function stf_gamma(M, Y, xi, eta)
    # the Christoffel term in the geodesic equation
    al = M.alpha[2]/M.alpha[1]    
    return Y*sym(xi' * eta) + (1-al)*Pi0(Y, xi*(eta'*Y) + eta*(xi'*Y))
end

function stf_dot_exp(M, X, eta, t)
   # return the other exp formula, and also
   # time derivative of the exponential map
   alf = M.alpha[2]/M.alpha[1]
   p = M.p
   A = X' * eta

   e_mat = zeros(2*p, 2*p)
   e_mat[1:p, 1:p] = (2*alf-1)*A 
   e_mat[1:p, p+1:end] = -eta'*eta - 2*(1-alf)*A*A
   e_mat[p+1:end, 1:p] = I(p)
   e_mat[p+1:end, p+1:end] = A
   eE = exp(t*e_mat)
   eA = exp((1-2*alf)*t*A)
   ex = ([X eta] * eE)[1:end, 1:p] * eA
   dot_ex = (vcat([X eta]) * e_mat*eE)[1:end, 1:p] * eA +
            (vcat([X eta]) * eE)[1:end, 1:p] * ((1-2*alf)*A*eA)
   return ex, dot_ex
end


stf_dot_exp (generic function with 1 method)

## TEST THE EXPONENTIAL MAP. THREE CONDITIONS: $Y(t)$ is on the manifold, and $d/dt Y(0) = \eta$. Check the geodesic equation. Verifying both formulas and time derivative of geodesic equation

In [None]:
n = 7
p = 3
alpha = [1, .8]
M2 = Stf(n, p, alpha)
Y = stf_rand(M2)
Y1 = stf_rand(M2)
eta = stf_randvec(M2, Y)

Yt = stf_exp(M2, Y, eta)

println("check Yt is on the manifold")
println(linf(Yt'*Yt - I(p)))
dlt = 1e-8
e2 = stf_exp(M2, Y, dlt*eta)
println("check d/dt Yt is eta")
println(linf(eta - (e2-Y)/dlt))

# check the geodesic equation Y
dlt = 1e-5
t = 1.5
Yt = stf_exp(M2, Y, t*eta)
Ytp = stf_exp(M2, Y, (t+dlt)*eta)
Ytm = stf_exp(M2, Y, (t-dlt)*eta)

println("VERIYING THE GEODESIC EQUATION d/dt Yt is eta")
Ydt = (Ytp - Ytm)/dlt/2
Yddt = (Ytp + Ytm - 2*Yt)/dlt/dlt
println(linf(Yddt + stf_gamma(M2, Yt, Ydt, Ydt)))

Yt1, Ytd1 = stf_dot_exp(M2, Y, eta, t)
println("VERIYING FORMULA 3.3 and also the time derivative of geodesic")
println(linf(Yt1 - Yt))
println(linf(Ytd1 - Ydt))

check Yt is on the manifold
2.220446049250313e-16
check d/dt Yt is eta
1.1344022082804273e-8
VERIYING THE GEODESIC EQUATION d/dt Yt is eta
5.058922321360404e-6
VERIYING FORMULA 3.3 and also the time derivative of geodesic
2.393918396847994e-16
1.2182581332620401e-11


In [None]:
@benchmark stf_exp(M2, Y, eta)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m27.444 μs[22m[39m … [35m 6.083 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m33.587 μs              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m38.908 μs[22m[39m ± [32m89.711 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m2.50% ± 1.91%

  [34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [34m█[39m[39m▁[39m▁[39m▁[39m▁[39m▁[3

# CHECK FORMULA FOR THE COST FUNCTION. VERIFY 

 $1/2\|Y(t) - Z\|^2 - p = \mathrm{Tr} (Z^T[Y Q] \exp \hat{A} I_{p+k, p}\exp((1-2\alpha)A)$

In [None]:
xi = stf_randvec(M2, Y)
Z = stf_exp(M2, Y, xi)
Q = get_Q(Y, Z)

k = size(Q)[2]
# now take a guess

A = asym(randn(p, p))
R = randn(k, p)
eta = Y*A + Q*R

Yt = stf_exp(M2, Y, eta)
println("VERIFYING THE COST FUNCTION FORMULA")

cost0 = 0.5*sum((Yt - Z).*(Yt - Z)) - p
println(cost0)

fix = (Z'*[Y Q])'
alf = M2.alpha[2]/M2.alpha[1]

Ahat = vcat([2*alf*A -R'], [R zeros((k, k))])
cost1 = -sum(fix .* (exp(Ahat)[1:end, 1:p]* exp((1-2*alf)*A)))
println(cost1)

ZTY = Z'*Y
ZTQ = Z'*Q

function fun(A, R)
  # to make this cheap, ie p^3 cost only evaluate ZTY, ZTQ outside
  # and evaluate both function and gradient
    ex1 = exp((1-2*alf)*A)

    mat = vcat([(2*alf*A) -R'], [R zeros(k, k)])
    E = vcat([(ex1 * ZTY) ex1*ZTQ], zeros(k, p+k))

    ex2, fe2 = expm_frechet_algo_64(mat, E)
    M = ex2[1:p, 1:p]
    N = ex2[p+1:end, 1:p]
    ZYMQN = ZTY*M+ZTQ*N

    partA = asym(
          (1-2*alf)*expm_frechet_algo_64((1-2*alf)*A, ZYMQN)[2])

    partA += 2*alf*asym(fe2[1:p, 1:p])
    partR = -(fe2[1:p, p+1:end]' - fe2[p+1:end, 1:p])

    return -sum(ZYMQN' .* ex1), partA, partR
end

f, g1, g2 = fun(A, R)
println(f)
DA = asym(randn(size(A)))
DR = randn(size(R))

dlt = 1e-8

Ytnew = sexp(M2, Y, Q, A + dlt*DA, R + dlt*DR)
costnew = 0.5*sum((Ytnew - Z).*(Ytnew - Z)) - p
println("VERIFYING THE GRADIENT FORMULA")
println((costnew - f)/dlt)
println((costnew - cost0)/dlt)
println(sum(g1 .* DA) + sum(g2 .* DR))



VERIFYING THE COST FUNCTION FORMULA
-0.09625865207398787
-0.09625865207398798
-0.0962586520739892
VERIFYING THE GRADIENT FORMULA
-0.08771081638769829
-0.08771094961446124
-0.08771078811993924


In [None]:
#@title NOT USED - JACT
"""# A slightly different presentation of the theorem on the gradient of the cost function

Here, consider the map $(A, R) \mapsto \exp(Y, YA +QR)$. This map is implemented as the function $sexp$ earlier. The Jacobian of $sexp$ is a map from the space of $(A, R)$s to $R^{n\times p}$. Its adjoint is a map from $R^{n\times p}$ back to the space of $A, R$, called $JacT$ below, with signature $JacT(M, Y, Q, A, R, \omega)$, or $JacT(A, R, \omega)$ for short.

The gradient of the cost function $1/2\| sexp(A, R) -Z\|^2_2 -p$ is $JacT(A, R, sexp(A, R) -Z)$. It is given by Frechet derivative, and we could reduce it to the form in the theorem. Below we verify numerically JacT is the adjoint, ie it satisfies for all random direction $\Delta_A, \Delta_R$ and random $\omega$
$$\lim_{\delta\to 0} \frac{1}{\delta}\mathrm{Tr}\omega^T(sexp(A+\delta \Delta_A, R + \delta \Delta_R) - sexp(A, R)) = JacT(A, R, \omega)$$
"""
function JacT(M, Y, Q, A, R, omg)
    p, k  = size(Y, 2), size(Q, 2)
    alf = M.alpha[2]/M.alpha[1]
    ex1 = exp((1-2*alf)*A)
    K14 = vcat(hcat(ex1*omg'*Y, ex1*omg'*Q), zeros((k, p+k)))
    
    Q14 = expm_frechet_algo_64(vcat([2*alf*A -R'], [R zeros((k, k))]), K14)
    
    K23 = omg' * (Y * Q14[1][1:p, 1:p] + Q * Q14[1][(p+1):end, 1:p])
    P23 = expm_frechet_algo_64((1-2*alf)*A, K23)[2]
    PA = asym(-(1-2*alf)*P23 - 2*alf*Q14[2][1:p, 1:p])
    PR =  Q14[2][1:p, (p+1):end]' - Q14[2][(p+1):end, 1:p]
    return PA, PR
end
M2 = Stf(n, p, alpha)

omg = randn(n, p)
# println(n, p, size(omg))
c1, c2 = JacT(M2, Y, Q, A, R, omg)
ee = Y * A + Q * R
DA = asym(randn(size(A)))
DR = randn(size(R))

dlt = 1e-8

println(sum(omg .*(sexp(M2, Y, Q, A + dlt*DA, R + dlt*DR) - sexp(M2, Y, Q, A, R) )/dlt))
println(sum(c1.*DA) + sum(c2.*DR))


-7.285244080794444
-7.285244007682631


# Implementing a simple gradient descent

In [None]:
function rlog_descent(stf, Y, Y1, tol=1e-10)
  alf = stf.alpha[2]/stf.alpha[1]
  n, p = stf.n, stf.p

  Q = get_Q(Y, Y1)
  k = size(Q, 2)

  eta0 = stf_proj(stf, Y, Y1-Y)
  A = asym(Y' * eta0)
  R = Q' * eta0 - (Q' * Y) * (Y' * eta0)

  ZTY = Y1'*Y
  ZTQ = Y1'*Q
  function fun(A, R)
    # to make this cheap, ie p^3 cost only evaluate ZTY, ZTQ outside
    # and evaluate both function and gradient
      ex1 = exp((1-2*alf)*A)

      mat = vcat([(2*alf*A) -R'], [R zeros(k, k)])
      E = vcat([(ex1 * ZTY) ex1*ZTQ], zeros(k, p+k))

      ex2, fe2 = expm_frechet_algo_64(mat, E)
      M = ex2[1:p, 1:p]
      N = ex2[p+1:end, 1:p]
      ZYMQN = ZTY*M+ZTQ*N

      partA = asym(
          (1-2*alf)*expm_frechet_algo_64((1-2*alf)*A, ZYMQN)[2])

      partA += 2*alf*asym(fe2[1:p, 1:p])
      partR = -(fe2[1:p, p+1:end]' - fe2[p+1:end, 1:p])

      return -sum(ZYMQN' .* ex1), partA, partR
  end

   max_itr = 120
   done = false
   itr = 0
   fjacs = 0
   fvals = 0
   scl = sqrt(n*p)

   while (!done) && (itr < max_itr)
        f, dA, dR = fun(A, R)
        fjacs += 1
        itr  += 1
        # println(itr, f)
        if max(0, f + p) < tol
            done = true
            break
        else
            dnorm = sqrt(sum(dA .* dA) + sum(dR .* dR))
            if dnorm == 0
                break
            end
            A -= dA
            R -= dR
        end
  end
  return Y*A +Q*R, itr, done, A, R, Q
end


rlog_descent (generic function with 2 methods)

In [None]:
n = 7
p = 3
alpha = [1, 0.5]
M2 = Stf(n, p, alpha)
Y = stf_rand(M2)
xi = stf_randvec(M2, Y)*.8*pi


Y1 = stf_exp(M2, Y, xi)
xi1, cnt, done, A, R, Q = rlog_descent(M2, Y, Y1, 1e-8)
size(xi1)
# Y'* xi1
println(linf(stf_exp(M2, Y, xi1) - Y1))
println(cnt)
if false
  @benchmark rlog_descent(M2, Y, Y1, 1e-10)
end


0.0008431637114450843
120


In [None]:
n = 7
p = 3
alpha = [1, 0.5]
M2 = Stf(n, p, alpha)
Y = vcat(I, zeros((n-p, p)))
k = p-1
Q = vcat(zeros(p, k), I,  zeros((n-p-k, k)))
A = asym(reshape(1:(p*p), (p, p)))
R = reshape((1:(p*k)) .* (1:(p*k)), (k, p)) .% 10
xi = Y * A + Q * R
xi = xi ./ sqrt(stf_inner(M2, Y, xi, xi))
Y1 = stf_exp(M2, Y, xi)
xi1, cnt, done, A, R, Q = rlog_descent(M2, Y, Y1, 1e-8)
if true
  # size(xi1)
  # Y'* xi1
  println(linf(stf_exp(M2, Y, xi1) - Y1))
  println(cnt)
  if true
    @benchmark rlog_descent(M2, Y, Y1, 1e-8)
  end
end  

5.829997139611409e-5
7


BenchmarkTools.Trial: 8515 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m416.692 μs[22m[39m … [35m 11.526 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m542.816 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m579.181 μs[22m[39m ± [32m368.115 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m2.82% ± 6.08%

  [34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [34m█[39m[39m▁[39m▁[39

In [None]:
n = 1000
p = 200
alpha = [1, 0.5]
M2 = Stf(n, p, alpha)
Y = vcat(I, zeros((n-p, p)))
k = p-1
Q = vcat(zeros(p, k), I,  zeros((n-p-k, k)))
A = asym(reshape(1:(p*p), (p, p)))
R = reshape((1:(p*k)) .* (1:(p*k)), (k, p))
xi = Y * A + Q * R
xi = xi ./ sqrt(stf_inner(M2, Y, xi, xi))
Y1 = stf_exp(M2, Y, xi)
xi1, cnt, done, A, R, Q = rlog_descent(M2, Y, Y1, 1e-8)
if true
  # size(xi1)
  # Y'* xi1
  println(linf(stf_exp(M2, Y, xi1) - Y1))
  println(cnt)
  if true
    @benchmark rlog_descent(M2, Y, Y1, 1e-8)
  end
end  

6.256545872804419e-9
3


BenchmarkTools.Trial: 19 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m256.675 ms[22m[39m … [35m285.644 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m3.75% … 10.06%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m267.267 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m6.62%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m268.741 ms[22m[39m ± [32m  9.286 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m6.23% ±  1.95%

  [34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [34m█[39m[39m▁[39m▁[39

In [None]:
n = 1500
p = 1000
alpha = [1, 0.5]
M2 = Stf(n, p, alpha)
Y = stf_rand(M2)
xi = stf_randvec(M2, Y)*.5*pi

if true
  Y1 = stf_exp(M2, Y, xi)
  xi1, cnt, done, A, R, Q = rlog_descent(M2, Y, Y1, 1e-8)
  println(linf(stf_exp(M2, Y, xi1) - Y1))
  println(cnt)
  if false
    @benchmark rlog_descent(M2, Y, Y1, 1e-8) seconds=180
  end
  if true
    @benchmark rlog_descent(M2, Y, Y1, 1e-8)
  end
end  

1.0023729652948801e-7
3


BenchmarkTools.Trial: 1 sample with 1 evaluation.
 Single result which took [34m43.241 s[39m (2.28% GC) to evaluate,
 with a memory estimate of [33m7.53 GiB[39m, over [33m1442[39m allocations.

# now do a simple LBFGS using library. The custom LBFGS is in Python
The More Thuente line search seems to work best

In [None]:
using Pkg; Pkg.add("Optim")

[32m[1m    Updating[22m[39m registry at `~/.julia/registries/General`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m   Installed[22m[39m DiffRules ────────────── v1.3.0
[32m[1m   Installed[22m[39m FiniteDiff ───────────── v2.8.1
[32m[1m   Installed[22m[39m OpenSpecFun_jll ──────── v0.5.5+0
[32m[1m   Installed[22m[39m Static ───────────────── v0.3.0
[32m[1m   Installed[22m[39m IfElse ───────────────── v0.1.0
[32m[1m   Installed[22m[39m ArrayInterface ───────── v3.1.24
[32m[1m   Installed[22m[39m LogExpFunctions ──────── v0.3.0
[32m[1m   Installed[22m[39m Optim ────────────────── v1.4.1
[32m[1m   Installed[22m[39m PositiveFactorizations ─ v0.2.4
[32m[1m   Installed[22m[39m FillArrays ───────────── v0.12.2
[32m[1m   Installed[22m[39m Parameters ───────────── v0.12.2
[32m[1m   Installed[22m[39m IrrationalConstants ──── v0.1.0
[32m[1m   Installed[22m[39m NLSolversBase ────────── v7.8.1
[32m[1m   Installed[22m[39m Forwa

In [None]:
using Optim
function veca(mat)
  # vectorize antisymmetric matrices
  sz = size(mat)[1]
  ret = zeros(div((sz*(sz-1)), 2))
  start = 1
  for i in 1:sz-1
    # println(size(ret[start:start+sz-i-1]), size(mat[i+1:end, i]))
    ret[start:start+sz-i-1] = mat[i+1:end, i]
    start += sz-i
  end
  return ret
end

function unveca(v)
   sz = .5 * (1 + sqrt(1 + 8 * size(v)[1]))
   sz = Int(round(sz))
   mat = zeros(sz, sz)
   start = 1
   for i in 1:(sz-1)
     mat[i+1:end, i] = v[start:start+sz-i-1]
     mat[i, i+1:end] = - v[start:start+sz-i-1]
    start += sz-i
  end
  return mat
end

function rlog_lbfgs(stf, Y, Z, tol)
  alf = stf.alpha[2]/stf.alpha[1]
  n, p = stf.n, stf.p

  Q = get_Q(Y, Y1)
  k = size(Q, 2)

  eta0 = stf_proj(stf, Y, Y1-Y)
  A = asym(Y' * eta0)
  R = Q' * eta0 - (Q' * Y) * (Y' * eta0)
  Adim = div(p*(p-1), 2)

  ZTY = Y1'*Y
  ZTQ = Y1'*Q
  function  ARunvec(v)
    return  unveca(v[1:Adim]), reshape(v[Adim+1:end], k, p)
  end
  
  function fun!(F, G, v)
    # to make this cheap, ie p^3 cost only evaluate ZTY, ZTQ outside
    # and evaluate both function and gradient
      A, R = ARunvec(v)
      ex1 = exp((1-2*alf)*A)

      mat = vcat([(2*alf*A) -R'], [R zeros(k, k)])
      E = vcat([(ex1 * ZTY) ex1*ZTQ], zeros(k, p+k))

      if G == nothing
        ex2 =  exp(mat)
        M = ex2[1:p, 1:p]
        N = ex2[p+1:end, 1:p]
        
        return -sum(( ZTY*M+ZTQ*N)' .* ex1)
      end
      ex2, fe2 = expm_frechet_algo_64(mat, E)
      M = ex2[1:p, 1:p]
      N = ex2[p+1:end, 1:p]
      ZYMQN = ZTY*M+ZTQ*N

      partA = asym(
          (1-2*alf)*expm_frechet_algo_64((1-2*alf)*A, ZYMQN)[2])

      partA += 2*alf*asym(fe2[1:p, 1:p])
      partR = -(fe2[1:p, p+1:end]' - fe2[p+1:end, 1:p])
      
      G[1:Adim] = veca(partA)
      G[1+Adim:end] = vec(partR)       
      return -sum(ZYMQN' .* ex1)
  end
  v0 = vcat(veca(A), vec(R))
  optzer = LBFGS(linesearch = Optim.LineSearches.MoreThuente(), m=5)
  ret = optimize(Optim.only_fg!(fun!), v0, optzer)
  A, R = ARunvec(Optim.minimizer(ret))  
  
  return Y * A + Q*R, ret
end

rlog_lbfgs (generic function with 1 method)

In [None]:
n = 7
p = 3
alpha = [1, 0.5]
M2 = Stf(n, p, alpha)
Y = stf_rand(M2)
xi = stf_randvec(M2, Y)*.5*pi
Y1 = stf_exp(M2, Y, xi)

eta1, ret = rlog_lbfgs(M2, Y, Y1, 1e-8)
println(linf(stf_exp(M2, Y, eta1) - Y1))
print(ret)

if true
  @benchmark rlog_lbfgs(M2, Y, Y1, 1e-8)
end

7.905688242937359e-9
 * Status: success

 * Candidate solution
    Final objective value:     -3.000000e+00

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 6.15e-08 ≰ 0.0e+00
    |x - x'|/|x'|          = 6.48e-08 ≰ 0.0e+00
    |f(x) - f(x')|         = 5.33e-15 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 1.78e-15 ≰ 0.0e+00
    |g(x)|                 = 5.36e-09 ≤ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    13
    f(x) calls:    14
    ∇f(x) calls:   14


BenchmarkTools.Trial: 3595 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.025 ms[22m[39m … [35m 14.461 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 67.49%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m1.291 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m1.378 ms[22m[39m ± [32m649.229 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m3.94% ±  8.25%

  [34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [34m█[39m[39m▁[39m▁[39m▁[39m▁[39m▁[

In [None]:
n = 1000
p = 50
alpha = [1, 0.5]
M2 = Stf(n, p, alpha)
Y = stf_rand(M2)
xi = stf_randvec(M2, Y)*.5*pi
Y1 = stf_exp(M2, Y, xi)

eta1, ret = rlog_lbfgs(M2, Y, Y1, 1e-8)
println(linf(stf_exp(M2, Y, eta1) - Y1))
print(ret)

if true
  @benchmark rlog_lbfgs(M2, Y, Y1, 1e-8)
end

2.142334018517822e-9
 * Status: success

 * Candidate solution
    Final objective value:     -5.000000e+01

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 5.98e-08 ≰ 0.0e+00
    |x - x'|/|x'|          = 2.59e-07 ≰ 0.0e+00
    |f(x) - f(x')|         = 5.19e-13 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 1.04e-14 ≰ 0.0e+00
    |g(x)|                 = 5.48e-09 ≤ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    5
    f(x) calls:    6
    ∇f(x) calls:   6


BenchmarkTools.Trial: 82 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m56.198 ms[22m[39m … [35m94.033 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m8.62% … 17.05%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m60.122 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m7.99%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m61.211 ms[22m[39m ± [32m 5.128 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m9.47% ±  2.74%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[34m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▃[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[3

In [None]:
n = 1500
p = 500
alpha = [1, 0.5]
M2 = Stf(n, p, alpha)
Y = stf_rand(M2)
xi = stf_randvec(M2, Y)*.5*pi
Y1 = stf_exp(M2, Y, xi)

eta1, ret = rlog_lbfgs(M2, Y, Y1, 1e-8)
println(linf(stf_exp(M2, Y, eta1) - Y1))
print(ret)

if true
  @benchmark rlog_lbfgs(M2, Y, Y1, 1e-8)
end

5.3959612021647896e-9
 * Status: success

 * Candidate solution
    Final objective value:     -5.000000e+02

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 2.16e-07 ≰ 0.0e+00
    |x - x'|/|x'|          = 3.20e-06 ≰ 0.0e+00
    |f(x) - f(x')|         = 2.37e-10 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 4.75e-13 ≰ 0.0e+00
    |g(x)|                 = 7.59e-09 ≤ 1.0e-08

 * Work counters
    Seconds run:   10  (vs limit Inf)
    Iterations:    3
    f(x) calls:    4
    ∇f(x) calls:   4


BenchmarkTools.Trial: 1 sample with 1 evaluation.
 Single result which took [34m15.367 s[39m (5.11% GC) to evaluate,
 with a memory estimate of [33m4.02 GiB[39m, over [33m11943[39m allocations.

In [None]:
n = 1500
p = 1000
alpha = [1, 0.5]
M2 = Stf(n, p, alpha)
Y = stf_rand(M2)
xi = stf_randvec(M2, Y)*.5*pi
Y1 = stf_exp(M2, Y, xi)

eta1, ret = rlog_lbfgs(M2, Y, Y1, 1e-8)
println(linf(stf_exp(M2, Y, eta1) - Y1))
print(ret)

if true
  @benchmark rlog_lbfgs(M2, Y, Y1, 1e-8)
end

2.3021776221487933e-9
 * Status: success

 * Candidate solution
    Final objective value:     -1.000000e+03

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 9.74e-08 ≰ 0.0e+00
    |x - x'|/|x'|          = 2.52e-06 ≰ 0.0e+00
    |f(x) - f(x')|         = 1.19e-10 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 1.19e-13 ≰ 0.0e+00
    |g(x)|                 = 2.85e-09 ≤ 1.0e-08

 * Work counters
    Seconds run:   37  (vs limit Inf)
    Iterations:    3
    f(x) calls:    4
    ∇f(x) calls:   4


BenchmarkTools.Trial: 1 sample with 1 evaluation.
 Single result which took [34m55.512 s[39m (2.15% GC) to evaluate,
 with a memory estimate of [33m10.24 GiB[39m, over [33m22017[39m allocations.

In [None]:
@benchmark rlog_descent(M2, Y, Y1, 1e-8)

BenchmarkTools.Trial: 1 sample with 1 evaluation.
 Single result which took [34m42.793 s[39m (2.62% GC) to evaluate,
 with a memory estimate of [33m7.53 GiB[39m, over [33m1442[39m allocations.

In [None]:
n = 1500
p = 200
alpha = [1, 0.8]
M2 = Stf(n, p, alpha)
Y = stf_rand(M2)
xi = stf_randvec(M2, Y)*1.3*pi
Y1 = stf_exp(M2, Y, xi)

eta1, ret = rlog_lbfgs(M2, Y, Y1, 1e-8)
println(linf(stf_exp(M2, Y, eta1) - Y1))
print(ret)

if true
  @benchmark rlog_lbfgs(M2, Y, Y1, 1e-8)  
end

2.0925045723929614e-10
 * Status: success

 * Candidate solution
    Final objective value:     -2.000000e+02

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 1.99e-08 ≰ 0.0e+00
    |x - x'|/|x'|          = 6.89e-08 ≰ 0.0e+00
    |f(x) - f(x')|         = 7.39e-13 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 3.69e-15 ≰ 0.0e+00
    |g(x)|                 = 3.23e-10 ≤ 1.0e-08

 * Work counters
    Seconds run:   1  (vs limit Inf)
    Iterations:    5
    f(x) calls:    6
    ∇f(x) calls:   6


BenchmarkTools.Trial: 3 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.863 s[22m[39m … [35m  1.882 s[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m3.86% … 3.87%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m1.879 s              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m3.88%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m1.875 s[22m[39m ± [32m10.098 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m3.95% ± 0.15%

  [34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [34m█[39m[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁

In [None]:
@benchmark rlog_descent(M2, Y, Y1, 1e-8)

BenchmarkTools.Trial: 4 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.312 s[22m[39m … [35m  1.372 s[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m3.73% … 3.38%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m1.337 s              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m3.46%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m1.339 s[22m[39m ± [32m24.858 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m3.50% ± 0.16%

  [34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [34m█[39m[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁

In [None]:
n = 7
p = 3
alpha = [1, 0.8]
M2 = Stf(n, p, alpha)
Y = stf_rand(M2)
xi = stf_randvec(M2, Y)*1.3*pi
Y1 = stf_exp(M2, Y, xi)

eta1, ret = rlog_lbfgs(M2, Y, Y1, 1e-8)
println(linf(stf_exp(M2, Y, eta1) - Y1))
print(ret)

if true
  @benchmark rlog_lbfgs(M2, Y, Y1, 1e-8)  
end


0.08462418124997292
 * Status: success

 * Candidate solution
    Final objective value:     -2.994179e+00

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 0.00e+00 ≤ 0.0e+00
    |x - x'|/|x'|          = 0.00e+00 ≤ 0.0e+00
    |f(x) - f(x')|         = 0.00e+00 ≤ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 0.00e+00 ≤ 0.0e+00
    |g(x)|                 = 1.50e-02 ≰ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    18
    f(x) calls:    64
    ∇f(x) calls:   64


BenchmarkTools.Trial: 714 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m5.700 ms[22m[39m … [35m24.562 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 59.43%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m6.375 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m6.991 ms[22m[39m ± [32m 2.044 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m5.13% ± 10.17%

  [34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [34m█[39m[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁

In [None]:
eta1, ret = rlog_descent(M2, Y, Y1, 1e-8)
println(ret)
println(linf(stf_exp(M2, Y, eta1) - Y1))
@benchmark rlog_descent(M2, Y, Y1, 1e-8)

120
0.07811664045523886


BenchmarkTools.Trial: 394 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m10.171 ms[22m[39m … [35m33.486 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 16.78%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m11.498 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m12.675 ms[22m[39m ± [32m 3.106 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m5.32% ±  9.64%

  [34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [34m█[39m[39m▁[39m▁[39m▁[39m▁[39m▁[3

# SHOW THE GRADIENT WORKS ON THE SPHERE
Agreeing with trigonometry

In [None]:
function randsphere(n)
  x = randn(n)
  return x/norm(x, 2)
end  


randsphere (generic function with 1 method)

In [None]:
n = 5
y = randsphere(n)
z = randsphere(n)
q = z - y*sum(y .* z)
q = q / norm(q, 2)


5-element Vector{Float64}:
  0.42801599283627284
  0.12918024834403197
 -0.5777275486806889
 -0.40168210721849473
  0.5522654593128796

In [None]:
r = .6
zty = sum(z .* y)
ztq = sum(z .* q)
function fcost(r)
  return - zty*cos(r) - ztq*sin(r)
end  

A= reshape([0, r, -r, 0], (2, 2))
E = reshape([zty, ztq, 0, 0], (2, 2))'
println(E)
s = expm_frechet_algo_64(A, E)[2]
hh = 1e-7
println((fcost(r+hh) - fcost(r))/hh)
println( zty*sin(r) - ztq*cos(r))
println(s[2, 1] - s[1, 2])

[-0.7555460671575899 0.6550955200600131; 0.0 0.0]
-0.9672870771026965
-0.9672870639970595
-0.9672870639970592
