In [22]:
%cd /Users/kostastsampourakis/Desktop/code/Python/projects/BayesianFiltering
import codebase.utils as utils
import codebase.gaussfilt as gf
import numpy as np
from jax import numpy as jnp
from jax import random as jrandom
from jax import jacfwd, jacrev, vmap
from numpy import random
import matplotlib.pyplot as plt
import codebase.gausssumfilt as gsf
import pandas as pd
import codebase.particlefilt as pf
import time

/Users/kostastsampourakis/Desktop/code/Python/projects/BayesianFiltering


In [35]:
def project_to_psd(Delta):
    evals, evec = np.linalg.eig(Delta)
    nonzero_eig = np.sum(evals > 0)
    new_evals = np.multiply(evals > 0, evals)
    new_Delta = evec @ np.diag(new_evals) @ evec.T
    return (new_Delta + new_Delta.T) / 2


def gradient_descent(dim, N, L, X0, P, H, Nsteps, eta):
    X = X0
    for i in range(Nsteps):
        X = X - eta * (-(2 * L ** 2 / N) * np.eye(dim) + (1 / 2) * np.trace(np.matmul(H, X)) * H)
    return X


def sdp_opt(dim, N, L, X0, P, H, Nsteps, eta):
    X = X0
    for i in range(Nsteps):
        X = gradient_descent(dim, N, L, X, P, H, 1, eta ** i)
        X = project_to_psd(X)
        X = P - project_to_psd(P - X)
        X = project_to_psd(X)
    return X.reshape(dim, dim)

def f(x, sigma=10, rho=28, beta=2.667, dt=0.01):
    dx = dt * sigma * (x[1]**3 - x[0]*x[1]*x[2])
    dy = dt * (x[0] * rho - x[1] - x[0] *x[2]**2) 
    dz = dt * (x[0] * x[1] - beta * x[2])
    return jnp.array([dx+x[0], dy+x[1], dz+x[2]])


jacobian = jacfwd(f)
hessian = jacrev(jacobian)

In [79]:
mu = jnp.array([1.0, 1.0, 1.0])
Sigma = jnp.eye(3)
sample = jrandom.multivariate_normal(jrandom.PRNGKey(0), mu, Sigma, (10,))
vhessian = vmap(hessian)
hess_array = vhessian(sample)

## Gradient descent
X = Sigma
eta = 0.01
L = 0.1 
N = 10
dim = 3
sum_hess = jnp.sum(hess_array, axis=0)
for i in range(100):
    coeffs = jnp.sum(jnp.trace(jnp.matmul(X, hess_array), axis1=2, axis2=3), axis = 0)
    term_two = jnp.zeros((dim, dim))
    for j in range(dim):
        term_two += coeffs[j] * sum_hess[j]
    X = X - eta * (-(2 * L ** 2 / N) * np.eye(dim) + (1 / 2 / N**2) * term_two)
X = project_to_psd(X)
X = Sigma - project_to_psd(Sigma - X)
X = project_to_psd(X)
print(X)




[[1.0020027  0.02212034 0.02255214]
 [0.02212034 0.86550784 0.02731085]
 [0.02255214 0.02731085 1.001757  ]]
[[0.9871601  0.0176468  0.00733725]
 [0.0176468  0.86415946 0.02272514]
 [0.00733725 0.02272514 0.9861605 ]]
