In [1]:
from pybsd import VariationalLaplaceEstimator
import numpy as np
from jax import numpy as jnp 
from jax import grad, hessian, jit

In [2]:
class Test(VariationalLaplaceEstimator):
    def forward(self, p):
        return p**2 - 4

In [3]:
N = 100
pE = jnp.zeros(N)-1.2
pC = jnp.ones((N,))*1e3
hE = jnp.ones(1)+10
hC = jnp.ones((1,))*1/128
Q  = jnp.eye(N).reshape((1,N,N))*3*N
t = Test(pE, pC, hE, hC, Q, rtol=1e-16)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [4]:
y = jnp.zeros((1,N))+np.random.normal(0,0.001,(3*N,N))

In [5]:
res = t.fit(None, y)

EM: (+) 0: F-F0: 0.00e+00 dF (predicted): 3.99e-07 (actual): 1.61e+09 (31.54 ms)
EM: (+) 1: F-F0: 1.61e+09 dF (predicted): 1.72e+04 (actual): 5.33e+05 (27.05 ms)
EM: (+) 2: F-F0: 1.61e+09 dF (predicted): 2.94e+08 (actual): -3.88e+10 (30.63 ms)
EM: (-) 3: F-F0: 1.61e+09 dF (predicted): 4.12e+07 (actual): -1.67e+07 (32.35 ms)
EM: (-) 4: F-F0: 1.61e+09 dF (predicted): 5.61e+06 (actual): -3.22e+04 (31.41 ms)
EM: (-) 5: F-F0: 1.61e+09 dF (predicted): 7.59e+05 (actual): -5.58e+03 (24.69 ms)
EM: (-) 6: F-F0: 1.61e+09 dF (predicted): 1.03e+05 (actual): 2.66e+03 (16.63 ms)
EM: (+) 7: F-F0: 1.61e+09 dF (predicted): 2.04e+03 (actual): 3.67e+02 (13.29 ms)
EM: (+) 8: F-F0: 1.61e+09 dF (predicted): 1.79e+03 (actual): 4.43e+02 (13.71 ms)
EM: (+) 9: F-F0: 1.61e+09 dF (predicted): 1.37e+03 (actual): 4.98e+02 (12.53 ms)
EM: (+) 10: F-F0: 1.61e+09 dF (predicted): 9.42e+02 (actual): 5.29e+02 (13.02 ms)
EM: (+) 11: F-F0: 1.61e+09 dF (predicted): 6.07e+02 (actual): 5.44e+02 (17.36 ms)
EM: (+) 12: F-F0: 1.61

In [6]:
t.forward(res['Ep'])

Array([[-0.0227623 , -0.02279067, -0.02280784, -0.02283168, -0.02283454,
        -0.0228498 , -0.02290487, -0.02275753, -0.02267194, -0.02284122,
        -0.02276993, -0.02277708, -0.02279544, -0.02276516, -0.02276707,
        -0.02283502, -0.02285266, -0.02278686, -0.0228436 , -0.02286315,
        -0.02281404, -0.02275085, -0.0230124 , -0.02282214, -0.02272987,
        -0.02282119, -0.02291822, -0.02272558, -0.0228126 , -0.02278328,
        -0.02283692, -0.02276468, -0.02283835, -0.02282977, -0.02276325,
        -0.0227623 , -0.02284551, -0.0227499 , -0.02284026, -0.02280498,
        -0.02278233, -0.02280402, -0.02273846, -0.0228231 , -0.02293968,
        -0.02277088, -0.02284598, -0.02277088, -0.02281928, -0.02277279,
        -0.02280402, -0.02282214, -0.02288342, -0.02280021, -0.0229373 ,
        -0.02284408, -0.02280879, -0.02280879, -0.02287054, -0.02293825,
        -0.02272272, -0.02277899, -0.02286744, -0.02284122, -0.02290201,
        -0.0229435 , -0.0228622 , -0.02275181, -0.0

In [7]:
res['L']

array([-1.20580029e+03, -3.45419312e+02, -8.41641211e+03, -4.00050171e+02,
       -3.70290184e+00, -9.28127899e+01])

In [8]:
import time
def format_time(ns):
    if ns < 1e3:  # Less than microsecond
        return f"{ns:.2f} ns"
    elif ns < 1e6:  # Less than millisecond
        return f"{ns * 1e-3:.2f} µs"
    elif ns < 1e9:  # Less than second
        return f"{ns * 1e-6:.2f} ms"
    else:  
        return f"{ns * 1e-9:.2f} s"

qE = pE
gE = hE
for i in range(500):
    tstart = time.perf_counter_ns()
    dq = grad(t.free_energy, argnums=1)(y, qE, pC, gE, hC)
#     H = hessian(t.free_energy, argnums=1)(y, pE, pC, hE, hC)
    dg = grad(t.free_energy, argnums=3)(y, qE, pC, gE, hC)
    qE += 0.000001*dq
#     print(qE)
    gE += 0.00001*dg
    print(format_time(time.perf_counter_ns() - tstart))
print(qE)

ValueError: Argument to slogdet() must have shape [..., n, n], got {a_shape}