# Velocity Verlet Integrator for a Gaussian Target

This notebook demonstrates how to use the velocity verlet integrator from Blackjax to simulate Hamiltonian dynamics for a Gaussian target distribution.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats

import blackjax
from blackjax.mcmc import integrators
from blackjax.mcmc import metrics
from blackjax.types import ArrayTree

# Set random seed for reproducibility
rng_key = jax.random.key(42)

## Define a Gaussian Target Distribution

We'll use a 2D Gaussian distribution as our target. The log-density function is defined as:

In [None]:
# Define the parameters of the Gaussian
mean = jnp.array([0.0, 0.0])
cov = jnp.array([[1.0, 0.5], [0.5, 2.0]])

# Define the log-density function
def logdensity_fn(position):
    return stats.multivariate_normal.logpdf(position, mean, cov)

# Visualize the target distribution
def plot_gaussian():
    x = np.linspace(-4, 4, 100)
    y = np.linspace(-4, 4, 100)
    X, Y = np.meshgrid(x, y)
    Z = np.zeros((100, 100))
    for i in range(100):
        for j in range(100):
            Z[i, j] = np.exp(logdensity_fn(np.array([X[i, j], Y[i, j]])))
    
    plt.figure(figsize=(10, 8))
    plt.contour(X, Y, Z, levels=20)
    plt.colorbar(label='Density')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('2D Gaussian Target Distribution')
    plt.axis('equal')
    plt.show()

plot_gaussian()

## Set Up the Velocity Verlet Integrator

Now, let's set up the velocity verlet integrator. We need to:
1. Define a kinetic energy function
2. Create an integrator state
3. Run the integrator for a few steps

In [None]:
# Define the inverse mass matrix (for the kinetic energy)
inverse_mass_matrix = jnp.eye(2)  # Identity matrix for simplicity

# Create a metric object
metric = metrics.default_metric(inverse_mass_matrix)

# Get the kinetic energy function from the metric
kinetic_energy_fn = metric.kinetic_energy

# Create the velocity verlet integrator
integrator = integrators.velocity_verlet(logdensity_fn, kinetic_energy_fn)

# Set the initial position and momentum
initial_position = jnp.array([2.0, 2.0])
initial_momentum = jnp.array([0.5, -0.3])

# Create the initial integrator state
initial_state = integrators.new_integrator_state(logdensity_fn, initial_position, initial_momentum)

# Set the step size and number of steps
step_size = 0.1
num_steps = 50

## Run the Integrator and Visualize the Trajectory

Now, let's run the integrator for a few steps and visualize the trajectory:

In [None]:
# Function to run the integrator for multiple steps
def run_integrator(initial_state, integrator, step_size, num_steps):
    def one_step(state, _):
        new_state = integrator(state, step_size)
        return new_state, new_state
    
    _, states = jax.lax.scan(one_step, initial_state, None, length=num_steps)
    return states

# Run the integrator
states = run_integrator(initial_state, integrator, step_size, num_steps)

# Extract positions and momenta
positions = jnp.array([state.position for state in states])
momenta = jnp.array([state.momentum for state in states])

# Visualize the trajectory
def plot_trajectory(positions, momenta):
    x = np.linspace(-4, 4, 100)
    y = np.linspace(-4, 4, 100)
    X, Y = np.meshgrid(x, y)
    Z = np.zeros((100, 100))
    for i in range(100):
        for j in range(100):
            Z[i, j] = np.exp(logdensity_fn(np.array([X[i, j], Y[i, j]])))
    
    plt.figure(figsize=(12, 10))
    
    # Plot the target distribution
    plt.subplot(2, 2, 1)
    plt.contour(X, Y, Z, levels=20)
    plt.colorbar(label='Density')
    plt.plot(positions[:, 0], positions[:, 1], 'r-', label='Trajectory')
    plt.plot(positions[0, 0], positions[0, 1], 'go', label='Start')
    plt.plot(positions[-1, 0], positions[-1, 1], 'bo', label='End')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Position Trajectory')
    plt.legend()
    plt.axis('equal')
    
    # Plot the momentum trajectory
    plt.subplot(2, 2, 2)
    plt.plot(momenta[:, 0], momenta[:, 1], 'b-', label='Trajectory')
    plt.plot(momenta[0, 0], momenta[0, 1], 'go', label='Start')
    plt.plot(momenta[-1, 0], momenta[-1, 1], 'ro', label='End')
    plt.xlabel('p_x')
    plt.ylabel('p_y')
    plt.title('Momentum Trajectory')
    plt.legend()
    plt.axis('equal')
    
    # Plot position vs time
    plt.subplot(2, 2, 3)
    time = np.arange(num_steps) * step_size
    plt.plot(time, positions[:, 0], 'r-', label='x')
    plt.plot(time, positions[:, 1], 'b-', label='y')
    plt.xlabel('Time')
    plt.ylabel('Position')
    plt.title('Position vs Time')
    plt.legend()
    
    # Plot momentum vs time
    plt.subplot(2, 2, 4)
    plt.plot(time, momenta[:, 0], 'r-', label='p_x')
    plt.plot(time, momenta[:, 1], 'b-', label='p_y')
    plt.xlabel('Time')
    plt.ylabel('Momentum')
    plt.title('Momentum vs Time')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

plot_trajectory(positions, momenta)

## Energy Conservation

One of the key properties of Hamiltonian dynamics is energy conservation. Let's check if our integrator preserves energy:

In [None]:
# Calculate the total energy (potential + kinetic) at each step
def calculate_energy(states):
    # Potential energy is the negative log-density
    potential_energy = -jnp.array([state.logdensity for state in states])
    
    # Kinetic energy
    kinetic_energy = jnp.array([kinetic_energy_fn(state.momentum) for state in states])
    
    # Total energy
    total_energy = potential_energy + kinetic_energy
    
    return potential_energy, kinetic_energy, total_energy

potential_energy, kinetic_energy, total_energy = calculate_energy(states)

# Plot the energy components
plt.figure(figsize=(12, 6))
time = np.arange(num_steps) * step_size
plt.plot(time, potential_energy, 'r-', label='Potential Energy')
plt.plot(time, kinetic_energy, 'b-', label='Kinetic Energy')
plt.plot(time, total_energy, 'g-', label='Total Energy')
plt.xlabel('Time')
plt.ylabel('Energy')
plt.title('Energy Conservation')
plt.legend()
plt.grid(True)
plt.show()

# Calculate the relative energy error
initial_total_energy = total_energy[0]
relative_energy_error = jnp.abs(total_energy - initial_total_energy) / jnp.abs(initial_total_energy)
max_relative_error = jnp.max(relative_energy_error)
print(f"Maximum relative energy error: {max_relative_error:.6f}")

## Time Reversibility

Another important property of Hamiltonian dynamics is time reversibility. Let's check if our integrator is time-reversible by running it forward and then backward:

In [None]:
# Function to run the integrator backward (by negating the step size)
def run_integrator_backward(final_state, integrator, step_size, num_steps):
    def one_step(state, _):
        new_state = integrator(state, -step_size)  # Negative step size for backward integration
        return new_state, new_state
    
    _, states = jax.lax.scan(one_step, final_state, None, length=num_steps)
    return states

# Run the integrator forward
forward_states = run_integrator(initial_state, integrator, step_size, num_steps)
final_state = forward_states[-1]

# Run the integrator backward
backward_states = run_integrator_backward(final_state, integrator, step_size, num_steps)
reversed_state = backward_states[-1]

# Check if we've returned to the initial state
position_error = jnp.linalg.norm(initial_state.position - reversed_state.position)
momentum_error = jnp.linalg.norm(initial_state.momentum - reversed_state.momentum)

print(f"Position error: {position_error:.10f}")
print(f"Momentum error: {momentum_error:.10f}")

# Visualize the forward and backward trajectories
def plot_reversibility(forward_states, backward_states):
    forward_positions = jnp.array([state.position for state in forward_states])
    backward_positions = jnp.array([state.position for state in backward_states])
    
    x = np.linspace(-4, 4, 100)
    y = np.linspace(-4, 4, 100)
    X, Y = np.meshgrid(x, y)
    Z = np.zeros((100, 100))
    for i in range(100):
        for j in range(100):
            Z[i, j] = np.exp(logdensity_fn(np.array([X[i, j], Y[i, j]])))
    
    plt.figure(figsize=(10, 8))
    plt.contour(X, Y, Z, levels=20)
    plt.colorbar(label='Density')
    plt.plot(forward_positions[:, 0], forward_positions[:, 1], 'r-', label='Forward')
    plt.plot(backward_positions[:, 0], backward_positions[:, 1], 'b--', label='Backward')
    plt.plot(forward_positions[0, 0], forward_positions[0, 1], 'go', label='Start')
    plt.plot(forward_positions[-1, 0], forward_positions[-1, 1], 'bo', label='End')
    plt.plot(backward_positions[-1, 0], backward_positions[-1, 1], 'ro', label='Reversed')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Time Reversibility')
    plt.legend()
    plt.axis('equal')
    plt.show()

plot_reversibility(forward_states, backward_states)

## Equivalence of Mass Matrix and Covariance Matrix

Now, let's demonstrate an interesting property of Hamiltonian dynamics: using a mass matrix that is the inverse of the covariance matrix is equivalent to using a unit mass matrix on a Gaussian with unit covariance.

This is a fundamental insight in Hamiltonian Monte Carlo, as it allows us to transform the problem into a simpler one with isotropic dynamics.

### Mathematical Explanation

For a Gaussian target with covariance matrix $\Sigma$, the log-density is:

$$\log p(x) = -\frac{1}{2}(x-\mu)^T \Sigma^{-1} (x-\mu) + \text{const}$$

If we use a mass matrix $M = \Sigma$, the Hamiltonian becomes:

$$H(x, p) = -\log p(x) + \frac{1}{2}p^T M^{-1} p = \frac{1}{2}(x-\mu)^T \Sigma^{-1} (x-\mu) + \frac{1}{2}p^T \Sigma^{-1} p$$

Now, let's transform the variables:

$$x' = \Sigma^{-1/2}(x-\mu)$$
$$p' = \Sigma^{1/2}p$$

The Hamiltonian in the transformed variables becomes:

$$H(x', p') = \frac{1}{2}x'^T x' + \frac{1}{2}p'^T p'$$

This is the Hamiltonian for a standard Gaussian with unit covariance and unit mass matrix.

Let's demonstrate this equivalence numerically:

In [None]:
# Define a function to transform coordinates
def transform_coordinates(position, momentum, cov):
    # Compute the Cholesky decomposition of the covariance matrix
    L = jnp.linalg.cholesky(cov)
    L_inv = jnp.linalg.inv(L)
    
    # Transform position and momentum
    position_transformed = L_inv @ (position - mean)
    momentum_transformed = L.T @ momentum
    
    return position_transformed, momentum_transformed

# Define a function to inverse transform coordinates
def inverse_transform_coordinates(position_transformed, momentum_transformed, cov):
    # Compute the Cholesky decomposition of the covariance matrix
    L = jnp.linalg.cholesky(cov)
    L_inv = jnp.linalg.inv(L)
    
    # Inverse transform position and momentum
    position = L @ position_transformed + mean
    momentum = L_inv.T @ momentum_transformed
    
    return position, momentum

# Define a standard Gaussian log-density function (unit covariance)
def standard_logdensity_fn(position):
    return stats.multivariate_normal.logpdf(position, jnp.zeros(2), jnp.eye(2))

# Set up three integrators:
# 1. Original Gaussian with unit mass matrix
original_metric = metrics.default_metric(jnp.eye(2))
original_integrator = integrators.velocity_verlet(logdensity_fn, original_metric.kinetic_energy)

# 2. Original Gaussian with mass matrix = covariance matrix
cov_mass_metric = metrics.default_metric(cov)
cov_mass_integrator = integrators.velocity_verlet(logdensity_fn, cov_mass_metric.kinetic_energy)

# 3. Standard Gaussian (unit covariance) with unit mass matrix
standard_metric = metrics.default_metric(jnp.eye(2))
standard_integrator = integrators.velocity_verlet(standard_logdensity_fn, standard_metric.kinetic_energy)

# Set initial conditions
initial_position = jnp.array([2.0, 2.0])
initial_momentum = jnp.array([0.5, -0.3])

# Transform initial conditions for the standard Gaussian
transformed_position, transformed_momentum = transform_coordinates(initial_position, initial_momentum, cov)

# Create initial states
original_state = integrators.new_integrator_state(logdensity_fn, initial_position, initial_momentum)
cov_mass_state = integrators.new_integrator_state(logdensity_fn, initial_position, initial_momentum)
standard_state = integrators.new_integrator_state(standard_logdensity_fn, transformed_position, transformed_momentum)

# Run the integrators
step_size = 0.1
num_steps = 50

original_states = run_integrator(original_state, original_integrator, step_size, num_steps)
cov_mass_states = run_integrator(cov_mass_state, cov_mass_integrator, step_size, num_steps)
standard_states = run_integrator(standard_state, standard_integrator, step_size, num_steps)

# Extract positions
original_positions = jnp.array([state.position for state in original_states])
cov_mass_positions = jnp.array([state.position for state in cov_mass_states])
standard_positions = jnp.array([state.position for state in standard_states])

# Transform standard positions back to original space
transformed_standard_positions = jnp.array([
    inverse_transform_coordinates(pos, jnp.zeros(2), cov)[0] for pos in standard_positions
])

# Compare the trajectories
def plot_equivalent_trajectories():
    x = np.linspace(-4, 4, 100)
    y = np.linspace(-4, 4, 100)
    X, Y = np.meshgrid(x, y)
    Z = np.zeros((100, 100))
    for i in range(100):
        for j in range(100):
            Z[i, j] = np.exp(logdensity_fn(np.array([X[i, j], Y[i, j]])))
    
    plt.figure(figsize=(15, 5))
    
    # Plot original Gaussian with unit mass matrix
    plt.subplot(1, 3, 1)
    plt.contour(X, Y, Z, levels=20)
    plt.plot(original_positions[:, 0], original_positions[:, 1], 'r-', label='Trajectory')
    plt.plot(original_positions[0, 0], original_positions[0, 1], 'go', label='Start')
    plt.plot(original_positions[-1, 0], original_positions[-1, 1], 'bo', label='End')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Original Gaussian\nUnit Mass Matrix')
    plt.legend()
    plt.axis('equal')
    
    # Plot original Gaussian with mass matrix = covariance
    plt.subplot(1, 3, 2)
    plt.contour(X, Y, Z, levels=20)
    plt.plot(cov_mass_positions[:, 0], cov_mass_positions[:, 1], 'r-', label='Trajectory')
    plt.plot(cov_mass_positions[0, 0], cov_mass_positions[0, 1], 'go', label='Start')
    plt.plot(cov_mass_positions[-1, 0], cov_mass_positions[-1, 1], 'bo', label='End')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Original Gaussian\nMass Matrix = Covariance')
    plt.legend()
    plt.axis('equal')
    
    # Plot standard Gaussian with unit mass matrix (transformed back)
    plt.subplot(1, 3, 3)
    plt.contour(X, Y, Z, levels=20)
    plt.plot(transformed_standard_positions[:, 0], transformed_standard_positions[:, 1], 'r-', label='Trajectory')
    plt.plot(transformed_standard_positions[0, 0], transformed_standard_positions[0, 1], 'go', label='Start')
    plt.plot(transformed_standard_positions[-1, 0], transformed_standard_positions[-1, 1], 'bo', label='End')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Standard Gaussian (Transformed)\nUnit Mass Matrix')
    plt.legend()
    plt.axis('equal')
    
    plt.tight_layout()
    plt.show()

plot_equivalent_trajectories()

# Calculate the differences between trajectories
original_cov_mass_diff = jnp.mean(jnp.abs(original_positions - cov_mass_positions))
original_standard_diff = jnp.mean(jnp.abs(original_positions - transformed_standard_positions))
cov_mass_standard_diff = jnp.mean(jnp.abs(cov_mass_positions - transformed_standard_positions))

print(f"Average difference between original and cov_mass trajectories: {original_cov_mass_diff:.6f}")
print(f"Average difference between original and standard trajectories: {original_standard_diff:.6f}")
print(f"Average difference between cov_mass and standard trajectories: {cov_mass_standard_diff:.6f}")

# Plot the differences over time
plt.figure(figsize=(12, 6))
time = np.arange(num_steps) * step_size
plt.plot(time, jnp.abs(original_positions - cov_mass_positions).mean(axis=1), 'r-', label='Original vs Cov Mass')
plt.plot(time, jnp.abs(original_positions - transformed_standard_positions).mean(axis=1), 'b-', label='Original vs Standard')
plt.plot(time, jnp.abs(cov_mass_positions - transformed_standard_positions).mean(axis=1), 'g-', label='Cov Mass vs Standard')
plt.xlabel('Time')
plt.ylabel('Average Absolute Difference')
plt.title('Trajectory Differences Over Time')
plt.legend()
plt.grid(True)
plt.show()

## Conclusion

In this notebook, we've demonstrated how to use the velocity verlet integrator from Blackjax to simulate Hamiltonian dynamics for a Gaussian target distribution. We've shown that:

1. The integrator can be used to simulate the trajectory of a particle in the potential energy landscape defined by the negative log-density of the target distribution.
2. The integrator approximately conserves energy, with small errors due to the numerical approximation.
3. The integrator is time-reversible, meaning that running it forward and then backward returns to the initial state (up to numerical errors).
4. Using a mass matrix that is the inverse of the covariance matrix is equivalent to using a unit mass matrix on a Gaussian with unit covariance.

These properties make the velocity verlet integrator a good choice for Hamiltonian Monte Carlo, where we want to simulate Hamiltonian dynamics to propose new states in the Markov chain. The equivalence between mass matrices and covariance matrices is particularly useful for designing efficient samplers, as it allows us to transform complex target distributions into simpler ones with isotropic dynamics.