In [None]:
import jax.numpy as jnp
from typing import Tuple
import jax

import numpy as np

import matplotlib.pyplot as plt


def _softmin(v: jnp.ndarray, gamma: float, axis: int = -1) -> jnp.ndarray:
  """JAX implementation of soft-min operator using logsumexp for stability."""
  return -gamma * jax.nn.logsumexp(-v / gamma, axis=axis)

def _soft_dtw(t1: jnp.ndarray, t2: jnp.ndarray, gamma) -> float:

  def body(
      carry: Tuple[jnp.ndarray, jnp.ndarray],
      current_antidiagonal: jnp.ndarray
  ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
    # modified from: https://github.com/khdlr/softdtw_jax
    two_ago, one_ago = carry

    diagonal, right, down = two_ago[:-1], one_ago[:-1], one_ago[1:]
    best = _softmin(
        jnp.stack([diagonal, right, down], axis=-1), gamma, axis=-1
    )

    next_row = best + current_antidiagonal
    next_row = jnp.pad(next_row, (1, 0), constant_values=jnp.inf)

    return (one_ago, next_row), next_row

  # calculate euclidian pairwise distance matrix
  dist = (t1[:, None] - t2[None, :]) ** 2

  n, m = dist.shape
  if n < m:
    dist = dist.T
    n, m = m, n

  model_matrix = jnp.full((n + m - 1, n), fill_value=jnp.inf)
  mask = np.tri(n + m - 1, n, k=0, dtype=bool)
  mask = mask & mask[::-1, ::-1]
  model_matrix = model_matrix.T.at[mask.T].set(dist.ravel()).T

  init = (
      jnp.pad(model_matrix[0], (1, 0), constant_values=jnp.inf),
      jnp.pad(
          model_matrix[1] + model_matrix[0, 0], (1, 0),
          constant_values=jnp.inf
      )
  )

  (_, carry), _ = jax.lax.scan(body, init, model_matrix[2:])
  return carry[-1]

def _debiased_soft_dtw(t1: jnp.ndarray, t2: jnp.ndarray, gamma: float) -> float:
    dtw12 = _soft_dtw(t1, t2, gamma)
    dtw11 = _soft_dtw(t1, t1, gamma)
    dtw22 = _soft_dtw(t2, t2, gamma)
    return dtw12 - 0.5 * (dtw11 + dtw22)

def get_chirped_signal(time, phase_velocity, phase_acceleration):
  return jnp.sin(time * phase_velocity + 0.5 * phase_acceleration * time ** 2)

# --- New functions ---
def l2_loss(t1: jnp.ndarray, t2: jnp.ndarray) -> float:
  """Calculates the Mean Squared Error (L2 loss)."""
  return jnp.mean((t1 - t2)**2)

# --- Parameters ---
time = jnp.linspace(0.0, 6.0, 100) # Reduced points for faster computation
true_velocity = 4.0
true_acceleration = -0.1
true_signal = get_chirped_signal(time, true_velocity, true_acceleration)

# Parameter range for comparison
velocities = jnp.linspace(3.0, 5.0, 50) # Grid resolution
accelerations = jnp.linspace(-0.2, 0.2, 50) # Grid resolution
V, A = jnp.meshgrid(velocities, accelerations)

# Soft-DTW parameter
gamma = 0.5

# --- Loss Calculation (Vectorized) ---

# Function to compute losses for a single (v, a) pair
@jax.jit
def compute_losses_for_params(v, a):
    test_signal = get_chirped_signal(time, v, a)
    sdtw_loss = _debiased_soft_dtw(true_signal, test_signal, gamma)
    mse_loss = l2_loss(true_signal, test_signal)
    return sdtw_loss, mse_loss

# Vectorize the function to work over the grid
# We map over the first axis of V and A (which are 2D arrays from meshgrid)
# This effectively calculates losses for each (v, a) pair in the grid
vectorized_compute_losses = jax.vmap(jax.vmap(compute_losses_for_params, in_axes=(0, 0)), in_axes=(0, 0))

# Calculate losses over the entire grid
# Convert meshgrid results to JAX arrays before passing
sdtw_losses_grid, l2_losses_grid = vectorized_compute_losses(jnp.array(V), jnp.array(A))

# --- Visualization ---
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Plot L2 Loss
cf1 = axes[0].contourf(V, A, l2_losses_grid, levels=50, cmap='viridis')
fig.colorbar(cf1, ax=axes[0], label='Loss Value')
axes[0].contour(V, A, l2_losses_grid, levels=10, colors='white', alpha=0.5, linewidths=0.5)
axes[0].plot(true_velocity, true_acceleration, 'ro', markersize=8, label='True Parameters')
axes[0].set_title(f'L2 (MSE) Loss Landscape')
axes[0].set_xlabel('Phase Velocity')
axes[0].set_ylabel('Phase Acceleration')
axes[0].legend()
axes[0].grid(True, linestyle='--', alpha=0.6)


# Plot Soft-DTW Loss
# Use a potentially different level scaling for better visualization if needed
# Clipping max value can help if outliers dominate color scale
max_sdtw_plot = jnp.percentile(sdtw_losses_grid, 99) # Clip to 99th percentile for better contrast
cf2 = axes[1].contourf(V, A, jnp.clip(sdtw_losses_grid, a_max=max_sdtw_plot), levels=50, cmap='viridis')
fig.colorbar(cf2, ax=axes[1], label='Loss Value')
axes[1].contour(V, A, jnp.clip(sdtw_losses_grid, a_max=max_sdtw_plot), levels=10, colors='white', alpha=0.5, linewidths=0.5)
axes[1].plot(true_velocity, true_acceleration, 'ro', markersize=8, label='True Parameters')
axes[1].set_title(f'Debiased Soft-DTW Loss Landscape (gamma={gamma})')
axes[1].set_xlabel('Phase Velocity')
axes[1].set_ylabel('Phase Acceleration')
axes[1].legend()
axes[1].grid(True, linestyle='--', alpha=0.6)


plt.tight_layout()
plt.show()

# --- Comparison Summary ---

print("\nComparison Summary:")
print("-" * 20)
print(f"True Parameters: Velocity={true_velocity}, Acceleration={true_acceleration}")
min_l2_idx = jnp.unravel_index(jnp.argmin(l2_losses_grid), l2_losses_grid.shape)
min_sdtw_idx = jnp.unravel_index(jnp.argmin(sdtw_losses_grid), sdtw_losses_grid.shape)

print(f"L2 Loss Minimum Found at: Velocity={V[min_l2_idx]:.3f}, Acceleration={A[min_l2_idx]:.3f} (Value: {l2_losses_grid[min_l2_idx]:.4f})")
print(f"Soft-DTW Loss Minimum Found at: Velocity={V[min_sdtw_idx]:.3f}, Acceleration={A[min_sdtw_idx]:.3f} (Value: {sdtw_losses_grid[min_sdtw_idx]:.4f})")

Soft-DTW(y1, y2) = 25.935562133789062
Soft-DTW(y1, y3) = 41.95087432861328
Soft-DTW(y1, y4) = 3.7997639179229736
Soft-DTW(y1, y5) = 0.37580013275146484
L2(y1, y2) = 12.967781066894531
L2(y1, y3) = 20.97543716430664
L2(y1, y4) = 1.8998819589614868
L2(y1, y5) = 0.18790006637573242
