In [1]:
## Uncomment the following to install necessary pip packages into colab
# !pip install -q ipympl jaxtyping

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Float, Array
import ipywidgets as widgets
from IPython.display import display

# Enable widgets if in colab
try:
    from google.colab import output
    output.enable_custom_widget_manager()
except:
    pass


Define custom energy functions

In [3]:
def epa_energy(q: Float[Array, "D"], # query
               Xi: Float[Array, "M D"], # memories
               beta: Float = 0.5, # beta scaling param
               eps: Float = 1e-3, # numerical stability in the log
               lmda: Float = 1e-6 # L2 regularization on the query
               ):
    return (
        -(1 / beta * jnp.log(jnp.sum(jax.nn.relu(1 - 0.5 * beta * jnp.sum((q - Xi)**2, axis=-1))) + eps)) + (lmda * jnp.sum(q**2))
    )


def gaussian_energy(q: Float[Array, "D"], Xi: Float[Array, "M D"], beta: Float = 0.5):
    return (
        -1 / beta * jax.nn.logsumexp(
            -0.5 * beta * jnp.sum((q - Xi)**2, axis=-1)
        )
    )

Simple! Let's quickly test this idea.

## Interactive Widget comparing EPA kernel and Gaussian Kernel on unit circle

In [None]:
%matplotlib widget

# Create initial plot setup with two subplots
fig = plt.figure(figsize=(6, 14))
fig.tight_layout()  # Adjust subplot spacing
fig.subplots_adjust(top=0.85)  # Make room for the suptitle

ax1 = fig.add_subplot(311, projection='3d')  # 3D surface plot
ax2 = fig.add_subplot(312)  # Contour plot
ax3 = fig.add_subplot(313)  # Radial energy plot

# Create grid of points
x = np.linspace(-1.5, 1.5, 200)
y = np.linspace(-1.5, 1.5, 200)
X, Y = np.meshgrid(x, y)
points = jnp.stack([X.flatten(), Y.flatten()], axis=-1)

# Create points on unit circle for radial plot
phis = jnp.linspace(-np.pi, np.pi, 300)
xs_circle = jnp.stack([jnp.cos(phis), jnp.sin(phis)], axis=-1)

beta_init = 7.
M_max = 100
M_init = 8 
SEED = 4
all_memories = jr.normal(jr.PRNGKey(SEED), shape=(M_max, 2))
all_memories = all_memories / jnp.linalg.norm(all_memories, axis=-1, keepdims=True)

# Update function
def update_plot(beta, num_mems, energy_fn):
    ax1.clear()
    ax2.clear()
    ax3.clear()
    
    # Update memories
    current_mems = all_memories[:num_mems]
    # energy_fn = energy_dropdown.value

    # Recalculate energies
    energies = jax.vmap(energy_fn, in_axes=(0, None, None))(points, current_mems, beta)
    Z = energies.reshape(X.shape)
    
    # Calculate energies along unit circle
    circle_energies = jax.vmap(energy_fn, in_axes=(0, None, None))(xs_circle, current_mems, beta)
    mem_energies = jax.vmap(energy_fn, in_axes=(0, None, None))(current_mems, current_mems, beta)
    mem_phis = jnp.arctan2(current_mems[:, 1], current_mems[:, 0])
    
    # Update surface plot
    surface = ax1.plot_surface(X, Y, Z, cmap='viridis', alpha=0.8)
    ax1.scatter(current_mems[:, 0], current_mems[:, 1], mem_energies,
            color='red', s=100)
    
    # Update contour plot
    contour = ax2.contour(X, Y, Z, levels=30, cmap='viridis')
    ax2.scatter(current_mems[:, 0], current_mems[:, 1], 
               color='red', s=100)
    # Add phi=0 annotation
    ax2.scatter([0], [0], color='black', s=20, zorder=3.)
    ax2.plot([0, 1.5], [0, 0], '--', color='black', alpha=0.5)
    ax2.annotate('φ=0', xy=(1.5, 0), xytext=(1.6, 0.1),
                arrowprops=dict(facecolor='black', shrink=0.05, headwidth=4, headlength=4),
                bbox=dict(boxstyle='round,pad=0.5', fc='white', alpha=0.8))
    ax2.axis('equal')
    # Show unit circle
    circle = plt.Circle((0, 0), 1, fill=False, color='gray', alpha=0.3)
    ax2.add_patch(circle)
    ax2.set_xlim(x.min(), x.max())
    ax2.set_ylim(y.min(), y.max())
    
    # Update radial energy plot
    ax3.plot(phis, circle_energies)
    ax3.scatter(mem_phis, mem_energies, color='red', zorder=3)
    ax3.axvline(x=0, color='black', linestyle='--', alpha=0.5)  # Add vertical line at phi=0

    
    # Set labels and titles
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_zlabel('Energy')
    ax2.set_xlabel('X')
    ax2.set_ylabel('Y')
    ax3.set_xlabel('ϕ (radians)')
    ax3.set_ylabel('Energy')
    
    # Update titles
    energy_name = 'EPA Energy' if energy_fn == epa_energy else 'Gaussian Energy'
    fig.suptitle(f'{energy_name} (beta={beta:.1f}, memories={num_mems})', y=0.97)

    ax1.set_title('3D Surface Energy')
    ax2.set_title('Contour Energy')
    ax3.set_title('Radial Energy')
    
    
    # Update y-limits for radial plot
    yrange = circle_energies.max() - circle_energies.min()
    ybuff = 0.05
    ax3.set_ylim(circle_energies.min() - ybuff * yrange, 
                 circle_energies.max() + ybuff * yrange)
    
    fig.canvas.draw_idle()

# Create sliders
energy_dropdown = widgets.Dropdown(
    options=[('EPA Energy', epa_energy), ('Gaussian Energy', gaussian_energy)],
    value=epa_energy,
    description='Energy:',
    style={'description_width': 'initial'}
)

beta_slider = widgets.FloatSlider(
    value=beta_init,
    min=0.1,
    max=40.0,
    step=0.1,
    description='Beta:',
    continuous_update=False
)

num_mem_slider = widgets.IntSlider(
    value=M_init,
    min=1,
    max=M_max,
    step=1,
    description='Nmems:',
    continuous_update=False
)


# Plot initial parameters 
update_plot(beta_init, M_init, epa_energy)

# Slider update functions
_curr_beta = beta_init
_curr_M = M_init
_curr_energy = epa_energy

def update_beta(change):
    global _curr_beta, _curr_M, _curr_energy
    _curr_beta = change['new']
    update_plot(_curr_beta, _curr_M, _curr_energy)

def update_mems(change):
    global _curr_beta, _curr_M, _curr_energy
    _curr_M = change['new']
    update_plot(_curr_beta, _curr_M, _curr_energy)

def update_energy(change):
    global _curr_beta, _curr_M, _curr_energy
    _curr_energy = change['new']
    update_plot(_curr_beta, _curr_M, _curr_energy)

beta_slider.observe(update_beta, names='value')
num_mem_slider.observe(update_mems, names='value')
energy_dropdown.observe(update_energy, names='value')

# Display widgets and plot
display(energy_dropdown)
display(beta_slider)
display(num_mem_slider)
plt.show()