In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"  # specify which GPU(s) to be used

import jax.numpy as np
from jax import vmap
import numpy as onp
import matplotlib.pyplot as plt
import scipy.stats
import scipy.stats as stats
import jax.scipy.optimize as jopt
import jax
from jax.config import config 
config.update("jax_enable_x64",True)
import time as time

In [2]:
from jax.lib import xla_bridge
print("jax backend {}".format(xla_bridge.get_backend().platform))

jax backend cpu


In [3]:
print(xla_bridge.get_backend().devices())

[CpuDevice(id=0)]


In [4]:
#Use Jax to express the distance objective and its derivative
def distance_obj2(diagPbar,P,logdetP1,n):
    iPbar = np.diag(1.0/np.abs(diagPbar))
    
    #(sig1^-1 - sigbar^-1)sig0
    M = np.matmul(iPbar,P)
    TrM = np.trace(M)
    t1 = 1.0/2.0 * (n - 2*TrM + np.trace(np.matmul(M,M)))

    logdetPbar = np.sum(np.log(diagPbar))
    t3 = 1.0/4.0*(logdetPbar - logdetP1 - (n-TrM))**2

                                                                              
    return t1+t3

distance_obj2 = jax.jit(distance_obj2)
dist_grad2 =  jax.jit(jax.grad(distance_obj2, argnums=(0)))
dist_hess2 = jax.jit(jax.jacfwd(jax.jacrev(distance_obj2, argnums=(0)), argnums=(0)))

In [5]:
#Use Jax to express the distance objective and its derivative
def theta_obj(theta, Pvec, Pd_inter,x,R):
    

    Pd = np.diag(Pvec/(1.0-np.abs(theta)*Pvec))

    #optimize overal max different in k
    ninf = np.matmul(Pd,x)
    dinf = np.add(np.matmul(np.matmul(x,Pd),x),R)
    
    ninter = np.matmul(Pd_inter,x)
    dinter = np.add(np.matmul(np.matmul(x,Pd_inter),x),R)

    return np.max(np.abs(ninf/dinf - ninter/dinter)/np.sqrt(Pvec))
    
theta_obj = jax.jit(theta_obj)
theta_grad =  jax.jit(jax.grad(theta_obj, argnums=(0)))
theta_hess = jax.jit(jax.jacfwd(jax.jacrev(theta_obj, argnums=(0)), argnums=(0)))

In [6]:
def theta_solve(Pd,Pd_inter,x,R):
    #compute the upper bound on theta
    ub1= onp.min(0.99/np.diag(Pd)) #This makes Pd < Infinity really Pd*100
    ub2= onp.max((1.0-np.diag(Pd))/np.diag(Pd)) #This makes the smallest Pd<1? is that ok? Basically trying to say we shoudl never be totally worse than the prior...
    ub = onp.minimum(ub1,ub2)
    optbuf = scipy.optimize.minimize(theta_obj, ub/2.0, args=(np.diag(Pd),Pd_inter,x,R), method="SLSQP", jac=theta_grad, bounds=((0.0,ub),),options={'maxiter':1000})
    return optbuf.success, np.abs(optbuf.x)

In [7]:
TestID = 1

ndims = [2,4,8,16,32,64,128,256,512]
Ntest = 8
rng = onp.random.default_rng(31+TestID)

T = 1000
t0 = time.time()
for idim in range(len(ndims)):
    n = ndims[idim]
    
    #Data Matrices
    #save true_theta
    #save mu and diag of P for all tests to analyze later
    
    true_theta_data = onp.zeros((Ntest,n))
    mu_data = onp.zeros((Ntest,n,5))
    P_data = onp.zeros((Ntest,n,5))
    
    for itest in range(Ntest):
        print(str(n) + " " +str(itest) + " " + str(time.time()-t0))
        #Generate Data
        var_noise = 0.1
        true_theta = onp.array(rng.standard_normal(n))
        var_x = 0.5*onp.eye(n)
        mean_x = onp.array(rng.standard_normal(n))

        x = rng.multivariate_normal(mean_x,var_x,T)
        ymu = np.matmul(x,true_theta)
        y = ymu+np.sqrt(var_noise)*rng.standard_normal(T)

        x = np.array(x)
        y = np.array(y)

        #Initialize
        cov_theta = 1.0*np.identity(n)# TO-DO, remove constants
        mean_theta = np.zeros((n,1))

        P_kalman = cov_theta
        mu_kalman = mean_theta

        P_vi = cov_theta
        mu_vi = mean_theta

        P_dist = cov_theta
        mu_dist = mean_theta

        P_dist_inf = cov_theta
        P_dist_infr1 = cov_theta
        mu_dist_inf = mean_theta

        P_vi_inf = cov_theta
        P_vi_infr1 = cov_theta
        mu_vi_inf = mean_theta

        #####NEED TO Handle when things fail
        Fail = 0
        Fail_vi = 0
        P_Fail =cov_theta
        P_Fail_vi =cov_theta
        for i in range(T):
            #print(i, end=" ")

            # Kalman
            #t0 = time.time()
            S_kalman = np.add(np.matmul(np.matmul(x[i,:],P_kalman),x[i,:]),var_noise)
            K_kalman = np.multiply(np.matmul(P_kalman,x[i,:].reshape(n,1)),1.0/S_kalman.item())
            mu_kalman = np.add(mu_kalman ,np.multiply(K_kalman,np.subtract(y[i,],np.matmul(x[i,:],mu_kalman))))
            P_kalman = np.matmul(np.subtract(np.identity(n),np.matmul(K_kalman,x[i,:].reshape(1,n))),P_kalman)

            # Pure VI
            S_vi = np.add(np.matmul(np.matmul(x[i,:],P_vi),x[i,:]),var_noise)
            K_vi = np.multiply(np.matmul(P_vi,x[i,:].reshape(n,1)),1.0/S_vi.item())
            mu_vi = np.add(mu_vi ,np.multiply(K_vi,np.subtract(y[i,],np.matmul(x[i,:],mu_vi))))
            P_vi_inter = np.matmul(np.subtract(np.identity(n),np.matmul(K_vi,x[i,:].reshape(1,n))),P_vi)
            P_vi = np.diag(np.diag(P_vi_inter)) #Only keep diagonal elements

            # Pure VI + Hinfnity
            thetabool, theta = theta_solve(P_vi_inf,P_vi_infr1,x[i,:],var_noise) #find theta to make errors work...
            P_old = P_vi_inf
            P_vi_inf = np.matmul(P_vi_inf,np.linalg.inv(np.eye(n)-theta*P_vi_inf))
            if (np.min(np.diag(P_vi_inf)) < 0) | (thetabool==False):
                Fail_vi = Fail_vi + 1
                print("Fail VI: " + str(Fail_vi))
                P_vi_inf = P_Fail_vi
            P_Fail_vi = P_vi_inf
            S_vi_inf = np.add(np.matmul(np.matmul(x[i,:],P_vi_inf),x[i,:]),var_noise)
            K_vi_inf = np.multiply(np.matmul(P_vi_inf,x[i,:].reshape(n,1)),1/S_vi_inf.item())
            mu_vi_inf = np.add(mu_vi_inf ,np.multiply(K_vi_inf,np.subtract(y[i,],np.matmul(x[i,:],mu_vi_inf))))
            P_vi_infr1 = np.matmul(np.subtract(np.identity(n),np.matmul(K_vi_inf,x[i,:].reshape(1,n))),P_vi_inf) # TO-DO remove constants
            P_vi_inf = np.diag(np.diag(P_vi_infr1))


            # Distance VI
            S_dist = np.add(np.matmul(np.matmul(x[i,:],P_dist),x[i,:]),var_noise)
            K_dist = np.multiply(np.matmul(P_dist,x[i,:].reshape(n,1)),1.0/S_dist.item())
            mu_dist = np.add(mu_dist ,np.multiply(K_dist,np.subtract(y[i,],np.matmul(x[i,:],mu_dist))))
            P_dist_inter = np.matmul(np.subtract(np.identity(n),np.matmul(K_dist,x[i,:].reshape(1,n))),P_dist)
            sign, logdetP1 = np.linalg.slogdet(P_dist_inter)
            optbuf = scipy.optimize.minimize(distance_obj2, onp.diag(P_dist_inter), args=(P_dist_inter,logdetP1,n), method="Newton-CG", hess=dist_hess2, jac=dist_grad2,options={'maxiter':1000})
            P_dist = np.diag(np.abs(optbuf.x)) #Solve Optimization


            # Distance VI + Hinfnity
            thetabool, theta = theta_solve(P_dist_inf,P_dist_infr1,x[i,:],var_noise) #find theta to make errors work...
            P_old = P_dist_inf
            P_dist_inf = np.matmul(P_dist_inf,np.linalg.inv(np.eye(n)-theta*P_dist_inf))
            if (np.min(np.diag(P_dist_inf)) < 0) | (thetabool==False):
                Fail = Fail + 1
                print("Fail: " + str(Fail))
                P_dist_inf = P_Fail #Reject past update

            P_Fail = P_dist_inf
            S_dist_inf = np.add(np.matmul(np.matmul(x[i,:],P_dist_inf),x[i,:]),var_noise)
            K_dist_inf = np.multiply(np.matmul(P_dist_inf,x[i,:].reshape(n,1)),1/S_dist_inf.item())
            mu_dist_inf = np.add(mu_dist_inf ,np.multiply(K_dist_inf,np.subtract(y[i,],np.matmul(x[i,:],mu_dist_inf))))
            P_dist_infr1 = np.matmul(np.subtract(np.identity(n),np.matmul(K_dist_inf,x[i,:].reshape(1,n))),P_dist_inf) # TO-DO remove constants

            sign, logdetP1 = np.linalg.slogdet(P_dist_infr1)
            optbuf = scipy.optimize.minimize(distance_obj2, onp.diag(P_dist_infr1), args=(P_dist_infr1,logdetP1,n), method="Newton-CG", jac=dist_grad2, hess=dist_hess2, options={'maxiter':1000})
            P_dist_inf = np.diag(np.abs(optbuf.x))
            #t2 = time.time()
            #print((t2-t0))
            
            
        #Save the outputs
        true_theta_data[itest,:] = true_theta
        mu_data[itest,:,0] = mu_kalman[:,-1]
        mu_data[itest,:,1] = mu_vi[:,-1]
        mu_data[itest,:,2] = mu_vi_inf[:,-1]
        mu_data[itest,:,3] = mu_dist[:,-1]
        mu_data[itest,:,4] = mu_dist_inf[:,-1]
        
        P_data[itest,:,0] = np.diag(P_kalman)
        P_data[itest,:,1] = np.diag(P_vi)
        P_data[itest,:,2] = np.diag(P_vi_inf)
        P_data[itest,:,3] = np.diag(P_dist)
        P_data[itest,:,4] = np.diag(P_dist_inf)
    np.savez("/global/cfs/cdirs/m3876/blta2/VI_Filtering/VI_Filters_" + str(TestID) + "_" + str(n)+".npz",T=T,n=n,true_theta_data=true_theta_data,mu_data=mu_data,P_data=P_data)

2 0 0.0005431175231933594




2 1 10.969022989273071




2 2 20.924451112747192




2 3 31.182199239730835




2 4 41.188714027404785




2 5 50.69083094596863




2 6 60.02604913711548




2 7 69.48416495323181




FileNotFoundError: [Errno 2] No such file or directory: '/global/cfs/cdirs/m3876/blta2/VI_Filtering/VI_Filters_1_2.npz'