In [228]:
from lie_learn.spectral.SE2FFT import SE2_FFT, shift_fft, shift_ifft

In [174]:
import numpy as np
from scipy.stats import multivariate_normal
from typing import Tuple

In [251]:
from numpy.fft import ifftshift, fftshift
from einops import rearrange

In [258]:
cov = np.diag([0.1, 0.1, 0.1])
# cov = np.eye(3) * 1e-4

In [259]:
# mu = np.array([0.1,0.1,np.pi/2])
mu = np.array([ 0.  , -0.15,  0.  ])

In [260]:
def se2_grid_samples(size: Tuple[int] = (5, 5, 5),
                     lower_bound: float = -0.5,
                     upper_bound: float = 0.5) -> np.ndarray:
    xs = np.linspace(lower_bound, upper_bound, size[0], endpoint=False)
    ys = np.linspace(lower_bound, upper_bound, size[1], endpoint=False)
    ts = np.linspace(0., 2. * np.pi, size[2], endpoint=False)
    X, Y, T = np.meshgrid(xs, ys, ts, indexing='ij')
    poses = np.vstack((X.flatten(), Y.flatten(), T.flatten())).T
    return poses, X, Y, T

In [261]:
grid_size = (50,50,32)

In [262]:
poses, X, Y, T = se2_grid_samples(grid_size)

In [263]:
diff = poses - mu
# Wrap angle
diff[:, 2] = (diff[:, 2] + np.pi) % (2 * np.pi) - np.pi
energy = multivariate_normal.logpdf(diff, mean=np.zeros(3), cov=cov)
energy = energy.reshape(grid_size)

In [None]:


def se2_gaussian(x, y, theta, mean, cov):
    """
    Compute the value of a Gaussian function on SE(2).
    
    Parameters:
    - x, y: Translation components.
    - theta: Rotation component (in radians).
    - mean: A tuple (mean_x, mean_y, mean_theta) representing the mean of the Gaussian.
    - cov: A 3x3 covariance matrix for the Gaussian.
    
    Returns:
    - The value of the Gaussian function at the given (x, y, theta).
    """
    mean_x, mean_y, mean_theta = mean
    cov_xy = cov[:2, :2]  # Covariance for translation
    cov_theta = cov[2, 2]  # Variance for rotation
    
    # Translation part (2D Gaussian)
    translation_vector = np.array([x - mean_x, y - mean_y])
    translation_exponent = -0.5 * translation_vector.T @ np.linalg.inv(cov_xy) @ translation_vector
    translation_gaussian = np.exp(translation_exponent) / (2 * np.pi * np.sqrt(np.linalg.det(cov_xy)))
    
    # Rotation part (Wrapped Gaussian)
    delta_theta = (theta - mean_theta + np.pi) % (2 * np.pi) - np.pi  # Wrap angle to [-π, π]
    rotation_exponent = -0.5 * (delta_theta**2) / cov_theta
    rotation_gaussian = np.exp(rotation_exponent) / np.sqrt(2 * np.pi * cov_theta)
    
    # Combine translation and rotation parts
    return translation_gaussian * rotation_gaussian


In [264]:
fft = SE2_FFT(spatial_grid_size=grid_size,
                      interpolation_method='spline',
                      spline_order=2,
                      oversampling_factor=3)

In [265]:
f, f1c, f1p, f2, f2f, fh = fft.analyze(energy)

In [266]:
fi, f1ci, f1pi, f2i, f2fi, fhi = fft.synthesize(fh)

In [267]:
np.mean(np.abs(f - fi))

0.5146681235751328

In [288]:
poses_ = poses[3000]

In [292]:
poses_.shape

(3,)

In [293]:
diff_ = poses_[None,...] - mu
# Wrap angle
diff_[:, 2] = (diff_[:, 2] + np.pi) % (2 * np.pi) - np.pi
energy_true = multivariate_normal.logpdf(diff_, mean=np.zeros(3), cov=cov)

In [246]:
b_x, b_y, b_t = grid_size

In [236]:
b_x, b_y, b_t = grid_size
p, n, m = energy.shape
dx, dy, d_theta = poses_
# Synthesize signal to obtain first FFT and
_, _, _, f_p_psi_m, _, _ = se2_fft.synthesize(M)
# Shift the signal to the origin
f_p_psi_m_ = ifftshift(f_p_psi_m, axes=2)
# Theta ranges from 0 to 2pi, thus ts = 2 * np.pi (duration)
t_theta = 2 * np.pi
n_theta = f_p_psi_m_.shape[2]
# Evaluate fourier coefficients at desired point
omega_n = 2 * np.pi * (1 / t_theta) * np.arange(b_t)
# Compute the value of f(x) using the inverse Fourier transform
f_p_psi = np.sum(f_p_psi_m * np.exp(-1j * omega_n.reshape(1, 1, -1) * d_theta), axis=2)
# Map from polar to cartesian grid
f_p_p = se2_fft.resample_p2c_3d(f_p_psi.reshape(p, n, 1)).squeeze()
# Finally, 2D inverse FFT
f_p_p = shift_ifft(f_p_p)
# Set domain of X and Y, recall X and Y range from [-0.5, 0.5]
t_x, t_y = 1., 1.
n_x, n_y = f_p_p.shape[:2]
# Compute complex term
omega_nx = 2 * np.pi * (1 / t_x) * np.arange(n_x)  # Angular frequencies in X
omega_ny = 2 * np.pi * (1 / t_y) * np.arange(n_y)  # Angular frequencies in Y
# Compute the value of p(g) using the inverse Fourier transform
f = np.sum(f_p_p * np.exp(1j * omega_nx.reshape(-1, 1) * (dx-0.5) + 1j * omega_ny.reshape(1, -1) * (dy-0.5))).real

In [294]:
energy_true

-14.092443461484649

In [285]:
def neg_log_likelihood(
        eta: np.ndarray, l_n_z: float, pose: np.ndarray, se2_fft: SE2_FFT
    ) -> float:
        """
        Compute point-wise synthesize the SE2 Fourier transform M at a given pose. More explicitly, this function
        computes p(g = pose)
        Args:
            eta (np.array): Fourier coefficients (eta) of SE2 distribution with shape [n, 3] where n is the number of
                            samples
            l_n_z (float): Log of normalization constant of SE2 distribution
            pose (np.array): Pose at which to interpolate the SE2 Fourier transform
            se2_fft (SE2_FFT): Object class for SE2 Fourier transform

        Returns:
            Probability of distribution determined by fourier coefficients (moments) at given pose
        """
        b_x, b_y, b_t = grid_size
        # Reshape in case single pose is provided
        if pose.ndim < 2:
            pose = rearrange(pose, "b -> 1 b")
        # Arrange pose samples in broadcastable shape
        dx, dy = rearrange(pose[:, 0], "b -> 1 b"), rearrange(pose[:, 1], "b -> 1 b")
        d_theta = rearrange(pose[:, 2], "b -> 1 b")
        # Synthesize signal to obtain first FFT and
        _, _, _, f_p_psi_m, _, _ = se2_fft.synthesize(eta)
        # Shift the signal to the origin
        f_p_psi_m = rearrange(ifftshift(f_p_psi_m, axes=2), "p n m -> p n m 1")
        # Theta ranges from 0 to 2pi, thus ts = 2 * np.pi (duration)
        t_theta = 2 * np.pi
        # Evaluate fourier coefficients at desired point
        omega_n = (
            2 * np.pi * (1 / t_theta) * rearrange(np.arange(b_t), "n -> 1 1 n 1")
        )
        # Compute the value of f(x) using the inverse Fourier transform
        f_p_psi = np.sum(f_p_psi_m * np.exp(1j * omega_n * d_theta), axis=2)
        # f_p_psi = np.sum(f_p_psi_m * np.exp(-1j * omega_n * d_theta / (2. * np.pi)), axis=2)
        # Map from polar to cartesian grid
        f_p_p = se2_fft.resample_p2c_3d(f_p_psi)
        # Finally, 2D inverse FFT
        f_p_p = ifftshift(f_p_p, axes=(0, 1))
        # Set domain of X and Y, recall X and Y range from [-0.5, 0.5]
        t_x, t_y = 1.0, 1.0
        # Compute complex term
        angle_x = (
            1j * 2 * np.pi * (1 / t_x) * rearrange(np.arange(b_x), "nx -> nx 1") * (dx + 0.5)
        )  # Angle component in X
        angle_y = (
            1j * 2 * np.pi * (1 / t_y) * rearrange(np.arange(b_y), "ny -> ny 1") * (dy + 0.5)
        )  # Angle component in Y
        # angle_x = 1j * 2 * np.pi * (1 / t_x) * rearrange(np.arange(n_x), 'nx -> nx 1') * (dx + 0.5)  # Angle component in X
        # angle_y = 1j * 2 * np.pi * (1 / t_y) * rearrange(np.arange(n_y), 'ny -> ny 1') * (dy + 0.5)  # Angle component in Y
        angle = rearrange(angle_x, "nx b -> nx 1 b") + rearrange(
            angle_y, "ny b -> 1 ny b"
        )
        # Compute the value of log(p(g)) using the inverse Fourier transform
        f = np.sum(f_p_p * np.exp(angle), axis=(0, 1)).real
#         - l_n_z
        return f

In [295]:
neg_log_likelihood(fh, None, poses_, fft)

array([-12.25566903])