In [3]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Use jax with float64
from jax.config import config
config.update("jax_enable_x64", True)

from util   import *
from basics import *
from simulate_data import *
from estimators    import *
from config import *
from scipy.special import *
configure_pylab()   

## Load Jax

In [4]:
# Shadow all pylab functions and numpy with the Jax versions
# Keep numpy around as np0 for easier RNG, array assignments
import jax
import jax.numpy as np
import numpy.random as npr
from jax               import jit, grad, vmap
from jax.config        import config
from jax.scipy.special import logsumexp
from jax.numpy         import *
from jax import jacfwd, jacrev
from jax import lax
from jax.numpy.fft import *
from jax.numpy.linalg import *

logdet = lambda A:jax.numpy.linalg.slogdet(A)[1]

def hess(f,inparam):
    return jacfwd(jacrev(f,inparam),inparam)

def hvp(f, x, v):
    return grad(lambda x: vdot(grad(f)(x), v))(x)

import numpy as np0

# redefine these with Jax env so it can be traced
def conv(x,K):
    return real(ifft2(fft2(x.reshape(L,L))*K))

def slog(x,minrate = 1e-10):
    return log(maximum(minrate,x))

def sexp(x,bound = 10):
    return exp(clip(x,-bound,bound))

from numpy.linalg import cholesky as chol

def vec(X):
    return X.ravel()

## Direct parameterization of Σ

In [5]:
T  = 10
μ  = randn(T)
μ0 = randn(T)-3
X0 = randn(T,T)*0.1
Σ0 = X0@X0.T + eye(T)*1e-2
X  = randn(T,T)*0.1
Σ  = X@X.T + eye(T)*1e-2
n  = np0.random.poisson(0.2,T)
λ0 = sexp(μ0)
y  = np0.random.poisson(λ0)
Λ0 = inv(Σ0)
u  = randn(T)
M  = randn(T,T)

def loss(μ,Σ):
    λbar  = exp(μ+μ0+diag(Σ)/2)
    ε     = λbar - y*μ
    trΛ0Σ = trace(Λ0@Σ)
    μΛ0μ  = μ.T@Λ0@μ
    lndΣ0 = logdet(Σ0)
    lndΣ  = logdet(Σ)
    return n@ε + trΛ0Σ/2 + μΛ0μ/2 + lndΣ0/2 - lndΣ/2

def Jμ1(μ,Σ):
    λbar  = sexp(μ+μ0+diag(Σ)/2)
    return n*(λbar-y) + Λ0@μ

def JΣ1(μ,Σ):
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    Λ    = inv(Σ)
    return diag(nλ)/2 + Λ0/2 - Λ/2

def Hμ1(μ,Σ):
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    return diag(nλ) + Λ0

def Hvμ1(μ,Σ,u):
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    return nλ*u + Λ0@u

def HvΣ1(μ,Σ,M):
    M    = M.reshape(T,T)
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    Λ    = inv(Σ)
    return diag(nλ*diag(M))/4 + Λ@M.T@Λ/2

j1 = Jμ1(μ,Σ) 
j2 = grad(loss,0)(μ,Σ)
print(mean(abs(j1-j2)))

j1 = JΣ1(μ,Σ) 
j2 = grad(loss,1)(μ,Σ)
print(mean(abs(j1-j2)))

h1 = Hμ1(μ,Σ)
h2 = hess(loss,0)(μ,Σ)
print(mean(abs(h1-h2)))

Hvμ = lambda u: grad(lambda μ: vdot(grad(loss,0)(μ,Σ), u))(μ)
u1 = Hvμ1(μ,Σ,u)
u2 = Hvμ(u)
print(mean(abs(u1-u2)))

HvΣ = lambda M: grad(lambda Σ: vdot(grad(loss,1)(μ,Σ),M))(Σ)
u1 = HvΣ1(μ,Σ,M)
u2 = HvΣ(M)
print(mean(abs(u1-u2)))



8.659739592076222e-16
1.8252066524837575e-15
5.584421813864538e-16
2.0761170560490427e-15
2.1644908088092053e-13


## Parameterize at Σ = A diag[v] A'

In [6]:
A = randn(T,T)
G = A*A
v = exp(randn(T)-2)

def loss2(μ,v):
    Σ     = A@diag(v)@A.T
    λbar  = exp(μ+μ0+diag(Σ)/2)
    ε     = λbar - y*μ
    trΛ0Σ = trace(Λ0@Σ)
    μΛ0μ  = μ.T@Λ0@μ
    lndΣ0 = logdet(Σ0)
    lndΣ  = logdet(Σ)
    return n@ε + trΛ0Σ/2 + μΛ0μ/2 + lndΣ0/2 - lndΣ/2

def Jμ2(μ,v):
    Σ     = A@diag(v)@A.T
    λbar  = sexp(μ+μ0+diag(Σ)/2)
    return n*(λbar-y) + Λ0@μ

def Jv2(μ,v):
    Σ    = A@diag(v)@A.T
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    Λ    = inv(Σ)
    return G.T@nλ/2 - 1/v/2 + diag(A.T@Λ0@A)/2

def Hμ2(μ,v):
    Σ    = A@diag(v)@A.T
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    return diag(nλ) + Λ0

def Hvμ2(μ,v,u):
    Σ    = A@diag(v)@A.T
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    return nλ*u + Λ0@u

def Hv2(μ,v):
    Σ    = A@diag(v)@A.T
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    return G.T@diag(nλ)@G/4 + diag(v**-2)/2

def Hvv2(μ,v,u):
    Σ    = A@diag(v)@A.T
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    Λ    = inv(Σ)
    return (G.T@diag(nλ)@G/4 + diag(v**-2)/2)@u

j1 = Jv2(μ,v) 
j2 = grad(loss2,1)(μ,v)
print(mean(abs(j1-j2)))

j1 = Jμ2(μ,v) 
j2 = grad(loss2,0)(μ,v)
print(mean(abs(j1-j2)))

h1 = Hμ2(μ,v)
h2 = hess(loss2,0)(μ,v)
print(mean(abs(h1-h2)))

h1 = Hv2(μ,v)
h2 = hess(loss2,1)(μ,v)
print(mean(abs(h1-h2)))

Hvμ = lambda u: grad(lambda μ: vdot(grad(loss2,0)(μ,v), u))(μ)
u1 = Hvμ2(μ,v,u)
u2 = Hvμ(u)
print(mean(abs(u1-u2)))

HvΣ = lambda u: grad(lambda v: vdot(grad(loss2,1)(μ,v),u))(v)
u1 = Hvv2(μ,v,u)
u2 = HvΣ(u)
print(mean(abs(u1-u2)))


2.700062395888381e-14
8.659739592076222e-16
5.584421813864538e-16
1.5350388410335707e-13
2.0761170560490427e-15
3.8329339702158905e-13


## Parameterize as Σ = XX'

In [14]:
R = T//2
X = randn(T,R)
M = randn(T,R)

def loss3(μ,X):
    Σ     = X@X.T
    λbar  = exp(μ+μ0+diag(Σ)/2)
    ε     = λbar - y*μ
    trΛ0Σ = trace(Λ0@Σ)
    μΛ0μ  = μ.T@Λ0@μ
    lndΣ0 = logdet(Σ0)
    lndΣ  = logdet(X.T@X)
    return n@ε + trΛ0Σ/2 + μΛ0μ/2 + lndΣ0/2 - lndΣ/2

def Jμ3(μ,X):
    Σ     = X@X.T
    λbar  = sexp(μ+μ0+diag(Σ)/2)
    return n*(λbar-y) + Λ0@μ

j1 = Jμ3(μ,X) 
j2 = grad(loss3,0)(μ,X)
print(mean(abs(j1-j2)))

def JX3(μ,X):
    Σ    = X@X.T
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    P    = pinv(X)
    return (diag(nλ) + Λ0) @ X - P.T

j1 = JX3(μ,X) 
j2 = grad(loss3,1)(μ,X)
print(mean(abs(j1-j2)))

def Hμ3(μ,X):
    Σ    = X@X.T
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    return diag(nλ) + Λ0

h1 = Hμ3(μ,X)
h2 = hess(loss3,0)(μ,X)
print(mean(abs(h1-h2)))

def Hvμ3(μ,X,u):
    Σ    = X@X.T
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    return nλ*u + Λ0@u

Hvμ = lambda u: grad(lambda μ: vdot(grad(loss3,0)(μ,X), u))(μ)
u1 = Hvμ3(μ,X,u)
u2 = Hvμ(u)
print(mean(abs(u1-u2)))

def HvX3(μ,X,M):
    Σ    = X@X.T
    λbar = sexp(μ+μ0+diag(Σ)/2)
    λb   = diag(λbar*n)
    dxm  = diag(diag(X@M.T))
    P    = pinv(X)
    Pt   = P.T
    PPt  = P@Pt
    dpnv = Pt@(M.T@Pt+(X.T@M)@PPt) - M@PPt
    return λb@(M + dxm@X) + Λ0@M + dpnv

HvX = lambda M: grad(lambda X: vdot(grad(loss3,1)(μ,X),M))(X)
u1  = HvX3(μ,X,M)
u2  = HvX(M)
print(mean(abs(u1-u2)))

8.659739592076222e-16
3.907985046680551e-15
5.584421813864538e-16
2.0761170560490427e-15
3.4439118223872355e-15


## Parameterize as Σ = F'QQ'F where F is a clipped orthonormal 

In [15]:
from scipy.stats import special_ortho_group

R = T//4
F = special_ortho_group.rvs(T)[:R]
Q = randn(R,R)
M = randn(R,R)

def loss4(μ,Q):
    Σ     = F.T@Q@Q.T@F
    λbar  = exp(μ+μ0+diag(Σ)/2)
    ε     = λbar - y*μ
    trΛ0Σ = trace(Λ0@Σ)
    μΛ0μ  = μ.T@Λ0@μ
    lndΣ0 = logdet(Σ0)
    lndΣ  = logdet(Q@Q.T)
    return n@ε + trΛ0Σ/2 + μΛ0μ/2 + lndΣ0/2 - lndΣ/2

def Jμ4(μ,Q):
    Σ     = F.T@Q@Q.T@F
    λbar  = sexp(μ+μ0+diag(Σ)/2)
    return n*(λbar-y) + Λ0@μ

j1 = Jμ4(μ,Q) 
j2 = grad(loss4,0)(μ,Q)
print(mean(abs(j1-j2)))

def JQ4(μ,Q):
    Σ    = F.T@Q@Q.T@F
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    P    = pinv(Q)
    return F@(Λ0 + diag(nλ))@F.T@Q - P.T

j1 = JQ4(μ,Q) 
j2 = grad(loss4,1)(μ,Q)
print(mean(abs(j1-j2)))

def Hμ4(μ,Q):
    Σ    = F.T@Q@Q.T@F
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    return diag(nλ) + Λ0

h1 = Hμ4(μ,Q)
h2 = hess(loss4,0)(μ,Q)
print(mean(abs(h1-h2)))

def Hvμ4(μ,Q,u):
    Σ    = F.T@Q@Q.T@F
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    return nλ*u + Λ0@u

Hvμ = lambda u: grad(lambda μ: vdot(grad(loss4,0)(μ,Q), u))(μ)
u1 = Hvμ4(μ,Q,u)
u2 = Hvμ(u)
print(mean(abs(u1-u2)))

def HvQ4(μ,Q,M):
    Σ    = F.T@Q@Q.T@F
    λbar = sexp(μ+μ0+diag(Σ)/2)
    nλ   = λbar*n
    d    = diag(nλ)
    r    = nλ * diag(F.T@M@Q.T@F)
    c    = F@d@F.T@M + F@diag(r)@F.T@Q
    Pt   = inv(Q).T
    return c + F@Λ0@F.T@M + Pt@M.T@Pt

HvQ = lambda M: grad(lambda Q:vdot(grad(loss4,1)(μ,Q),M))(Q)
u1  = HvQ4(μ,Q,M)
u2  = HvQ(M)
print(mean(abs(u1-u2)))

8.659739592076222e-16
3.3306690738754696e-15
5.584421813864538e-16
2.0761170560490427e-15
3.2862601528904634e-14


## Parameterize as Σ = inv( Λ + diag[p] )

In [8]:
p = randn(T)
u = randn(T)

def loss5(μ,p):
    Σ = inv(Λ0 + diag(p))
    λ = exp(μ+μ0+diag(Σ)/2)
    ε = λ - y*μ
    μΛ0μ  = μ.T@Λ0@μ
    return n@ε + (trace(Λ0@Σ) - logdet(Σ) + logdet(Σ0) + μΛ0μ)/2

def Jμ5(μ,p):
    Σ    = pinv(Λ0 + diag(p))
    λbar = sexp(μ+μ0+diag(Σ)/2)
    return n*(λbar-y) + Λ0@μ

j1 = Jμ5(μ,p) 
j2 = grad(loss5,0)(μ,p)
print(mean(abs(j1-j2)))

def Jp5(μ,p):
    Σ = pinv(Λ0 + diag(p))
    λ = exp(μ+μ0+diag(Σ)/2)
    D = diag(n*λ)
    Λ = Λ0 + D
    return diag(Σ - Σ@Λ@Σ)/2
        
j1 = Jp5(μ,p) 
j2 = grad(loss5,1)(μ,p)
print(mean(abs(j1-j2)))

def Hvp5(μ,p,u):
    Σ = inv(Λ0 + diag(p))
    λ = exp(μ+μ0+diag(Σ)/2)
    D = diag(n*λ)
    Λ = Λ0 + D
    return -diag(Σ@diag(u)@Σ)/2 + diag(Σ@Λ@Σ@diag(u)@Σ)

Hvp = lambda u:grad(lambda p:vdot(grad(loss5,1)(μ,p),u))(p)
u1  = Hvp5(μ,p,u)
u2  = Hvp(u)
print(mean(abs(u1-u2)))

3.9523939676655574e-15
7.278384709166752e-17
5.987475806652163e-06
