In [1]:
import numpy as np
from numpy.testing import assert_allclose
from scipy.special import gamma
from math import comb
import matplotlib.pyplot as plt

In [83]:
# Define function to compute all moments for a general dilation matrix.
# The implementation focuses on conceptual simplicity, while sacrifizing
# memory efficiency.
def compute_moments_inefficient_implementation(A, a, maxdeg):
    """
    Parameters:
    - A: symmetric 3x3 matrix (np.ndarray of shape (3,3))
        Dilation matrix of the Gaussian that determines its shape.
        It can be written as cov = RDR^T, where R is a rotation matrix that specifies
        the orientation of the three principal axes, while D is a diagonal matrix
        whose three diagonal elements are the lengths of the principal axes.
    - a: np.ndarray of shape (3,)
        Contains the position vector for the center of the trivariate Gaussian.
    - maxdeg: int
        Maximum degree for which the moments need to be computed.
        
    Returns:
    - The list of moments defined as
        <x^n0 * y^n1 * z^n2> = integral (x^n0 * y^n1 * z^n2) * exp(-0.5*(r-a).T@cov@(r-a)) dxdydz
        Note that the term "moments" in probability theory are defined for normalized Gaussian distributions.
        Here, we take the Gaussian 
    """
    # Make sure that the provided arrays have the correct dimensions and properties
    assert A.shape == (3,3), "Dilation matrix needs to be 3x3"
    assert np.sum((A-A.T)**2) < 1e-14, "Dilation matrix needs to be symmetric"
    assert a.shape == (3,), "Center of Gaussian has to be given by a 3-dim. vector"
    assert maxdeg > 0, "The maximum degree needs to be at least 1"
    cov = np.linalg.inv(A) # the covariance matrix is the inverse of the matrix A
    global_factor = (2*np.pi)**1.5 / np.sqrt(np.linalg.det(A)) # normalization of Gaussian
    
    # Initialize the array in which to store the moments
    # moments[n0, n1, n2] will be set to <x^n0 * y^n1 * z^n2>
    # This representation is memory inefficient, since only about 1/3 of the
    # array elements will actually be relevant.
    # The advantage, however, is the simplicity in later use.
    moments = np.zeros((maxdeg+1, maxdeg+1, maxdeg+1))
    
    # Initialize the first few elements
    moments[0,0,0] = 1.
    moments[1,0,0] = a[0] # <x>
    moments[0,1,0] = a[1] # <y>
    moments[0,0,1] = a[2] # <z>
    if maxdeg == 1:
        return global_factor * moments
    
    # Initialize the quadratic elements
    moments[2,0,0] = cov[0,0] + a[0]**2
    moments[0,2,0] = cov[1,1] + a[1]**2
    moments[0,0,2] = cov[2,2] + a[2]**2
    moments[1,1,0] = cov[0,1] + a[0]*a[1]
    moments[0,1,1] = cov[1,2] + a[1]*a[2]
    moments[1,0,1] = cov[2,0] + a[2]*a[0]
    if maxdeg == 2:
        return global_factor * moments
    
    # Iterate over all possible exponents to generate all moments
    # Instead of iterating over n1, n2 and n3, we iterate over the total degree of the monomials
    # which will allow us to simplify certain edge cases.
    for deg in range(2, maxdeg):
        for n0 in range(deg+1):
            for n1 in range(deg+1-n0):
                # We consider monomials of degree "deg", and generate moments of degree deg+1.
                n2 = deg - n0 - n1
                
                # Run the x-iteration
                moments[n0+1,n1,n2] = a[0]*moments[n0,n1,n2] + cov[0,0]*n0*moments[n0-1,n1,n2]
                moments[n0+1,n1,n2] += cov[0,1]*n1*moments[n0,n1-1,n2] + cov[0,2]*n2*moments[n0,n1,n2-1]
                
                # If n0 is equal to zero, we also need the y- and z-iterations
                if n0 == 0:
                    # Run the y-iteration
                    moments[n0,n1+1,n2] = a[1]*moments[n0,n1,n2] + cov[1,0]*n0*moments[n0-1,n1,n2]
                    moments[n0,n1+1,n2] += cov[1,1]*n1*moments[n0,n1-1,n2] + cov[1,2]*n2*moments[n0,n1,n2-1]
                    
                    if n0 == 0 and n1 == 0:
                        # Run the z-iteration
                        moments[n0,n1,n2+1] = a[2]*moments[n0,n1,n2] + cov[2,0]*n0*moments[n0-1,n1,n2]
                        moments[n0,n1,n2+1] += cov[2,1]*n1*moments[n0,n1-1,n2] + cov[2,2]*n2*moments[n0,n1,n2-1]
    
    return global_factor * moments

# Comparing the general code vs a special purpose diagonal-only version

We now define the same code that only works for already diagonalized dilation matrices.

In [5]:
def compute_moments_single_variable(A, a, maxdeg):
    """
    Parameters:
    - A: inverse of variance
    - a: center
    - maxdeg: int
        Maximum degree for which the moments need to be computed.
        
    Returns:
    - A numpy array of size (maxdeg+1, ) containing the moments defined as
        <x^n> = integral x^n exp(-A(x-a)^2/2) dx
    """
    assert maxdeg > 0
    moments = np.zeros((maxdeg+1, ))
    moments[0] = np.sqrt(2*np.pi/A)
    moments[1] = a * moments[0]
    
    # If maxdeg = 1, there is nothing more to do
    if maxdeg == 1:
        return moments
    
    for deg in range(1, maxdeg):
        moments[deg+1] = a*moments[deg] + deg*moments[deg-1] / A
        
    return moments

In [6]:
# Define function to compute all moments for a diagonal dilation matrix.
# The implementation focuses on conceptual simplicity, while sacrifizing
# memory efficiency.
def compute_moments_diagonal_inefficient_implementation(principcal_components, a, maxdeg):
    """
    Parameters:
    - principal_components: np.ndarray of shape (3,)
        Array containing the three principal components
    - a: np.ndarray of shape (3,)
        Contains the information about the center of the trivariate Gaussian.
    - maxdeg: int
        Maximum degree for which the moments need to be computed.
        
    Returns:
    - The list of moments defined as
        <x^n0 * y^n1 * z^n2> = integral (x^n0 * y^n1 * z^n2) * exp(-0.5*(r-a).T@cov@(r-a)) dxdydz
        Note that the term "moments" in probability theory are defined for normalized Gaussian distributions.
        Here, we take the Gaussian 
    """
    # Initialize the array in which to store the moments
    # moments[n0, n1, n2] will be set to <x^n0 * y^n1 * z^n2>
    # This representation is very inefficient, since only about 1/6 of the
    # array elements will actually be relevant.
    # The advantage, however, is the simplicity in later use.
    moments = np.zeros((maxdeg+1, maxdeg+1, maxdeg+1))
    
    # Precompute the single variable moments in x- y- and z-directions:
    moments_x = compute_moments_single_variable(principal_components[0], a[0], maxdeg)
    moments_y = compute_moments_single_variable(principal_components[1], a[1], maxdeg)
    moments_z = compute_moments_single_variable(principal_components[2], a[2], maxdeg)

    # Compute values for all relevant components for which the monomial degree is <= maxdeg
    for n0 in range(maxdeg+1):
        for n1 in range(maxdeg+1):
            for n2 in range(maxdeg+1):
                # Make sure that the degree is not above the maximal degree
                deg = n0 + n1 + n2
                if deg > maxdeg:
                    continue
                    
                # If not, the moment is a product of the x- y- and z-integrals
                moments[n0, n1, n2] = moments_x[n0] * moments_y[n1] * moments_z[n2]
    return moments

### Test the single variable case

Centered Gaussian

In [7]:
sigma = 0.32
a = 0.
maxdeg = 5
A = 1/sigma**2
moments_single = compute_moments_single_variable(A, a, maxdeg)

In [8]:
exact_values = np.zeros((maxdeg+1,))
for deg in range(maxdeg+1):
    exact_value = 0
    if deg % 2 == 0:
        neff = (deg + 1) / 2
        exact_value = (2 * sigma**2)**neff * gamma(neff)
    exact_values[deg] = exact_value

In [9]:
assert_allclose(exact_values, moments_single,atol=1e-15)

Non centered Gaussian

In [10]:
sigma = 0.23
A = 1/sigma**2
a = 0.5
maxdeg = 5
moments_single = compute_moments_single_variable(A, a, maxdeg)

In [11]:
def test_non_centered_moments(A, a, maxdeg):
    # Compute the exact moments for the centered moments
    centered_moments = np.zeros((maxdeg+1,))
    for deg in range(maxdeg+1):
        exact_value = 0
        if deg % 2 == 0:
            neff = (deg + 1) / 2
            exact_value = (2 * sigma**2)**neff * gamma(neff)
        centered_moments[deg] = exact_value
    
    # Compute the moments from the binomial theorem
    moments = np.zeros((maxdeg+1,))
    for deg in range(maxdeg+1):
        moments[deg] += centered_moments[deg]
        
        # Get the correction from the centered moment
        for k in range(deg):
            moments[deg] -= comb(deg, k) * (-a)**(deg-k) * moments[k]
    
    moments_from_code = compute_moments_single_variable(A, a, maxdeg)
    
    assert_allclose(moments_from_code, moments)

### Test the diagonal implementation 

Check agreement with the exact expression for degrees 0, 1, 2, 3.

In [90]:
principal_components = np.array([2.8,0.4,1.1])
A = np.diag(principal_components)
a = np.array([3.1, -2.3, 5.92])
#a *= 0
maxdeg = 3
moments_general = compute_moments_inefficient_implementation(A, a, maxdeg)
moments_diagonal = compute_moments_diagonal_inefficient_implementation(principal_components, a, maxdeg)

In [89]:
def get_exact_moments(A, a, maxdeg=3):
    global_factor = (2*np.pi)**1.5 / np.sqrt(np.linalg.det(A))
    assert maxdeg in [1,2,3]
    
    moments_exact = np.zeros((maxdeg+1, maxdeg+1, maxdeg+1))
    moments_exact[0,0,0] = 1.
    # Exact expressions for degree 1
    moments_exact[1,0,0] = a[0]
    moments_exact[0,1,0] = a[1]
    moments_exact[0,0,1] = a[2]
    if maxdeg == 1:
        return global_factor * moments_exact

    # Exact expressions for degree 2
    moments_exact[2,0,0] = cov[0,0] + a[0]**2
    moments_exact[0,2,0] = cov[1,1] + a[1]**2
    moments_exact[0,0,2] = cov[2,2] + a[2]**2
    moments_exact[1,1,0] = a[0]*a[1]
    moments_exact[0,1,1] = a[1]*a[2]
    moments_exact[1,0,1] = a[0]*a[2]
    if maxdeg == 2:
        return global_factor * moments_exact

    # Exact expressions for degree 3
    moments_exact[3,0,0] = 3*a[0]*cov[0,0] + a[0]**3
    moments_exact[0,3,0] = 3*a[1]*cov[1,1] + a[1]**3
    moments_exact[0,0,3] = 3*a[2]*cov[2,2] + a[2]**3
    moments_exact[2,1,0] = a[1]*(cov[0,0] + a[0]**2) +  2*a[0]*cov[0,1]
    moments_exact[2,0,1] = a[2]*(cov[0,0] + a[0]**2) +  2*a[0]*cov[0,2]
    moments_exact[1,2,0] = a[0]*(cov[1,1] + a[1]**2) +  2*a[1]*cov[1,0]
    moments_exact[0,2,1] = a[2]*(cov[1,1] + a[1]**2) +  2*a[1]*cov[1,2]
    moments_exact[1,0,2] = a[0]*(cov[2,2] + a[2]**2) +  2*a[2]*cov[2,0]
    moments_exact[0,1,2] = a[1]*(cov[2,2] + a[2]**2) +  2*a[2]*cov[2,1]
    moments_exact[1,1,1] = a[0]*a[1]*a[2] + a[0]*cov[1,2] + a[1]*cov[0,2] + a[2]*cov[0,1]
    if maxdeg == 3:
        return global_factor * moments_exact

In [59]:
moments_exact = np.zeros((4,4,4))
moments_exact[0,0,0] = 1.
# Exact expressions for degree 1
moments_exact[1,0,0] = a[0]
moments_exact[0,1,0] = a[1]
moments_exact[0,0,1] = a[2]

# Exact expressions for degree 2
moments_exact[2,0,0] = cov[0,0] + a[0]**2
moments_exact[0,2,0] = cov[1,1] + a[1]**2
moments_exact[0,0,2] = cov[2,2] + a[2]**2
moments_exact[1,1,0] = a[0]*a[1]
moments_exact[0,1,1] = a[1]*a[2]
moments_exact[1,0,1] = a[0]*a[2]

# Exact expressions for degree 3
moments_exact[3,0,0] = 3*a[0]*cov[0,0] + a[0]**3
moments_exact[0,3,0] = 3*a[1]*cov[1,1] + a[1]**3
moments_exact[0,0,3] = 3*a[2]*cov[2,2] + a[2]**3
moments_exact[2,1,0] = a[1]*(cov[0,0] + a[0]**2) +  2*a[0]*cov[0,1]
moments_exact[2,0,1] = a[2]*(cov[0,0] + a[0]**2) +  2*a[0]*cov[0,2]
moments_exact[1,2,0] = a[0]*(cov[1,1] + a[1]**2) +  2*a[1]*cov[1,0]
moments_exact[0,2,1] = a[2]*(cov[1,1] + a[1]**2) +  2*a[1]*cov[1,2]
moments_exact[1,0,2] = a[0]*(cov[2,2] + a[2]**2) +  2*a[2]*cov[2,0]
moments_exact[0,1,2] = a[1]*(cov[2,2] + a[2]**2) +  2*a[2]*cov[2,1]
moments_exact[1,1,1] = a[0]*a[1]*a[2] + a[0]*cov[1,2] + a[1]*cov[0,2] + a[2]*cov[0,1]

moments_exact *= (2*np.pi)**1.5 / np.sqrt(np.prod(principal_components))

In [92]:
moments_exact = get_exact_moments(A, a, maxdeg=3)
assert_allclose(moments_exact, moments_diagonal, rtol=1e-14, atol=1e-14)
assert_allclose(moments_exact, moments_general, rtol=1e-14, atol=1e-14)

Test correctness for isotropic special case.

In [78]:
sigma = 0.32
principal_components = np.ones((3,)) / sigma**2
A = np.diag(principal_components)
a0 = 3.2
a = a0 * np.ones((3,))
maxdeg = 8
moments_diagonal = compute_moments_diagonal_inefficient_implementation(principal_components, a, maxdeg)

In [79]:
eps = 1e-12
for n0 in range(maxdeg+1):
    for n1 in range(maxdeg+1):
        for n2 in range(maxdeg+1):
            deg = n0 + n1 + n2
            if deg > maxdeg:
                assert moments_diagonal[n0, n1, n2] == 0
            else:
                assert abs(moments_diagonal[n0,n1,n2]-moments_diagonal[n1,n0,n2]) < eps
                assert abs(moments_diagonal[n0,n1,n2]-moments_diagonal[n2,n1,n0]) < eps
                assert abs(moments_diagonal[n0,n1,n2]-moments_diagonal[n0,n2,n1]) < eps

## General vs diagonal implementation

Now that we have tested the diagonal implementation to quite a high degree, we compare it to the general implementation.

In [101]:
principal_components = np.array([2.8,0.4,1.1])
A = np.diag(principal_components)
a = np.array([3.1, -2.3, 5.92])
maxdeg = 3
moments_general = compute_moments_inefficient_implementation(A, a, maxdeg)
moments_diagonal = compute_moments_diagonal_inefficient_implementation(principal_components, a, maxdeg)
assert_allclose(moments_general, moments_diagonal, rtol=1e-15, atol=3e-16)

In [102]:
moments_general

array([[[   14.18941362,    84.00132865,   510.18733253,  3173.038697  ],
        [  -32.63565133,  -193.20305589, -1173.43086481,     0.        ],
        [  110.53553212,   654.37035017,     0.        ,     0.        ],
        [ -417.40998055,     0.        ,     0.        ,     0.        ]],

       [[   43.98718223,   260.40411881,  1581.58073083,     0.        ],
        [ -101.17051913,  -598.92947326,     0.        ,     0.        ],
        [  342.66014958,     0.        ,     0.        ,     0.        ],
        [    0.        ,     0.        ,     0.        ,     0.        ]],

       [[  141.42791264,   837.25324282,     0.        ,     0.        ],
        [ -325.28419907,     0.        ,     0.        ,     0.        ],
        [    0.        ,     0.        ,     0.        ,     0.        ],
        [    0.        ,     0.        ,     0.        ,     0.        ]],

       [[  469.84594506,     0.        ,     0.        ,     0.        ],
        [    0.        ,     0. 

In [103]:
moments_diagonal

array([[[   14.18941362,    84.00132865,   510.18733253,  3173.038697  ],
        [  -32.63565133,  -193.20305589, -1173.43086481,     0.        ],
        [  110.53553212,   654.37035017,     0.        ,     0.        ],
        [ -417.40998055,     0.        ,     0.        ,     0.        ]],

       [[   43.98718223,   260.40411881,  1581.58073083,     0.        ],
        [ -101.17051913,  -598.92947326,     0.        ,     0.        ],
        [  342.66014958,     0.        ,     0.        ,     0.        ],
        [    0.        ,     0.        ,     0.        ,     0.        ]],

       [[  141.42791264,   837.25324282,     0.        ,     0.        ],
        [ -325.28419907,     0.        ,     0.        ,     0.        ],
        [    0.        ,     0.        ,     0.        ,     0.        ],
        [    0.        ,     0.        ,     0.        ,     0.        ]],

       [[  469.84594506,     0.        ,     0.        ,     0.        ],
        [    0.        ,     0. 