In [8]:
import jax
import jax.numpy as jnp
import jax.scipy.special
import numpy as np
from typing import Tuple, Callable

import os
# only use gpu 4
os.environ["CUDA_VISIBLE_DEVICES"] = "4" 

# Define the L2 (Squared Euclidean) ground cost function
def l2_ground_cost(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  """Computes the squared Euclidean distance between two sets of vectors.

  Args:
    x: Array of shape (n, d)
    y: Array of shape (m, d)

  Returns:
    Cost matrix of shape (n, m) where C[i, j] = ||x[i] - y[j]||^2
  """
  return jnp.sum((x[:, None, :] - y[None, :, :])**2, axis=-1)

# Define the soft-min operator using logsumexp
def _softmin(x: jnp.ndarray, gamma: float, axis: int = -1) -> jnp.ndarray:
  """Computes the soft-minimum of an array along a given axis.

  Args:
    x: Input array.
    gamma: Smoothing parameter > 0.
    axis: Axis along which to compute the soft-minimum.

  Returns:
    The soft-minimum of x along the specified axis.
  """
  return -gamma * jax.scipy.special.logsumexp(-x / gamma, axis=axis)

# Core SoftDTW computation (internal helper function)
def _compute_soft_dtw(
    t1: jnp.ndarray,
    t2: jnp.ndarray,
    gamma: float,
    ground_cost_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
) -> float:
  """Computes the core Soft-DTW value between two time series.

  Args:
    t1: First time series, shape (n,) or (n, d).
    t2: Second time series, shape (m,) or (m, d).
    gamma: Smoothing parameter > 0.
    ground_cost_fn: Function to compute the pairwise cost matrix.

  Returns:
    The Soft-DTW value.
  """

  # --- Scan Body Definition ---
  def _body_fn(
      carry: Tuple[jnp.ndarray, jnp.ndarray],
      current_antidiagonal: jnp.ndarray
  ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
    """Body function for the jax.lax.scan over anti-diagonals."""
    two_ago, one_ago = carry # Values from anti-diagonals k-2 and k-1

    # Values needed for the current DP update step (indices relative to next_row)
    # diagonal corresponds to V(i-1, j-1) -> element two_ago[k] for next_row[k]
    # right corresponds to V(i, j-1)    -> element one_ago[k] for next_row[k]
    # down corresponds to V(i-1, j)     -> element one_ago[k+1] for next_row[k]
    diagonal = two_ago[:-1] # Shape (k,)
    right = one_ago[:-1]    # Shape (k,)
    down = one_ago[1:]      # Shape (k,)

    # Stack costs from possible previous states for soft-min calculation
    # Shape (k, 3)
    stacked_costs = jnp.stack([diagonal, right, down], axis=-1)

    # Apply soft-min to find the best path cost so far
    # Shape (k,)
    best_prev_cost = _softmin(stacked_costs, gamma, axis=-1)

    # Add the ground cost from the current anti-diagonal
    # Shape (k,)
    next_row_core = best_prev_cost + current_antidiagonal

    # Pad with infinity for boundary conditions in the next iteration
    # Shape (k+1+1,) -> (k+2,) which matches the expected length of next_row
    next_row = jnp.pad(next_row_core, (1, 0), constant_values=jnp.inf)

    return (one_ago, next_row), next_row # Pass previous and current row for next step

  # --- Preprocessing ---
  # Ensure inputs are at least 2D (n_timepoints, n_features)
  t1 = jnp.atleast_2d(t1)
  t2 = jnp.atleast_2d(t2)
  if t1.ndim == 1:
      t1 = t1[:, None]
  if t2.ndim == 1:
      t2 = t2[:, None]

  # Calculate the pairwise ground cost matrix
  dist = ground_cost_fn(t1, t2)
  n, m = dist.shape

  # Ensure n >= m for the anti-diagonal scan setup
  swapped = False
  if n < m:
    dist = dist.T
    n, m = m, n
    swapped = True # Keep track if we swapped

  # --- Anti-diagonal Matrix Construction ---
  # Create a matrix where anti-diagonals of 'dist' become rows
  # Size: (n + m - 1) anti-diagonals, max length n
  model_matrix = jnp.full((n + m - 1, n), fill_value=jnp.inf)

  # Use np.tri to create boolean masks for placing elements
  # Note: Using np here is acceptable as it generates static masks based on shape
  mask_bottom_right = np.tri(n + m - 1, n, k=0, dtype=bool)
  mask_top_left = np.tri(n + m - 1, n, k=(n + m - 1) - m, dtype=bool) # Equivalent to mask_bottom_right[::-1, ::-1] but direct
  mask = mask_bottom_right & mask_top_left # Valid entries for anti-diagonals

  # Place the flattened distance matrix onto the valid spots in the model_matrix
  # Transpose mask to match assignment pattern if assigning to model_matrix.T
  model_matrix = model_matrix.T.at[mask.T].set(dist.ravel()).T

  # --- Scan Initialization ---
  # Initial 'carry' state for the scan: (V_{k=-1}, V_{k=0}) effectively
  # Requires careful padding based on how the scan body uses indices.
  # init_two_ago corresponds to the 'zeroth' anti-diagonal (padded)
  # init_one_ago corresponds to the first anti-diagonal V(i,j) where i+j=1 (plus costs)
  init_two_ago = jnp.pad(model_matrix[0], (1, 0), constant_values=jnp.inf)
  # The first step combines model_matrix[0,0] and model_matrix[1]
  init_one_ago = jnp.pad(
      model_matrix[1] + model_matrix[0, 0], # Cost V(1,0) and V(0,1)
      (1, 0),
      constant_values=jnp.inf
  )
  init = (init_two_ago, init_one_ago)

  # --- Run the Scan ---
  # Iterate over the remaining anti-diagonals (from k=2 upwards)
  (_, final_row), _ = jax.lax.scan(
      _body_fn, init, model_matrix[2:], # Pass remaining anti-diagonal costs
  )

  # The final value is the last element of the last computed row (V[n, m])
  # Adjust index if we swapped n and m earlier (though the value should be the same)
  result = final_row[-1]

  return result


# Main public function for SoftDTW
def soft_dtw(
    x: jnp.ndarray,
    y: jnp.ndarray,
    gamma: float = 1.0,
    debiased: bool = False
) -> float:
  """Computes the Soft-DTW divergence between two time series.

  Args:
    x: First time series, shape (n,) or (n, d).
    y: Second time series, shape (m,) or (m, d).
    gamma: Smoothing parameter > 0. Defaults to 1.0.
    debiased: Whether to compute the debiased Soft-DTW. Defaults to False.

  Returns:
    The Soft-DTW value (potentially debiased).
  """
  # Gamma must be positive.

  # Use the L2 ground cost
  ground_cost_fn = l2_ground_cost

  # Compute main Soft-DTW(x, y)
  val_xy = _compute_soft_dtw(x, y, gamma, ground_cost_fn)

  if debiased:
    # Compute Soft-DTW(x, x)
    val_xx = _compute_soft_dtw(x, x, gamma, ground_cost_fn)
    # Compute Soft-DTW(y, y)
    val_yy = _compute_soft_dtw(y, y, gamma, ground_cost_fn)
    # Return debiased value
    return val_xy - 0.5 * (val_xx + val_yy)
  else:
    # Return standard value
    return val_xy

# Example Usage:
if __name__ == "__main__":
    x = jnp.linspace(0, 1, 100)
    y1 = jnp.sin(x)
    y2 = jnp.sin(x + 0.5)
    y3 = jnp.sin(x + 0.7)
    y4 = jnp.sin(0.7 * x)
    y5 = jnp.sin(0.9 * x)	

    dtw_1 = soft_dtw(y1, y2, gamma=0.1)
    dtw_2 = soft_dtw(y1, y3, gamma=0.1)
    dtw_3 = soft_dtw(y1, y4, gamma=0.1)
    dtw_4 = soft_dtw(y1, y5, gamma=0.1)

    l2_1 = jnp.sum((y1 - y2)**2)
    l2_2 = jnp.sum((y1 - y3)**2)
    l2_3 = jnp.sum((y1 - y4)**2)
    l2_4 = jnp.sum((y1 - y5)**2)

    print(f"Soft-DTW(y1, y2) = {dtw_1}")
    print(f"Soft-DTW(y1, y3) = {dtw_2}")
    print(f"Soft-DTW(y1, y4) = {dtw_3}")
    print(f"Soft-DTW(y1, y5) = {dtw_4}")

    print(f"L2(y1, y2) = {l2_1}")
    print(f"L2(y1, y3) = {l2_2}")
    print(f"L2(y1, y4) = {l2_3}")
    print(f"L2(y1, y5) = {l2_4}")
    


Soft-DTW(y1, y2) = 25.935562133789062
Soft-DTW(y1, y3) = 41.95087432861328
Soft-DTW(y1, y4) = 3.7997639179229736
Soft-DTW(y1, y5) = 0.37580013275146484
L2(y1, y2) = 12.967781066894531
L2(y1, y3) = 20.97543716430664
L2(y1, y4) = 1.8998819589614868
L2(y1, y5) = 0.18790006637573242
