### Flow Matching Ergodic Coverage Tutorial

#### Coverage on surfaces using the heat kernel of Eucldiean spaces 
This tutorial uses [`lqrax`](https://github.com/MaxMSun/lqrax/tree/main) to solve the continuous time Riccati equation for the LQ flow matching problem and uses heat kernel smoothed gradient vector to guide the agents.

#### Parameters

**diffusion_coefficient**: diffusion coefficient for the heat kernel control the global/local coverage trade-off. Increasing the diffusion coefficient leads to more global coverage. Another interpretation is that increasing it results in larger agent footprint decreasing the diffusion coefficient requires more iterations to fully cover the target.

In [1]:
# Parameters
# ==============================================================================
step_size = 2
num_iters = 600
diffusion_coefficient = 5e-3

In [2]:
import time
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import jax.numpy as jnp
import os
import numpy as np

from jax import jit, grad, vmap, jacfwd
from jax.scipy.stats import gaussian_kde as kde
from jax.scipy.stats import multivariate_normal as mvn
import jax
cpu = jax.devices("cpu")[0]
try:
    gpu = jax.devices("cuda")[0]
except:
    gpu = cpu
jnp.set_printoptions(precision=4)

try:
    from lqrax import LQR
except:
    %pip install lqrax
    from lqrax import LQR

In [3]:
def visualize_3d(x0, x_traj, tgt_samples):
    import numpy as np
    import plotly.graph_objects as go

    traj_color = "#ff7f0e"
    sample_color = "#9467bd"
    quad_color = "black"
    arm_length=0.08

    fig = go.Figure()
    # samples
    fig.add_trace(go.Scatter3d(
        x=tgt_samples[:, 0], y=tgt_samples[:, 1], z=tgt_samples[:, 2],
        mode='markers',
        marker=dict(color=sample_color, size=3, opacity=0.1)
    ))
    # trajectory
    fig.add_trace(go.Scatter3d(
        x=x_traj[:, 0], y=x_traj[:, 1], z=x_traj[:, 2],
        mode='lines+markers',
        line=dict(color=traj_color),
        marker=dict(color=traj_color, size=3)
    ))
    # quadcopter arms (two crossing lines)
    fig.add_trace(go.Scatter3d(
        x=[x0[0]-arm_length, x0[0]+arm_length],
        y=[x0[1],        x0[1]],
        z=[x0[2],        x0[2]],
        mode='lines',
        line=dict(color=quad_color, width=8)
    ))
    fig.add_trace(go.Scatter3d(
        x=[x0[0],        x0[0]],
        y=[x0[1]-arm_length, x0[1]+arm_length],
        z=[x0[2],        x0[2]],
        mode='lines',
        line=dict(color=quad_color, width=8)
    ))
    
    fig.add_trace(go.Scatter3d(
        x=[x0[0]],
        y=[x0[1]-arm_length],
        z=[x0[2]],
        mode='markers',
        marker=dict(color='black', size=5)
    ))
    fig.add_trace(go.Scatter3d(
        x=[x0[0]],
        y=[x0[1]+arm_length],
        z=[x0[2]],
        mode='markers',
        marker=dict(color='black', size=5)
    ))
    fig.add_trace(go.Scatter3d(
        x=[x0[0]-arm_length],
        y=[x0[1]],
        z=[x0[2]],
        mode='markers',
        marker=dict(color='black', size=5)
    ))
    fig.add_trace(go.Scatter3d(
        x=[x0[0]+arm_length],
        y=[x0[1]],
        z=[x0[2]],
        mode='markers',
        marker=dict(color='black', size=5)
    ))

    # propeller rings
    prop_radius = arm_length * 0.5     # radius of each little ring
    theta = np.linspace(0, 2*np.pi, 30)
    # offsets of the four motors from center
    offsets = [(arm_length,  0),
               (-arm_length,  0),
               (0,  arm_length),
               (0, -arm_length)]
    for dx, dy in offsets:
        cx, cy, cz = x0[0] + dx, x0[1] + dy, x0[2]
        fig.add_trace(go.Scatter3d(
            x=cx + prop_radius * np.cos(theta),
            y=cy + prop_radius * np.sin(theta),
            z=cz + np.zeros_like(theta),
            mode='lines',
            line=dict(color=quad_color, width=4)
        ))
    
    fig.add_trace(go.Scatter3d(x=[x0[0]], y=[x0[1]], z=[
        x0[2]], mode='markers', marker=dict(color='black', size=10)))

    fig.update_layout(
        scene=dict(
            camera=dict(eye=dict(x=0.7, y=0.7, z=0.7)),
            xaxis=dict(range=[-0.6, 0.6], showgrid=False, visible=False),
            yaxis=dict(range=[-0.6, 0.6], showgrid=False, visible=False),
            zaxis=dict(range=[-0.6, 0.6], showgrid=False, visible=False),
            aspectmode='cube'
        ),
        paper_bgcolor='white',
        scene_bgcolor='white',
        margin=dict(l=0, r=0, t=0, b=0),
        showlegend=False
    )
    fig.show()

In [4]:
import urllib.request

object_list = ["bunny", "airplane", "armadillo", "chair"]
obj_idx = 0

object = object_list[obj_idx]
print(f'object: {object}')

resp = urllib.request.urlopen(
    f"https://raw.githubusercontent.com/MurpheyLab/lqr-flow-matching/refs/heads/main/tutorials/test_objects/3d/{object}.txt")
tgt_samples_dense = np.loadtxt(resp)[:, :3]


tgt_samples = tgt_samples_dense[:3000]
tgt_samples = jnp.array(tgt_samples)
num_samples = tgt_samples.shape[0]
print(f'samples.shape: {tgt_samples.shape}')

object: bunny
samples.shape: (3000, 3)


In [5]:
class AircraftLQR(LQR):
    def __init__(self, dt, x_dim, u_dim, Q, R):
        super().__init__(dt, x_dim, u_dim, Q, R)

    def dyn(self, xt, ut):
        x, y, z, psi, phi, v = xt
        xdot = jnp.array([
            v * jnp.cos(phi) * jnp.cos(psi),
            v * jnp.cos(phi) * jnp.sin(psi),
            v * jnp.sin(phi),
            ut[0],
            ut[1],
            ut[2]
        ])
        return xdot

In [6]:
Q = jnp.diag(jnp.array([
    1.0, 1.0, 1.0, 1e-03, 1e-03, 1e-03,
]))
R = jnp.diag(jnp.array([0.01, 0.01, 0.1]))

aircraft_lqr = AircraftLQR(dt=0.05, x_dim=6, u_dim=3, Q=Q, R=R)

# lqr solving on CPU is faster
linearize_dyn = jit(aircraft_lqr.linearize_dyn, device=cpu)
solve_lqr = jit(aircraft_lqr.solve, device=cpu)

In [7]:
tsteps = 600

x0 = jnp.array([
    -0.2, -0.2, 0.0, jnp.pi/4.0, 0.0, 0.02
])
u_traj = jnp.zeros((tsteps, 3))
x_traj, A_traj, B_traj = linearize_dyn(x0, u_traj)

In [8]:
visualize_3d(x0, x_traj, tgt_samples_dense)

In [9]:
# Heat kernel replacing the Sinkhron Divergence
# ==============================================================================
def gaussian_kernel(x, y, diffusion_coefficient):
    """
    Compute Gaussian (heat) kernel matrix between x and y
    """
    x_norm = jnp.sum(x**2, axis=1, keepdims=True)
    y_norm = jnp.sum(y**2, axis=1, keepdims=True)
    dists = x_norm - 2 * x @ y.T + y_norm.T
    return jnp.exp(-dists / (4 * diffusion_coefficient))


@jax.jit
def compute_heat_kernel_dx_traj(x_samples, tgt_samples, diffusion_coefficient=1e-3):
    """
    Compute gradient directions using heat kernel similarity
    """
    # Heat kernels
    K_xt = gaussian_kernel(x_samples, tgt_samples, diffusion_coefficient)
    K_xx = gaussian_kernel(x_samples, x_samples, diffusion_coefficient)

    # Row normalization
    K_xt /= jnp.sum(K_xt, axis=1, keepdims=True)
    K_xx /= jnp.sum(K_xx, axis=1, keepdims=True)

    # Displacement vectors
    delta_xt = tgt_samples[None, :, :] - x_samples[:, None, :]
    delta_xx = x_samples[None, :, :] - x_samples[:, None, :]

    # Weighted displacements
    grad_xt = jnp.sum(K_xt[:, :, None] * delta_xt, axis=1)
    grad_xx = jnp.sum(K_xx[:, :, None] * delta_xx, axis=1)
    alpha = 0.5
    return alpha * grad_xt - (1 - alpha) * grad_xx

In [10]:
heat_kernel_dx_traj = compute_heat_kernel_dx_traj(x_traj[:,:2],tgt_samples[:,:2])
heat_kernel_dx_traj = np.array(heat_kernel_dx_traj)
# note that the dimension of the descent direction matches the trajectory
print(
    f'heat_kernel_dx_traj .shape: {heat_kernel_dx_traj.shape} == x_traj.shape: {x_traj.shape}')

heat_kernel_dx_traj .shape: (600, 2) == x_traj.shape: (600, 6)


In [11]:
# Solve the flow matching ergodic coverage problem
# (use a smaller step size here for smoother animation)
z0 = jnp.zeros(6)
# Create two columns of zeros
zeros = np.zeros((heat_kernel_dx_traj.shape[0], 3))
x_traj_list = []
for i in tqdm(range(num_iters)):
    x_traj, A_traj, B_traj = linearize_dyn(x0, u_traj)
    heat_kernel_dx_traj = compute_heat_kernel_dx_traj(
        x_traj[:, :3], tgt_samples[:, :3], diffusion_coefficient=diffusion_coefficient
    )
    heat_kernel_dx_traj = np.array(heat_kernel_dx_traj)
    heat_kernel_dx_traj = np.hstack((heat_kernel_dx_traj, zeros))
    v_traj, z_traj = solve_lqr(z0, A_traj, B_traj, heat_kernel_dx_traj)
    u_traj += step_size * v_traj
    x_traj_list.append(np.array(x_traj))
final_x_traj = aircraft_lqr.traj_sim(x0, u_traj)
x_traj_list.append(final_x_traj)
x_traj_list = np.array(x_traj_list)

  0%|          | 0/600 [00:00<?, ?it/s]

In [12]:
visualize_3d(x0, final_x_traj, tgt_samples_dense)

In [13]:
def visualize_3d_animation(x0, x_traj_list, tgt_samples):
    import plotly.graph_objects as go
    traj_color = "#ff7f0e"
    sample_color = "#9467bd"
    quad_color = "black"
    arm_length = 0.08
    
    fig = go.Figure()

    fig.update_layout(
        margin=dict(l=10, r=10, t=30, b=60),
        plot_bgcolor="white",
        width=800,    # square canvas
        height=600
    )

    fig.add_trace(go.Scatter3d(
        x=tgt_samples[:, 0], y=tgt_samples[:, 1], z=tgt_samples[:, 2],
        mode='markers',
        marker=dict(size=3, color=sample_color, opacity=0.1),
        showlegend=False
    ))
    fig.add_trace(go.Scatter3d(
        x=x_traj_list[-1][:, 0], y=x_traj_list[-1][:, 1], z=x_traj_list[-1][:, 2],
        mode='lines+markers',
        line=dict(color=traj_color),
        marker=dict(size=3, color=traj_color),
        showlegend=False
    ))
    
    # quadcopter arms (two crossing lines)
    fig.add_trace(go.Scatter3d(
        x=[x0[0]-arm_length, x0[0]+arm_length],
        y=[x0[1],        x0[1]],
        z=[x0[2],        x0[2]],
        mode='lines',
        line=dict(color=quad_color, width=8)
    ))
    fig.add_trace(go.Scatter3d(
        x=[x0[0],        x0[0]],
        y=[x0[1]-arm_length, x0[1]+arm_length],
        z=[x0[2],        x0[2]],
        mode='lines',
        line=dict(color=quad_color, width=8)
    ))

    fig.add_trace(go.Scatter3d(
        x=[x0[0]],
        y=[x0[1]-arm_length],
        z=[x0[2]],
        mode='markers',
        marker=dict(color='black', size=5)
    ))
    fig.add_trace(go.Scatter3d(
        x=[x0[0]],
        y=[x0[1]+arm_length],
        z=[x0[2]],
        mode='markers',
        marker=dict(color='black', size=5)
    ))
    fig.add_trace(go.Scatter3d(
        x=[x0[0]-arm_length],
        y=[x0[1]],
        z=[x0[2]],
        mode='markers',
        marker=dict(color='black', size=5)
    ))
    fig.add_trace(go.Scatter3d(
        x=[x0[0]+arm_length],
        y=[x0[1]],
        z=[x0[2]],
        mode='markers',
        marker=dict(color='black', size=5)
    ))

    # propeller rings
    prop_radius = arm_length * 0.5     # radius of each little ring
    theta = np.linspace(0, 2*np.pi, 30)
    # offsets of the four motors from center
    offsets = [(arm_length,  0),
               (-arm_length,  0),
               (0,  arm_length),
               (0, -arm_length)]
    for dx, dy in offsets:
        cx, cy, cz = x0[0] + dx, x0[1] + dy, x0[2]
        fig.add_trace(go.Scatter3d(
            x=cx + prop_radius * np.cos(theta),
            y=cy + prop_radius * np.sin(theta),
            z=cz + np.zeros_like(theta),
            mode='lines',
            line=dict(color=quad_color, width=4)
        ))

    fig.add_trace(go.Scatter3d(x=[x0[0]], y=[x0[1]], z=[
        x0[2]], mode='markers', marker=dict(color='black', size=10)))

    frames = []
    skip = 1
    for i, traj in enumerate(x_traj_list[::skip]):
        frames.append(go.Frame(
            name=str(i*skip),
            data=[go.Scatter3d(
                x=traj[:, 0], y=traj[:, 1], z=traj[:, 2],
                mode='lines+markers',
                line=dict(color=traj_color),
                marker=dict(size=3, color=traj_color),
                showlegend=False
            )],
            traces=[1]
        ))
    fig.frames = frames

    steps = [
        dict(
            method='animate',
            args=[[str(i)], dict(mode='immediate', frame=dict(
                duration=0, redraw=True), transition=dict(duration=0))],
            label=str(i)
        ) for i in range(len(x_traj_list))
    ]
    
    # 5) Slider (narrow, centered)
    steps = [
        dict(method="animate",
            args=[[str(i)], dict(mode="immediate",
                                frame=dict(duration=0, redraw=True),
                                transition=dict(duration=0))],
            label=str(i))
        for i in range(len(x_traj_list))
    ]
    fig.update_layout(
        sliders=[dict(
            active=len(x_traj_list)-1,
            y=-0.05,
            x=0.5,
            xanchor="center",
            pad=dict(t=10),
            len=0.5,
            steps=steps,
            currentvalue=dict(prefix="Iteration: ", font=dict(size=12))
        )]
    )

    # 6) Start / Pause / Reset buttons (row, centered, moved down)
    fig.update_layout(
        updatemenus=[dict(
            type="buttons",
            direction="left",
            x=0.5,
            y=-0.25,
            xanchor="center",
            yanchor="top",
            pad=dict(r=10, t=10),
            buttons=[
                dict(label="Start",
                    method="animate",
                    args=[None, dict(frame=dict(duration=10, redraw=True),
                                    transition=dict(duration=0),
                                    fromcurrent=True,
                                    mode="immediate")]),
                dict(label="Pause",
                    method="animate",
                    args=[[None], dict(frame=dict(duration=0, redraw=False),
                                        transition=dict(duration=0),
                                        mode="immediate")]),
                dict(label="Reset",
                    method="animate",
                    args=[[str(0)], dict(frame=dict(duration=0, redraw=True),
                                        transition=dict(duration=0),
                                        mode="immediate")])
            ]
        )]
    )

    fig.update_layout(
        scene=dict(
            camera=dict(eye=dict(x=0.7, y=0.7, z=0.7)),
            xaxis=dict(range=[-0.6, 0.6], showgrid=False, visible=False),
            yaxis=dict(range=[-0.6, 0.6], showgrid=False, visible=False),
            zaxis=dict(range=[-0.6, 0.6], showgrid=False, visible=False),
            aspectmode='cube'
        ),
        paper_bgcolor='white',
        scene_bgcolor='white',
        showlegend=False
    )
    
    fig.show()
    # fig.write_html(
    #     f"sinkhorn3d_{object}.html",
    #     include_plotlyjs="cdn",
    #     full_html=True,
    #     auto_play=False
    # )

In [14]:
visualize_3d_animation(x0, x_traj_list[::2], tgt_samples_dense)