Gradient-based Optimal Transport Optimization -- Sliced Wasserstein API for Pipelines
A unified, lightweight API for sliced Wasserstein image targets with multi-backend support (JAX and PyTorch). This package makes it easy to integrate sliced Wasserstein distance computations into your machine learning pipelines.
- 🚀 Multi-Backend Support: Choose between JAX and PyTorch backends
- ⚡ High Performance: JIT compilation with JAX, optimized PyTorch operations
- 🔧 Easy Integration: Simple API that works with existing ML workflows
- 📦 Lightweight: Optional dependencies - only install what you need
- 🧮 Differentiable: Full automatic differentiation support
- 🎯 Specialized: 1D Earth Mover's Distance (EMD) implementations
pip install goto-swapJAX Backend:
pip install goto-swap[jax]PyTorch Backend:
pip install goto-swap[pytorch]All Backends:
pip install goto-swap[all]import numpy as np
from goto_swap import wasserstein_distance_loss_images, emd_1d_sorted_differentiable
# Prepare your data
x = np.random.randn(2, 32, 32) # Predicted images
x_true = np.random.randn(2, 32, 32) # True images
# JAX backend (default)
loss = wasserstein_distance_loss_images(x, x_true)
# PyTorch backend
loss = wasserstein_distance_loss_images(x, x_true, backend="pytorch")import jax.numpy as jnp
from goto_swap import emd_1d_sorted_differentiable
# Sorted positions and weights
u = jnp.array([1.0, 2.0, 3.0]) # Source positions
v = jnp.array([1.5, 2.5, 3.5]) # Target positions
u_weights = jnp.array([0.3, 0.4, 0.3]) # Source weights
v_weights = jnp.array([0.2, 0.5, 0.3]) # Target weights
# Compute EMD
emd_loss = emd_1d_sorted_differentiable(u, v, u_weights, v_weights)from goto_swap import get_backend
# Get specific backend
jax_backend = get_backend("jax")
pytorch_backend = get_backend("pytorch")
# Use backend directly
loss_jax = jax_backend.wasserstein_distance_loss_images(x, x_true)
loss_pytorch = pytorch_backend.wasserstein_distance_loss_images(x, x_true)
# Direct function imports
from goto_swap.backends.jax import emd_1d_sorted_jax_differentiable
from goto_swap.backends.pytorch import emd_1d_sorted_pytorch_differentiablewasserstein_distance_loss_images(x, x_true, num_rotations=100, backend="jax"): Compute sliced Wasserstein distance between imagesemd_1d_sorted_differentiable(u, v, u_weights, v_weights, metric='sqeuclidean', tol=1e-12, backend="jax"): Compute 1D Earth Mover's Distanceget_backend(backend_name): Get backend instance for advanced usage
| Backend | JIT Compilation | GPU Support | TPU Support | Best For |
|---|---|---|---|---|
| JAX | ✅ Native | ✅ | ✅ | Research, high-performance computing |
| PyTorch | ❌ No | ✅ | ❌ | Integration with PyTorch workflows |
import jax
import jax.numpy as jnp
from goto_swap import get_backend
# Get backend and JIT compile for performance
backend = get_backend("jax")
jit_loss = jax.jit(backend.wasserstein_distance_loss_images, static_argnames=['num_rotations'])
# Use compiled function
x = jnp.random.normal(jax.random.PRNGKey(0), (2, 32, 32))
x_true = jnp.random.normal(jax.random.PRNGKey(1), (2, 32, 32))
loss = jit_loss(x, x_true, num_rotations=100)import torch
from goto_swap import emd_1d_sorted_differentiable
u = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
v = torch.tensor([1.5, 2.5, 3.5])
u_weights = torch.tensor([0.3, 0.4, 0.3])
v_weights = torch.tensor([0.2, 0.5, 0.3])
loss = emd_1d_sorted_differentiable(u, v, u_weights, v_weights, backend="pytorch")
loss.backward() # Compute gradientsimport jax.numpy as jnp
from goto_swap import wasserstein_distance_loss_images
# Process multiple image pairs
batch_x = jnp.random.randn(10, 32, 32) # 10 images
batch_x_true = jnp.random.randn(10, 32, 32)
# Vectorized computation
losses = jax.vmap(wasserstein_distance_loss_images)(batch_x, batch_x_true)GOTO-SWAP uses a modular backend system that allows you to choose between JAX and PyTorch implementations:
src/goto_swap/
├── __init__.py # Main package exports
├── base.py # Legacy compatibility (deprecated)
└── backends/ # Backend system
├── __init__.py # Backend factory and interface
├── jax/ # JAX backend implementation
│ ├── backend.py # JAXBackend class
│ └── emd.py # JAX EMD implementation
└── pytorch/ # PyTorch backend implementation
├── backend.py # PyTorchBackend class
└── emd.py # PyTorch EMD implementation
| Feature | JAX Backend | PyTorch Backend |
|---|---|---|
| JIT Compilation | ✅ Native | ❌ No |
| Automatic Differentiation | ✅ Native | ✅ Native |
| GPU Support | ✅ | ✅ |
| TPU Support | ✅ | ❌ |
| Memory Efficiency | ✅ | ✅ |
| Ecosystem Integration | JAX/Flax | PyTorch |
GOTO-SWAP includes comprehensive tests for both major functionalities. Tests are organized by backend to avoid compatibility issues.
Test specific backend:
# Test JAX backend only
python run_tests.py --backend jax
# Test PyTorch backend only
python run_tests.py --backend pytorch
# Test all available backends
python run_tests.pyUsing pytest directly:
# Test JAX backend
pytest tests/ --backend jax
# Test PyTorch backend
pytest tests/ --backend pytorch
# Test specific functionality
pytest tests/test_wasserstein_distance.py --backend jax
pytest tests/test_emd_1d.py --backend pytorchSystematic Tests:
test_wasserstein_distance.py- Wasserstein distance functionalitytest_emd_1d.py- EMD 1D functionalitytest_emd_comparison.py- Comparison with reference implementationstest_backends.py- Backend system tests
Test Features:
- ✅ Backend-specific filtering
- ✅ Reference implementation validation
- ✅ Gradient computation testing
- ✅ JIT compilation testing
- ✅ Cross-backend consistency
- ✅ Edge case coverage
JAX Environment:
mamba activate ./envs/goto-swap-jax
python run_tests.py --backend jaxPyTorch Environment:
mamba activate ./envs/goto-swap-pytorch
python run_tests.py --backend pytorchWith coverage:
python run_tests.py --backend jax --coverageVerbose output:
python run_tests.py --backend pytorch --verboseSpecific test files:
pytest tests/test_wasserstein_distance.py --backend jax -v
pytest tests/test_emd_1d.py --backend pytorch -vIf you try to use a backend that isn't installed:
# This will raise ImportError with helpful message
loss = wasserstein_distance_loss_images(x, x_true, backend="pytorch")
# ImportError: PyTorch backend requested but PyTorch is not installed.
# Install with: pip install torch torchvisionIf you're using the legacy goto_swap.base module:
# Old (deprecated)
from goto_swap.base import wasserstein_distance_loss_images
# New (recommended)
from goto_swap import wasserstein_distance_loss_images- GOTO: Main project repository with experimental codes - https://github.com/minhuanli/GOTO
To add a new backend (e.g., TensorFlow):
- Create a new directory under
backends/ - Implement the
BackendInterfaceabstract base class - Add dependencies to
pyproject.toml - Update the factory function in
backends/__init__.py
Example structure:
backends/tensorflow/
├── __init__.py
├── backend.py # TensorFlowBackend class
└── emd.py # TensorFlow EMD implementation
The test suite includes:
- Backend interface compliance tests
- Backend-specific functionality tests
- Cross-backend consistency tests
- Error handling tests
Run tests with:
python run_tests.py --backend jax
python run_tests.py --backend pytorchWe welcome contributions! Please see our testing guide for development details.
BSD 3-Clause License - see LICENSE for details.
If you use GOTO-SWAP in your research, please cite:
@software{goto_swap,
title={GOTO-SWAP: Gradient-based Optimal Transport Optimization -- Sliced Wasserstein API for Pipelines},
author={Li, Minhuan and Woollard, Geoffrey and Herreros, David},
year={2024},
url={https://github.com/flatironinstitute/GOTO-SWAP}
}