In [150]:
import numpy as np
import math as m
import scipy as sp

In [104]:
class NormalRV:
    def __init__(self, mu, Sigma):
        self.mu = mu
        self.Sigma = Sigma
        self.S = np.linalg.inv(Sigma)
        self.m = self.S@self.mu
    
    def pdf(self, x):
        # x a numpy array
        # should work if x is a matrix, applying the pdf to each column
        mu = self.mu
        Sigma = self.Sigma
        d = len(mu)
        x = np.asarray(x)
        if x.ndim == 1:
            x = x[:, np.newaxis]
        n = x.shape[1]
        pdf = np.zeros(n)
        for i in range(n):
            pdf[i] = 1/np.sqrt((2*np.pi)**d * np.linalg.det(Sigma)) * \
                np.exp(-1/2 * (x[:,i] - mu).T @ np.linalg.inv(Sigma) @ (x[:,i] - mu))
        return pdf

    def sample(self, n):
        # n: number of samples
        mu = self.mu
        Sigma = self.Sigma
        return np.random.multivariate_normal(mu, Sigma, n).T
    
    def update_vals(self, m, S):
        self.m = m
        self.S = S
        self.Sigma = np.linalg.inv(S)
        self.mu = self.Sigma@m
        

In [105]:
q = NormalRV(np.array([0,0]), np.array([[1,0],[0,1]]))
pi = NormalRV(np.array([1,1]), np.array([[1,0],[0,1]]))




In [106]:
def partial_lq_m(x, p):
    return x - np.reshape(p.mu, (len(p.mu), 1))

def partial_lq_S(x, p):
    # x can be a matrix, each column is a sample
    # in this case, return a 3D array, each slice is a matrix  
    # (i.e. a 2D array) corresponding to a sample
    x = np.asarray(x)
    if x.ndim == 1:
        x = x[:, np.newaxis]
    n = x.shape[1]
    d = len(p.mu)
    res = np.zeros((d,d,n))
    for i in range(n):
        xi = x[:,i]
        res[:,:,i] = -0.5*(np.outer(xi, xi) - np.outer(p.mu, p.mu)- p.Sigma)
    return res


In [107]:
def is_pd(M):

    # efficient check if a matrix is positive-definite
    # from https://stackoverflow.com/questions/16266720/find-out-if-matrix-is-positive-definite-with-numpy

    try:
        _ = np.linalg.cholesky(M)
        return True
    except np.linalg.LinAlgError:
        return False

In [108]:
def project_pd(M, eps=1e-6):
    # project a matrix to the positive definite cone
    # M: a symmetric matrix
    # return: a symmetric matrix
    if is_pd(M):
        return M

    # eigen decomposition
    eig_vals, eig_vecs = np.linalg.eig(M)
    # set negative eigenvalues to eps
    eig_vals[eig_vals <= 0] = eps
    # reconstruct matrix
    return eig_vecs @ np.diag(eig_vals) @ eig_vecs.T


In [155]:
def OAIS(phi, pi, q0, lr, nsamples, niter):
    results = []
    for _ in range(niter):
        # compute inner product
        samples = q0.sample(nsamples)
        w = pi.pdf(samples) / q0.pdf(samples) # compute w as we have access to pi
        w2 = w**2
        phi_samples = np.apply_along_axis(phi, 0, samples)
        integral = np.mean(w*phi_samples)/np.mean(w)

        # update q0
        results.append(integral)
        partial_m = partial_lq_m(samples, q0)
        update_m = -np.mean(w2 * partial_m, axis=1)
        
        partial_S = partial_lq_S(samples, q0)
        update_S = -np.mean(w2*partial_S, axis=2)
        
        new_S = project_pd(q0.S - lr*update_S)
        new_m = q0.m - lr*update_m
        q0.update_vals(new_m, new_S)

    return results
        
        
        
q = NormalRV(np.array([1,1]), np.array([[1,0],[0,1]]))
pi = NormalRV(np.array([0,0]), np.array([[1,0],[0,1]]))

def phi(x):
    # indicator function for square [-0.5, 0.5]^2
    if np.abs(x[0]) < 0.5 and np.abs(x[1]) < 0.5:
        return 1
    else:
        return 0
e = OAIS(phi, pi, q, 0.01, 1000, 100)

v = sp.stats.norm.cdf(0.5) - sp.stats.norm.cdf(-0.5)
np.abs(e - v**2)


array([1.87241968e-02, 1.00668832e-02, 2.00480191e-02, 3.78777418e-02,
       1.49069918e-02, 4.69667040e-03, 8.23680531e-04, 2.48697595e-02,
       1.31624154e-02, 1.25976904e-02, 1.40364605e-03, 2.76616049e-03,
       1.18906455e-02, 1.71286431e-02, 3.84074106e-03, 6.99387456e-03,
       1.17330653e-02, 3.31187344e-03, 5.12995674e-03, 1.09078225e-02,
       1.42560943e-02, 4.40111412e-03, 4.99023919e-03, 1.37956929e-02,
       2.27256693e-03, 9.14862037e-03, 8.64406586e-03, 6.84244680e-03,
       8.81585141e-03, 7.59101649e-03, 2.15738793e-02, 1.04114275e-02,
       1.30551365e-02, 2.70749850e-03, 1.88390741e-03, 1.06892914e-03,
       6.78827805e-03, 5.96634455e-03, 3.43927460e-03, 8.14629908e-03,
       3.93164844e-03, 9.32632079e-03, 4.95224502e-03, 3.64558964e-03,
       2.81640355e-03, 7.31097219e-04, 4.68116986e-03, 1.22401830e-02,
       1.05756227e-02, 1.83059997e-02, 2.97963286e-03, 1.53645565e-02,
       4.53451248e-03, 9.18877762e-03, 3.84887453e-03, 7.73389551e-03,
      

In [146]:
m.erf(-0.5)

-0.5204998778130465