### Flow Matching Ergodic Coverage Tutorial
#### Coverage on surfaces using the heat kernel on Riemannian manifolds
This tutorial uses [`lqrax`](https://github.com/MaxMSun/lqrax/tree/main) to solve the continuous time Riccati equation for the LQ flow matching problem and uses heat kernel smoothed gradient vector to guide the agents. The heat kernel on the manifold is computed using the Laplacian evals/evecs.

#### Parameters

**diffusion_coefficient**: diffusion coefficient for the heat kernel control the global/local coverage trade-off. Increasing the diffusion coefficient leads to more global coverage. Another interpretation is that increasing it results in larger agent footprint decreasing the diffusion coefficient requires more iterations to fully cover the target.

**kernel_eps**: bandwith of the Gaussian kernel when projecting off-manifold points to 
Laplacian eigenfunctions. Higher values enable projecting points far from the
surface smoothly but reduce the accuracy.

In [1]:
# Parameters
# ==============================================================================
object_name = "plate_shapes"
x0 = [
    0.4, 0.3, 0.1, 0.02, 0.0, 0.00
]
tsteps = 200 
step_size = 0.5
num_iters = 500
diffusion_coefficient = 1e-4
kernel_eps = 1e-5


# object_name = "spot" 
# x0 = [
#     0.0, 0.0, 0.07, 0.0, 0.0, 0.01
# ]
# tsteps = 200
# diffusion_coefficient = 5e-4
# kernel_eps = 1e-5
# step_size = 0.1 
# num_iters = 400


# object_name = "bunny" 
# x0 = [
#     0.0, 0.0, -0.25, 0.0, 0.0, 0.01
# ]
# tsteps = 200
# diffusion_coefficient = 1e-2
# kernel_eps = 5e-4
# step_size = 0.1 
# num_iters = 400

In [2]:
import time
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import jax.numpy as jnp
import os
import numpy as np

from jax import jit, grad, vmap, jacfwd
from jax.scipy.stats import gaussian_kde as kde
from jax.scipy.stats import multivariate_normal as mvn
import jax
cpu = jax.devices("cpu")[0]
try:
    gpu = jax.devices("cuda")[0]
except:
    gpu = cpu
jnp.set_printoptions(precision=4)

try:
    from lqrax import LQR
except:
    %pip install lqrax
    from lqrax import LQR

try:
    import ott 
except:
    %pip install ott-jax 
    import ott 
from ott.geometry import pointcloud
from ott.geometry.costs import PNormP
from ott.tools.sinkhorn_divergence import sinkhorn_divergence

In [3]:
import robust_laplacian
from plyfile import PlyData
import numpy as np
import scipy.sparse.linalg as sla
import os 
print(os.getcwd())

# Read input
plydata = PlyData.read(f"./test_objects/3d/{object_name}.ply")
# Extract vertex data
vertex = plydata['vertex'].data

# Extract 3D coordinates
points = np.vstack((vertex['x'], vertex['y'], vertex['z'])).T

# Check if color attributes exist
if all(c in vertex.dtype.names for c in ('red', 'green', 'blue')):
    target_density = np.array(vertex['red'])
elif all(c in vertex.dtype.names for c in ('r', 'g', 'b')):
    target_density = np.vstack((vertex['r']))
else:
    target_density = np.ones(len(points))  # or set a default color

L, M = robust_laplacian.point_cloud_laplacian(points)

num_eigen = 200 # increasing further have minimal effect but slows down the computation
evals, evecs = sla.eigsh(L, num_eigen, M, sigma=1e-8)

/Users/cembilaloglu/repos/third_party/lqr-flow-matching/tutorials


In [4]:
def visualize_3d(x0, x_traj, tgt_samples, v_traj=None, sample_colors=None):
    import numpy as np
    import plotly.graph_objects as go

    traj_color = "black"
    default_sample_color = "#9467bd"
    arrow_color = "blue"
    arrow_scale = 0.02  # controls visual length of each arrow

    fig = go.Figure()

    # Target samples
    fig.add_trace(go.Scatter3d(
        x=tgt_samples[:, 0],
        y=tgt_samples[:, 1],
        z=tgt_samples[:, 2],
        mode='markers',
        marker=dict(
            color=sample_colors if sample_colors is not None else default_sample_color,
            size=5,
            opacity=1.0,
            colorscale="bluered" if sample_colors is not None else None
        )
    ))

    # Trajectory path
    fig.add_trace(go.Scatter3d(
        x=x_traj[:, 0],
        y=x_traj[:, 1],
        z=x_traj[:, 2],
        mode='lines+markers',
        line=dict(color=traj_color,width=5),
        marker=dict(color=traj_color, size=5)
    ))

    # Optional velocity arrows along trajectory
    if v_traj is not None:
        for x, v in zip(x_traj, v_traj):
            x_end = x + arrow_scale * v
            fig.add_trace(go.Scatter3d(
                x=[x[0], x_end[0]],
                y=[x[1], x_end[1]],
                z=[x[2], x_end[2]],
                mode='lines',
                line=dict(color=arrow_color, width=3),
                showlegend=False
            ))

    # Starting point
    fig.add_trace(go.Scatter3d(
        x=[x0[0]],
        y=[x0[1]],
        z=[x0[2]],
        mode='markers',
        marker=dict(color='black', size=10)
    ))

    fig.update_layout(
        scene=dict(
            camera=dict(eye=dict(x=0.7, y=0.7, z=0.7)),
            # xaxis=dict(range=[-0.6, 0.6], showgrid=False, visible=False),
            # yaxis=dict(range=[-0.6, 0.6], showgrid=False, visible=False),
            # zaxis=dict(range=[-0.6, 0.6], showgrid=False, visible=False),
            aspectmode='data'
        ),
        paper_bgcolor='white',
        scene_bgcolor='white',
        margin=dict(l=0, r=0, t=0, b=0),
        showlegend=False
    )

    fig.show()


In [5]:
class PointMassLQR(LQR):
    def __init__(self, dt, x_dim, u_dim, Q, R):
        super().__init__(dt, x_dim, u_dim, Q, R)


    def dyn(self, xt, ut):
        return jnp.array([
            xt[3],  # dx = vx
            xt[4],  # dy = vy
            xt[5],  # dz = vz
            ut[0],  # ax
            ut[1],  # ay
            ut[2],  # az
        ])

In [6]:
Q = jnp.diag(jnp.array([
    1.0, 1.0, 1.0, 1e-03, 1e-03, 1e-03,
]))
R = jnp.diag(jnp.array([0.01, 0.01, 0.1]))

pointmass_lqr = PointMassLQR(dt=0.05, x_dim=6, u_dim=3, Q=Q, R=R)

# lqr solving on CPU is faster
linearize_dyn = jit(pointmass_lqr.linearize_dyn, device=cpu)
solve_lqr = jit(pointmass_lqr.solve, device=cpu)

In [7]:
x0 = jnp.array(x0)
u_traj = jnp.zeros((tsteps, 3))
x_traj, A_traj, B_traj = linearize_dyn(x0, u_traj)

In [8]:
visualize_3d(x0, x_traj, points, sample_colors=target_density)

In [9]:
# Heat kernel replacing the Sinkhron Divergence
# ==============================================================================
def nystrom_extension(x, tgt, eigvecs, kernel_eps):
    """
    Approximate eigenfunctions at off-manifold points x via weighted kernel interpolation.
    Returns phi_x: (N_x, K)
    """
    sq_dists = jnp.sum((x[:, None, :] - tgt[None, :, :]) ** 2, axis=2)  # (N_x, N_tgt)
    weights = jnp.exp(-sq_dists / kernel_eps)  # (N_x, N_tgt)
    weights /= jnp.sum(weights, axis=1, keepdims=True)  # normalize

    # interpolate each eigenvector: phi_x[:, k] = sum_j weights[:, j] * eigvecs[j, k]
    return weights @ eigvecs  # (N_x, K)


def spectral_heat_kernel(x, tgt, eigvecs, eigvals, diffusion_coefficient, kernel_eps):
    """
    Compute heat kernel K_xt between off-manifold x and tgt using spectral heat kernel
    with Nystrom extension of eigenfunctions.
    """
    phi_x = nystrom_extension(x, tgt, eigvecs, kernel_eps)  # (N_x, K)
    phi_y = eigvecs  # (N_tgt, K)

    heat_coeffs = jnp.exp(-eigvals * diffusion_coefficient)  # (K,)
    return phi_x @ (phi_y.T * heat_coeffs[:, None])


@jax.jit
def compute_heat_kernel_grad_on_manifold_weighted(
    x_samples,
    tgt_samples,
    eigvecs,
    eigvals,
    target_weights,
    diffusion_coefficient=1e-4,
    kernel_eps=1e-5,
):
    """
    Compute gradient flow using spectral heat kernel on a manifold
    for both attraction and repulsion.
    """
    # Spectral heat kernel from x -> tgt
    K_xt = spectral_heat_kernel(
        x_samples, tgt_samples, eigvecs, eigvals, diffusion_coefficient, kernel_eps
    )

    # Weight by discrete target distribution
    K_xt *= target_weights[None, :]  # broadcast multiplication
    K_xt /= jnp.sum(K_xt, axis=1, keepdims=True)  # normalize rows

    # Spectral heat kernel from x -> x (repulsion)
    phi_x = nystrom_extension(x_samples, tgt_samples, eigvecs, kernel_eps)
    heat_coeffs = jnp.exp(-eigvals * diffusion_coefficient)
    K_xx = phi_x @ (phi_x.T * heat_coeffs[:, None])
    K_xx /= jnp.sum(K_xx, axis=1, keepdims=True)

    # Attraction term
    delta_xt = tgt_samples[None, :, :] - x_samples[:, None, :]  # (N_x, N_tgt, 3)
    grad_xt = jnp.sum(K_xt[:, :, None] * delta_xt, axis=1)

    # Repulsion term
    delta_xx = x_samples[None, :, :] - x_samples[:, None, :]  # (N_x, N_x, 3)
    grad_xx = jnp.sum(K_xx[:, :, None] * delta_xx, axis=1)

    return grad_xt - grad_xx

In [10]:
heat_kernel_dx_traj = compute_heat_kernel_grad_on_manifold_weighted(
        x_traj[:,:3], points, evecs, evals, target_density,
        diffusion_coefficient=diffusion_coefficient,  kernel_eps=kernel_eps
    )
heat_kernel_dx_traj = np.array(heat_kernel_dx_traj)
# note that the dimension of the descent direction matches the trajectory
print(
    f'heat_kernel_dx_traj .shape: {heat_kernel_dx_traj.shape} == x_traj.shape: {x_traj.shape}')

heat_kernel_dx_traj .shape: (200, 3) == x_traj.shape: (200, 6)


In [11]:
# Solve the flow matching ergodic coverage problem
# (use a smaller step size here for smoother animation)
z0 = jnp.zeros(6)
x_traj_list = []
zeros = np.zeros((x_traj.shape[0], 3))
for i in tqdm(range(num_iters)):
    x_traj, A_traj, B_traj = linearize_dyn(x0, u_traj)

    heat_kernel_dx_traj = compute_heat_kernel_grad_on_manifold_weighted(
        x_traj[:,:3], points, evecs, evals, target_density,
        diffusion_coefficient=diffusion_coefficient,  kernel_eps=kernel_eps
    )
    heat_kernel_dx_traj = np.hstack((heat_kernel_dx_traj, zeros))

    v_traj, z_traj = solve_lqr(z0, A_traj, B_traj, heat_kernel_dx_traj)
    u_traj += step_size * v_traj
    x_traj_list.append(np.array(x_traj))
x_traj_list = np.array(x_traj_list)

  0%|          | 0/500 [00:00<?, ?it/s]

In [12]:
visualize_3d(x0, x_traj_list[-1,...], points, sample_colors=target_density)

In [13]:
def visualize_3d_animation(x0, x_traj_list, tgt_samples, v_traj_list=None, sample_colors_list=None):
    import plotly.graph_objects as go
    import numpy as np

    # traj_color = "#ff7f0e"
    traj_color = "black"
    default_sample_color = "#9467bd"
    arrow_color = "blue"
    arrow_scale = 0.02

    T = x_traj_list[0].shape[0]
    num_frames = len(x_traj_list)

    fig = go.Figure()

    # Trace 0: animated target samples (markers only, with changing color)
    fig.add_trace(go.Scatter3d(
        x=tgt_samples[:, 0], y=tgt_samples[:, 1], z=tgt_samples[:, 2],
        mode='markers',
        marker=dict(size=3, opacity=0.1,
                    color= default_sample_color,
                                colorscale="bluered",),
        showlegend=False
    ))

    # Trace 1: Trajectory (lines + markers)
    fig.add_trace(go.Scatter3d(
        x=[], y=[], z=[],
        mode='lines+markers',
        line=dict(color=traj_color),
        marker=dict(size=3, color=traj_color),
        showlegend=False
    ))

    # Trace 2 to 1+T: Velocity arrows
    for _ in range(T):
        fig.add_trace(go.Scatter3d(
            x=[], y=[], z=[],
            mode='lines',
            line=dict(color=arrow_color, width=2),
            showlegend=False
        ))

    # Build frames
    frames = []
    skip = 1
    for i in range(0, num_frames, skip):
        x_traj = x_traj_list[i]
        frame_data = []

        # Updated target samples
        sample_colors = sample_colors_list[i]
        frame_data.append(go.Scatter3d(
            x=tgt_samples[:, 0],
            y=tgt_samples[:, 1],
            z=tgt_samples[:, 2],
            mode='markers',
            marker=dict(size=3, opacity=0.5, color=sample_colors),
            showlegend=False
        ))

        # Trajectory
        frame_data.append(go.Scatter3d(
            x=x_traj[:, 0],
            y=x_traj[:, 1],
            z=x_traj[:, 2],
            mode='lines+markers',
            line=dict(color=traj_color,width=5),
            marker=dict(size=5, color=traj_color),
            showlegend=False
        ))

        # Velocity arrows
        if v_traj_list is not None:
            v_traj = v_traj_list[i]
            for x, v in zip(x_traj, v_traj):
                x_end = x + arrow_scale * v
                frame_data.append(go.Scatter3d(
                    x=[x[0], x_end[0]],
                    y=[x[1], x_end[1]],
                    z=[x[2], x_end[2]],
                    mode='lines',
                    line=dict(color=arrow_color, width=2),
                    showlegend=False
                ))
        else:
            for _ in range(T):
                frame_data.append(go.Scatter3d(
                    x=[None], y=[None], z=[None],
                    mode='lines',
                    line=dict(color=arrow_color, width=2),
                    showlegend=False
                ))

        trace_indices = list(range(len(frame_data)))
        frames.append(go.Frame(name=str(i),
                               data=frame_data,
                               traces=trace_indices))

    fig.frames = frames

    # Slider
    steps = [dict(method='animate',
                  args=[[f.name], dict(mode='immediate',
                                       frame=dict(duration=0, redraw=True),
                                       transition=dict(duration=0))],
                  label=f.name) for f in frames]

    fig.update_layout(
        sliders=[dict(
            active=len(frames) - 1,
            y=-0.05,
            x=0.5,
            xanchor="center",
            pad=dict(t=10),
            len=0.5,
            steps=steps,
            currentvalue=dict(prefix="Frame: ", font=dict(size=12))
        )],
        updatemenus=[dict(
            type="buttons",
            direction="left",
            x=0.5,
            y=-0.25,
            xanchor="center",
            yanchor="top",
            pad=dict(r=10, t=10),
            buttons=[
                dict(label="Start",
                     method="animate",
                     args=[None, dict(frame=dict(duration=50, redraw=True),
                                      transition=dict(duration=0),
                                      fromcurrent=True,
                                      mode="immediate")]),
                dict(label="Pause",
                     method="animate",
                     args=[[None], dict(frame=dict(duration=0, redraw=False),
                                        transition=dict(duration=0),
                                        mode="immediate")]),
                dict(label="Reset",
                     method="animate",
                     args=[[frames[0].name], dict(frame=dict(duration=0, redraw=True),
                                                  transition=dict(duration=0),
                                                  mode="immediate")])
            ]
        )],
        scene=dict(
            camera=dict(eye=dict(x=0.7, y=0.7, z=0.7)),
            # xaxis=dict(range=[-0.6, 0.6], showgrid=False, visible=False),
            # yaxis=dict(range=[-0.6, 0.6], showgrid=False, visible=False),
            # zaxis=dict(range=[-0.6, 0.6], showgrid=False, visible=False),
            aspectmode='data',
            # showgrid=False
        ),
        paper_bgcolor='white',
        scene_bgcolor='white',
        showlegend=False
    )
    # fig.write_html(
    # f"hedac_3d.html",
    # include_plotlyjs="cdn", 
    # full_html=True,
    # auto_play=False
    # )
    fig.show()

In [14]:
target_density = np.broadcast_to(target_density, (num_iters, len(points))).copy()
visualize_3d_animation(x0, x_traj_list[::20],v_traj_list=None, tgt_samples=points,sample_colors_list=target_density[::20])