# Gradient checks 2

Extend `gradient checks 1` to bring it closer to the convolutional LGCP use-case. 

## Load Jax

In [686]:
# Shadow all pylab functions and numpy with the Jax versions
# Keep numpy around as np0 for easier RNG, array assignments
import jax
import jax.numpy as np
import numpy.random as npr
from jax               import jit, grad, vmap
from jax.config        import config
from jax.scipy.special import logsumexp
from jax.numpy         import *
from jax import jacfwd, jacrev
from jax import lax
from jax.numpy.fft import *
from jax.numpy.linalg import *

logdet = lambda A:jax.numpy.linalg.slogdet(A)[1]

def hess(f,inparam):
    return jacfwd(jacrev(f,inparam),inparam)

def hvp(f, x, v):
    return grad(lambda x: vdot(grad(f)(x), v))(x)

def slog(x,minrate = 1e-10):
    return log(maximum(minrate,x))

def sexp(x,bound = 10):
    return exp(clip(x,-bound,bound))

from numpy.linalg import cholesky as chol

def vec(X):
    return X.ravel()

# Convolve with vector
def conv(K,v):
    v = v.reshape(L,L)
    return real(ifft2(fft2(v)*K)).ravel()

# Convolve with matrix
def conm(K,M):
    N  = M.shape[-1]
    M  = M.reshape(L,L,N)
    Mf = fft2(M,axes=(0,1))
    Cf = K[:,:,None]*Mf
    C  = real(ifft2(Cf,axes=(0,1)))
    return C.reshape(T,N)

## Direct parameterization of Σ

- The domain is now 2D
- The prior is now a convolution

No real point in optimizing this one, since parameterization in Σ is never practical

In [687]:
L  = 10
T  = L*L
μ  = randn(L,L)
μ0 = randn(L,L)-3

# Prior is now a convolution
# Make kernel
w  = arange(L)-L/2
K0 = fftshift(exp(-abs(w[:,None]+1j*w[None,:])**2/2))
Kf = maximum(1e-6,real(fft2(K0,norm='ortho')))
Λf = 1/Kf
K0 = real(ifft2(Kf))
Λ0 = real(ifft2(Λf))

# simulated visit and spike counts
n  = np0.random.poisson(0.2,(L,L))
λ0 = sexp(μ0)
y  = np0.random.poisson(λ0)

# guess for posterior covariance
X  = randn(T,T)*0.1
Σ  = X@X.T + eye(T)*1e-2

# Flatten 2D
μ0 = μ0.ravel()
μ  = μ.ravel()
n  = n.ravel()
y  = y.ravel()

# random vectors for checking gradients
u  = randn(T)
M  = randn(T,T)

# Inefficient but for the hessian we need the full prior
Λ0matrix = conm(Λf,eye(T))

def loss(μ,Σ):
    λ     = exp(μ+μ0+diag(Σ)/2)
    ε     = λ - y*μ
    trΛ0Σ = trace(conm(Λf,Σ)) # TODO fix this one
    μΛ0μ  = μ.T @ conv(Λf,μ)
    lndΣ  = logdet(Σ)
    lndΣ0 = sum(log(Kf))
    return n@ε + (trΛ0Σ + μΛ0μ - lndΣ + lndΣ0)/2

# Gradient in μ
def Jμ1(μ,Σ):
    λ = sexp(μ+μ0+diag(Σ)/2)
    return n*(λ-y) + conv(Λf,μ)
print(mean(abs(Jμ1(μ,Σ) -grad(loss,0)(μ,Σ))))

# Gradient in Σ
def JΣ1(μ,Σ):
    nλ = n*sexp(μ+μ0+diag(Σ)/2)
    D  = diag(nλ)
    Λ  = inv(Σ)
    return (D + Λ0matrix - Λ)/2
print(mean(abs(JΣ1(μ,Σ)-grad(loss,1)(μ,Σ))))

# Hessian in μ
def Hμ1(μ,Σ):
    nλ = n*sexp(μ+μ0+diag(Σ)/2)
    D  = diag(nλ)
    return D + Λ0matrix
print(mean(abs(Hμ1(μ,Σ)-hess(loss,0)(μ,Σ))))

# Hessian-vector product for μ
def Hvμ1(μ,Σ,u):
    nλ = n*sexp(μ+μ0+diag(Σ)/2)
    D  = diag(nλ)
    return D@u + conv(Λf,u)
Hvμ = lambda u: grad(lambda μ: vdot(grad(loss,0)(μ,Σ), u))(μ)
print(mean(abs(Hvμ1(μ,Σ,u)-Hvμ(u))))

# Hessian-vector product for Σ
def HvΣ1(μ,Σ,M):
    nλ = n*sexp(μ+μ0+diag(Σ)/2)
    Λ  = pinv(Σ)
    return diag(nλ*diag(M))/4 + Λ@M.T@Λ/2
HvΣ = lambda M: grad(lambda Σ: vdot(grad(loss,1)(μ,Σ),M))(Σ)
print(mean(abs(HvΣ1(μ,Σ,M)-HvΣ(M))))

8.07887090559234e-14
8.01598787347757e-15
6.778932970519236e-15
5.1656456889759287e-14
4.629323169247357e-12


## Parameterize at Σ = A diag[v] A'

- A is now a convolution

This one is worth optimizing

In [688]:
A  = randn(L,L)*0.1
Af = real(fft2(A))
G  = A*A
Gf = real(fft2(G))
v  = exp(randn(T)-2)

# This term enters into the variance optimization
diagΛ0  = mean(Λf)*ones(L*L)
diagAΛA = mean(Λf*Gf)*ones(L*L)

def loss2(μ,v):
    ν = conv(Gf,v)
    λ = exp(μ+μ0+ν/2)
    ε = λ - y*μ
    trΛ0Σ = diagAΛA@v
    μΛ0μ  = μ.T @ conv(Λf,μ)
    lndΣ  = sum(slog(v))
    lndΣ0 = sum(log(Kf))
    return n@ε + trΛ0Σ/2 + μΛ0μ/2 + lndΣ0/2 - lndΣ/2

# Gradient in μ
def Jμ2(μ,v):
    ν = conv(Gf,v)
    λ = sexp(μ+μ0+ν/2)
    return n*(λ-y) + conv(Λf,μ)
print(mean(abs(Jμ2(μ,v) -grad(loss2,0)(μ,v))))

# Gradient in v
def Jv2(μ,v):
    ν  = conv(Gf,v)
    nλ = n*sexp(μ+μ0+ν/2)
    return conv(Gf,nλ)/2 + diagAΛA/2 - 1/v/2
print(mean(abs(grad(loss2,1)(μ,v)-Jv2(μ,v) )))

# Hessian-vector product for μ
def Hvμ2(μ,v,u):
    ν  = conv(Gf,v)
    nλ = n*sexp(μ+μ0+ν/2)
    return nλ*u + conv(Λf,u)
Hvμ = lambda u: grad(lambda μ: vdot(grad(loss2,0)(μ,v), u))(μ)
print(mean(abs(Hvμ2(μ,v,u)-Hvμ(u))))

# Hessian-vector product for v
def Hvv2(μ,v,u):
    ν  = conv(Gf,v)
    nλ = n*sexp(μ+μ0+ν/2)
    return 0.25*conv(Gf,nλ*conv(Gf,u)) + 0.5*v**-2*u

HvΣ = lambda u: grad(lambda v:vdot(grad(loss2,1)(μ,v),u))(v)
print(mean(abs(Hvv2(μ,v,u)-HvΣ(u))))

8.07887090559234e-14
4.263256414560601e-16
5.755396159656811e-14
0.0


## Parameterize as Σ = XX'

In [689]:
from opt_einsum import contract

In [690]:
R = 4
X = randn(T,R)
M = randn(T,R)

lndΣ0 = sum(log(Kf))

def loss3(μ,X):
    λ = exp(μ+μ0+sum(X**2,1)/2)
    ε = λ-y*μ
    trΛ0Σ = sum(X*conm(Λf,X))
    μΛ0μ  = μ.T @ conv(Λf,μ)
    lndΣ  = logdet(X.T@X)
    return n@ε + μΛ0μ/2 + trΛ0Σ/2 + lndΣ0/2 - lndΣ/2

# Gradient in μ
def Jμ3(μ,X):
    λ = exp(μ+μ0+sum(X**2,1)/2)
    return n*(λ-y) + conv(Λf,μ) 
print(mean(abs(Jμ3(μ,X) -grad(loss3,0)(μ,X))))

# Hessian-vector-product for μ
def Hvμ3(μ,X,u):
    nλ = n*exp(μ+μ0+sum(X**2,1)/2)
    return nλ*u + conv(Λf,u)
Hvμ = lambda u: grad(lambda μ: vdot(grad(loss3,0)(μ,X), u))(μ)
print(mean(abs(Hvμ3(μ,X,u)-Hvμ(u))))

# Gradient in X
def JX3(μ,X):
    nλ = n*exp(μ+μ0+sum(X**2,1)/2)
    P  = pinv(X)
    return nλ[:,None]*X + conm(Λf,X)  - P.T
print(mean(abs(JX3(μ,X)-grad(loss3,1)(μ,X))))

# Hessian-vector-product for X
def HvX3(μ,X,M):
    nλ   = n*exp(μ+μ0+sum(X**2,1)/2)
    xm   = sum(X*M,1)
    P    = pinv(X)
    Pt   = P.T
    PPt  = P@Pt
    #dpnv = Pt@(M.T@Pt+(X.T@M)@PPt) - M@PPt
    dpnv = contract('ji,kj,lk->il',P,M,P)+\
        contract('ji,kj,km,mn,ln->il',P,X,M,P,P)-\
        contract('ij,jk,lk->il',M,P,P)
    return nλ[:,None]*(M + xm[:,None]*X) + conm(Λf,M) + dpnv

HvX = lambda M: grad(lambda X: vdot(grad(loss3,1)(μ,X),M))(X)
print(mean(abs(HvX3(μ,X,M)-HvX(M))))

8.07887090559234e-14
5.4143356464919635e-14
7.889300324137594e-14
8.995137967815481e-14


## Parameterize as Σ = XX', quadratic instead of exponential observation

In [691]:
R = 4
X = randn(T,R)
M = randn(T,R)

lndΣ0 = sum(log(Kf))

def loss3(μ,X):
    λ = exp(μ+μ0)*(1+sum(X**2,1)/2)
    ε = λ-y*μ
    trΛ0Σ = sum(X*conm(Λf,X))
    μΛ0μ  = μ.T @ conv(Λf,μ)
    lndΣ  = logdet(X.T@X)
    return n@ε + μΛ0μ/2 + trΛ0Σ/2 + lndΣ0/2 - lndΣ/2

# Gradient in μ
def Jμ3(μ,X):
    λ = exp(μ+μ0)*(1+sum(X**2,1)/2)
    return n*(λ-y) + conv(Λf,μ) 
print(mean(abs(Jμ3(μ,X) -grad(loss3,0)(μ,X))))

# Hessian-vector-product for μ
def Hvμ3(μ,X,u):
    nλ = n*exp(μ+μ0)*(1+sum(X**2,1)/2)
    return nλ*u + conv(Λf,u)
Hvμ = lambda u: grad(lambda μ: vdot(grad(loss3,0)(μ,X), u))(μ)
print(mean(abs(Hvμ3(μ,X,u)-Hvμ(u))))

# Gradient in X
def JX3(μ,X):
    nλ = n*exp(μ+μ0)
    P  = pinv(X)
    return nλ[:,None]*X + conm(Λf,X)  - P.T
print(mean(abs(JX3(μ,X)-grad(loss3,1)(μ,X))))

# Hessian-vector-product for X
def HvX3(μ,X,M):
    nλ   = n*exp(μ+μ0)
    P    = pinv(X)
    Pt   = P.T
    PPt  = P@Pt
    #dpnv = Pt@(M.T@Pt+(X.T@M)@PPt) - M@PPt
    dpnv = einsum('ji,kj,lk->il',P,M,P)+\
        einsum('ji,kj,km,mn,ln->il',P,X,M,P,P)-\
        einsum('ij,jk,lk->il',M,P,P)
    return nλ[:,None]*M + conm(Λf,M) + dpnv

HvX = lambda M: grad(lambda X: vdot(grad(loss3,1)(μ,X),M))(X)
print(mean(abs(HvX3(μ,X,M)-HvX(M))))

8.07887090559234e-14
5.4996007747831756e-14
5.717204487609707e-14
1.1760370455249358e-13


## Parameterize as Σ = F'QQ'F where F is a clipped orthonormal basis and Q is a small square matrix

- Will come back to this at the end, replacing F with a discrete cosine transform

In [692]:
from scipy.stats import special_ortho_group

F = special_ortho_group.rvs(T)
R = T#//2
F = F[:R]
print(mean(abs(F@conj(F.T))-eye(R)))

Q = (randn(R,R) + eye(R))*1e-4
M = (randn(R,R) + eye(R))*1e-4

# Used for grad Q
Λ0Ft = conm(Λf,F.T)

def loss4(μ,Q):
    FQ = F.T@Q
    Σ  = FQ@FQ.T
    λ  = exp(μ+μ0+diag(Σ)/2)
    ε  = λ-y*μ
    trΛ0Σ = sum(FQ*conm(Λf,FQ))
    μΛ0μ  = μ.T @ conv(Λf,μ)
    lndΣ0 = sum(log(Kf))
    lndΣ  = logdet(Q@Q.T)
    return n@ε + trΛ0Σ/2 + μΛ0μ/2 + lndΣ0/2 - lndΣ/2

# Gradient in μ
def Jμ4(μ,Q):
    FQ = F.T@Q
    Σ  = FQ@FQ.T
    λ  = exp(μ+μ0+diag(Σ)/2)
    return n*(λ-y) + conv(Λf,μ)
print('μ jac',mean(abs(Jμ4(μ,Q)-grad(loss4,0)(μ,Q))))

def Hvμ4(μ,Q,u):
    FQ = F.T@Q
    Σ  = FQ@FQ.T
    nλ = n*exp(μ+μ0+diag(Σ)/2)
    return nλ*u + conv(Λf,u)
Hvμ = lambda u: grad(lambda μ: vdot(grad(loss4,0)(μ,Q), u))(μ)
print('μ hvp',mean(abs(Hvμ4(μ,Q,u)-Hvμ(u))))

# Gradient in Q
def JQ4(μ,Q):
    FQ = F.T@Q
    Σ  = FQ@FQ.T
    nλ = n*exp(μ+μ0+diag(Σ)/2)
    P  = pinv(Q)
    return F  @ (Λ0Ft + nλ[:,None] * F.T) @ Q - P.T
print('Q jac',mean(abs(JQ4(μ,Q) -grad(loss4,1)(μ,Q))))

def HvQ4(μ,Q,M):
    P    = pinv(Q)
    Pt   = P.T
    PPt  = P@Pt
    FQ = F.T@Q
    Σ  = FQ@FQ.T
    nλ = n*exp(μ+μ0+diag(Σ)/2)
    D  = diag(nλ)
    r  = nλ * diag(F.T@(M@Q.T)@F)
    c  = ((F*nλ)@F.T)@M + ((F*r)@F.T)@Q
    return c + (F@Λ0Ft)@M + Pt@M.T@Pt
HvQ = lambda M: grad(lambda Q:vdot(grad(loss4,1)(μ,Q),M))(Q)
print('Q hvp',mean(abs(HvQ4(μ,Q,M)-HvQ(M))))

6.361885082189481e-17
μ jac 8.07887090559234e-14
μ hvp 5.158540261618327e-14
Q jac 4.307481613663455e-09
Q hvp 1.013903475170741e-06


## Parameterize as Σ = inv( Λ + diag[p] )

Skipping this one since I have no idea how to quickly get the marginal variances. 

# Hartley transform

Replace F with the unitary Fourier transform. Double the domain size so that things remain real. Previously we had some difficulty with this. We will try to avoid that by working the whole problem on size (2L)².

In [693]:
F = array([fft2(x.reshape(L,L),norm='ortho').ravel() for x in eye(T)]).T
F = real(F) + imag(F)
print('Check unitary',max(abs(F@F.T-eye(T))))
print('Check involution',max(abs(F@F-eye(T))))

R = T
F = F[:R,:]
Q = (randn(R,R) + eye(R))*1e-4
M = (randn(R,R) + eye(R))*1e-4

# Used for grad Q
Λ0Ft = conm(Λf,F.T)

def loss4(μ,Q):
    FQ = F.T@Q
    Σ  = FQ@FQ.T
    λ  = exp(μ+μ0+diag(Σ)/2)
    ε  = λ-y*μ
    trΛ0Σ = sum(FQ*conm(Λf,FQ))
    μΛ0μ  = μ.T @ conv(Λf,μ)
    lndΣ0 = sum(log(Kf))
    lndΣ  = logdet(Q@Q.T)
    return n@ε + trΛ0Σ/2 + μΛ0μ/2 + lndΣ0/2 - lndΣ/2

def Jμ4(μ,Q):
    FQ = F.T@Q
    Σ  = FQ@FQ.T
    λ  = exp(μ+μ0+diag(Σ)/2)
    return n*(λ-y) + conv(Λf,μ)
print('μ jac',mean(abs(Jμ4(μ,Q)-grad(loss4,0)(μ,Q))))

def Hvμ4(μ,Q,u):
    FQ = F.T@Q
    Σ  = FQ@FQ.T
    nλ = n*exp(μ+μ0+diag(Σ)/2)
    return nλ*u + conv(Λf,u)
Hvμ = lambda u: grad(lambda μ: vdot(grad(loss4,0)(μ,Q), u))(μ)
print('μ hvp',mean(abs(Hvμ4(μ,Q,u)-Hvμ(u))))

def JQ4(μ,Q):
    FQ = F.T@Q
    Σ  = FQ@FQ.T
    nλ = n*exp(μ+μ0+diag(Σ)/2)
    P  = inv(Q)
    return F  @ (Λ0Ft + nλ[:,None] * F.T) @ Q - P.T
print('Q jac',mean(abs(JQ4(μ,Q) -grad(loss4,1)(μ,Q))))

def HvQ4(μ,Q,M):
    P   = inv(Q)
    Pt  = P.T
    PPt = P@Pt
    FQ = F.T@Q
    Σ  = FQ@FQ.T
    nλ = n*exp(μ+μ0+diag(Σ)/2)
    D  = diag(nλ)
    r  = nλ * diag(F.T@(M@Q.T)@F)
    c  = ((F*nλ)@F.T)@M + ((F*r)@F.T)@Q
    return c + (F@Λ0Ft)@M + Pt@M.T@Pt
HvQ = lambda M: grad(lambda Q:vdot(grad(loss4,1)(μ,Q),M))(Q)
print('Q hvp',mean(abs(HvQ4(μ,Q,M)-HvQ(M))))

Check unitary 6.661338147750939e-16
Check involution 6.661338147750939e-16
μ jac 8.07887090559234e-14
μ hvp 5.755396159656811e-14
Q jac 1.5730532407817744e-09
Q hvp 3.660276452085221e-07


## Fast Hartley transform 

In [694]:
def fht2(x):
    f = fft2(x.reshape(L,L),norm='ortho')
    return real(f)+imag(f)

def fht(*args):
    f = fft(*args,norm='ortho')
    return real(f)+imag(f)

In [705]:
Kh = fht2(K0)
K1 = fht2(Kh)
print('check FHT is involution',mean(abs(K1-K0)))

#  Pick which components to keep 
thr    = percentile(abs(Kh),90)
keep2d = abs(Kh)>thr
keep   = where(keep2d.ravel())[0]
R      = len(keep)
print('Keeping R=%d out of %d components'%(R,T))
print('Check approximation accuracy',max(abs(fht2(Kh*keep2d)-K0)))

# Jax can't do assignment so we build these projection matrices
down = eye(T)[keep]
up   = down.T

# TODO later
# Slightly faster to separate
keep1d = any(keep2d,1)
keeprc = where(keep2d[keep1d,:][:,keep1d].ravel())[0]

def collapse(v):
    v = v.reshape(L,L)
    v = fht2(v)
    v = v.ravel()[keep]
    return v

def expand(x):
    v = (up@x).reshape(L,L)
    v = fht2(v)
    return v.ravel()

def collapseAleft(A):
    ''' Collapse L²×L² matrix to subspace on left (todo: optimize me)'''
    A = A.reshape(L*L,L,L)
    A = array([collapse(a) for a in A]).T
    return A

def collapseAleft(A):
    ''' Collapse L²×L² matrix to subspace on left (todo: optimize me)'''
    print(A.shape)
    n = A.shape[0]
    A = A.reshape(n,L,L)
    A = array([collapse(a) for a in A]).T
    return A

def collapseAright(A):
    ''' Collapse L²×L² matrix to subspace on left (todo: optimize me)'''
    print(A.shape)
    n = A.shape[1]
    A = A.reshape(L,L,n).T
    A = array([collapse(a) for a in A]).T
    return A

def collapseA(A):
    ''' Collapse L²×L² matrix to subspace (todo: optimize me)'''
    A = A.reshape(L*L,L,L)
    A = array([collapse(a) for a in A]).T
    A = A.reshape(R,L,L)
    A = array([collapse(a) for a in A]).T
    return A

def expandAleft(A):
    '''Expand compressed representation on the left'''
    return array([expand(a) for a in A.T]).T

def expandAright(A):
    '''Expand compressed representation on the right'''
    return array([expand(a) for a in A])

def expandA(A):
    return expandAleft(expandAright(A))

# Sanity check : construct as matrix and verify fast routines match 
F = array([fft2(x.reshape(L,L),norm='ortho').ravel() for x in eye(T)]).T
F = real(F) + imag(F)
F = F[keep,:]
print('check F orthonormal',max(abs(F@F.T-eye(R))))
P = F.T@F

H = circulant(K0.ravel())
print('check expand/collapse matrix',max(abs(expandA(collapseA(H))-P@H@P)))
print('check expand/collapse vectors',max(abs(expand(collapse(K0))-P@K0.ravel())))

A = randn(T,T)
print('check collapse matrix',max(abs(F@A@F.T - collapseA(A))))

Λf    = 1/Kf
Λ0Ft   = conm(Λf,F.T)
lndΣ0 = sum(log(Kf))

Q = tril(randn(R,R) + eye(R))*1e-1
M = tril(randn(R,R) + eye(R))*1e-1

print('check FQ',max(abs(F.T@Q-expandAleft(Q))))
print('check FQF\'',max(abs(F.T@Q@Q.T@F-expandA(Q@Q.T))))

from scipy.linalg.lapack import dtrtri
def ltinv(ch):
    q,info = scipy.linalg.lapack.dtrtri(ch,lower=True)
    if info!=0:
        raise ValueError(
            'lapack.dtrtri encountered illegal argument in position %d'%-info
            if info<0 else
            'lapack.dtrtri encountered zero diagonal element at %d'%info)
    return q

print('check lower triangular inverse',max(abs(inv(Q) - ltinv(Q))))

check FHT is involution 3.1579208913678184e-18
Keeping R=11 out of 100 components
Check approximation accuracy 0.05044533473812045
check F orthonormal 6.661338147750939e-16
check expand/collapse matrix 5.551115123125783e-17
check expand/collapse vectors 1.3877787807814457e-17
check collapse matrix 2.6645352591003757e-15
check FQ 2.0816681711721685e-17
check FQF' 5.204170427930421e-18
check lower triangular inverse 2.0463630789890885e-12


In [716]:
def loss4(μ,Q):
    FQ = F.T@Q
    Σ  = FQ@FQ.T
    λ  = exp(μ+μ0+diag(Σ)/2)
    ε  = λ-y*μ
    trΛ0Σ = sum(FQ*conm(Λf,FQ))
    μΛ0μ  = μ.T @ conv(Λf,μ)
    lndΣ  = logdet(Q@Q.T)
    return n@ε + trΛ0Σ/2 + μΛ0μ/2 + lndΣ0/2 - lndΣ/2

def loss4fht(μ,Q):
    FQ = expandAleft(Q)
    Σ  = expandA(Q)
    λ  = exp(μ+μ0+diag(Σ)/2)
    ε  = λ-y*μ
    trΛ0Σ = sum(FQ*conm(Λf,FQ))
    μΛ0μ  = μ.T @ conv(Λf,μ)
    lndΣ  = logdet(Q@Q.T)
    return n@ε + trΛ0Σ/2 + μΛ0μ/2 + lndΣ0/2 - lndΣ/2
print('loss FHT check',max(abs(loss4(μ,Q)-loss4fht(μ,Q))))

# Gradient in μ
def Jμ4(μ,Q):
    FQ = expandAleft(Q)
    Σ  = expandA(Q)
    λ  = exp(μ+μ0+diag(Σ)/2)
    return n*(λ-y) + conv(Λf,μ)
print('μ jac',mean(abs(Jμ4(μ,Q)-grad(loss4,0)(μ,Q))))

def Hvμ4(μ,Q,u):
    FQ = expandAleft(Q)
    nλ = n*exp(μ+μ0+sum(FQ**2,1)/2)
    return nλ*u + conv(Λf,u)
Hvμ = lambda u: grad(lambda μ: vdot(grad(loss4,0)(μ,Q), u))(μ)
print('μ hvp',mean(abs(Hvμ4(μ,Q,u)-Hvμ(u))))

# Gradient in Q
def JQ4(μ,Q):
    FQ = expandAleft(Q)
    nλ = n*exp(μ+μ0+sum(FQ**2,1)/2)
    P  = ltinv(Q)
    
    λFQ = nλ[:,None]*FQ
    x = λFQ.reshape(L,L,R)
    x = fft2(x,axes=(0,1),norm='ortho')
    x = real(x) + imag(x)
    x = x.reshape(L*L,R)
    x = x[keep,:]
    FdλFtQ = x #(F*nλ)@F.T@Q
    
    return F@Λ0Ft@Q + FdλFtQ - P.T
print('Q jac',mean(abs(JQ4(μ,Q) -grad(loss4,1)(μ,Q))))

def HvQ4(μ,Q,M):
    FQ = expandAleft(Q)
    nλ = n*exp(μ+μ0+sum(FQ**2,1)/2)
    P   = ltinv(Q)
    Pt  = P.T
    D   = diag(nλ)
    r   = nλ * diag(F.T@(M@Q.T)@F)
    return (F*nλ)@F.T@M + (F*r)@F.T@Q + F@Λ0Ft@M + Pt@M.T@Pt
HvQ = lambda M: grad(lambda Q:vdot(grad(loss4,1)(μ,Q),M))(Q)
print('Q hvp',mean(abs(HvQ4(μ,Q,M)-HvQ(M))))

loss FHT check 0.0011935656220884994
μ jac 2.626321395197806e-05
μ hvp 5.6274984672199934e-14
Q jac 7.023459478434121e-11
Q hvp 1.539867610117542e-08


11