Skip to content

flatironinstitute/GOTO-SWAP

Repository files navigation

GOTO-SWAP

Gradient-based Optimal Transport Optimization -- Sliced Wasserstein API for Pipelines

A unified, lightweight API for sliced Wasserstein image targets with multi-backend support (JAX and PyTorch). This package makes it easy to integrate sliced Wasserstein distance computations into your machine learning pipelines.

Features

  • 🚀 Multi-Backend Support: Choose between JAX and PyTorch backends
  • High Performance: JIT compilation with JAX, optimized PyTorch operations
  • 🔧 Easy Integration: Simple API that works with existing ML workflows
  • 📦 Lightweight: Optional dependencies - only install what you need
  • 🧮 Differentiable: Full automatic differentiation support
  • 🎯 Specialized: 1D Earth Mover's Distance (EMD) implementations

Installation

Minimal Installation

pip install goto-swap

Backend-Specific Installation

JAX Backend:

pip install goto-swap[jax]

PyTorch Backend:

pip install goto-swap[pytorch]

All Backends:

pip install goto-swap[all]

Quick Start

Basic Usage

import numpy as np
from goto_swap import wasserstein_distance_loss_images, emd_1d_sorted_differentiable

# Prepare your data
x = np.random.randn(2, 32, 32)  # Predicted images
x_true = np.random.randn(2, 32, 32)  # True images

# JAX backend (default)
loss = wasserstein_distance_loss_images(x, x_true)

# PyTorch backend
loss = wasserstein_distance_loss_images(x, x_true, backend="pytorch")

1D Earth Mover's Distance

import jax.numpy as jnp
from goto_swap import emd_1d_sorted_differentiable

# Sorted positions and weights
u = jnp.array([1.0, 2.0, 3.0])  # Source positions
v = jnp.array([1.5, 2.5, 3.5])  # Target positions
u_weights = jnp.array([0.3, 0.4, 0.3])  # Source weights
v_weights = jnp.array([0.2, 0.5, 0.3])  # Target weights

# Compute EMD
emd_loss = emd_1d_sorted_differentiable(u, v, u_weights, v_weights)

Advanced Usage

from goto_swap import get_backend

# Get specific backend
jax_backend = get_backend("jax")
pytorch_backend = get_backend("pytorch")

# Use backend directly
loss_jax = jax_backend.wasserstein_distance_loss_images(x, x_true)
loss_pytorch = pytorch_backend.wasserstein_distance_loss_images(x, x_true)

# Direct function imports
from goto_swap.backends.jax import emd_1d_sorted_jax_differentiable
from goto_swap.backends.pytorch import emd_1d_sorted_pytorch_differentiable

API Reference

Main Functions

  • wasserstein_distance_loss_images(x, x_true, num_rotations=100, backend="jax"): Compute sliced Wasserstein distance between images
  • emd_1d_sorted_differentiable(u, v, u_weights, v_weights, metric='sqeuclidean', tol=1e-12, backend="jax"): Compute 1D Earth Mover's Distance
  • get_backend(backend_name): Get backend instance for advanced usage

Backend Selection

Backend JIT Compilation GPU Support TPU Support Best For
JAX ✅ Native Research, high-performance computing
PyTorch ❌ No Integration with PyTorch workflows

Examples

JAX with JIT Compilation

import jax
import jax.numpy as jnp
from goto_swap import get_backend

# Get backend and JIT compile for performance
backend = get_backend("jax")
jit_loss = jax.jit(backend.wasserstein_distance_loss_images, static_argnames=['num_rotations'])

# Use compiled function
x = jnp.random.normal(jax.random.PRNGKey(0), (2, 32, 32))
x_true = jnp.random.normal(jax.random.PRNGKey(1), (2, 32, 32))
loss = jit_loss(x, x_true, num_rotations=100)

PyTorch with Gradients

import torch
from goto_swap import emd_1d_sorted_differentiable

u = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
v = torch.tensor([1.5, 2.5, 3.5])
u_weights = torch.tensor([0.3, 0.4, 0.3])
v_weights = torch.tensor([0.2, 0.5, 0.3])

loss = emd_1d_sorted_differentiable(u, v, u_weights, v_weights, backend="pytorch")
loss.backward()  # Compute gradients

Batch Processing

import jax.numpy as jnp
from goto_swap import wasserstein_distance_loss_images

# Process multiple image pairs
batch_x = jnp.random.randn(10, 32, 32)  # 10 images
batch_x_true = jnp.random.randn(10, 32, 32)

# Vectorized computation
losses = jax.vmap(wasserstein_distance_loss_images)(batch_x, batch_x_true)

Architecture

GOTO-SWAP uses a modular backend system that allows you to choose between JAX and PyTorch implementations:

src/goto_swap/
├── __init__.py                 # Main package exports
├── base.py                     # Legacy compatibility (deprecated)
└── backends/                   # Backend system
    ├── __init__.py             # Backend factory and interface
    ├── jax/                    # JAX backend implementation
    │   ├── backend.py          # JAXBackend class
    │   └── emd.py              # JAX EMD implementation
    └── pytorch/                # PyTorch backend implementation
        ├── backend.py          # PyTorchBackend class
        └── emd.py              # PyTorch EMD implementation

Backend Features

Feature JAX Backend PyTorch Backend
JIT Compilation ✅ Native ❌ No
Automatic Differentiation ✅ Native ✅ Native
GPU Support
TPU Support
Memory Efficiency
Ecosystem Integration JAX/Flax PyTorch

Testing

GOTO-SWAP includes comprehensive tests for both major functionalities. Tests are organized by backend to avoid compatibility issues.

Quick Testing

Test specific backend:

# Test JAX backend only
python run_tests.py --backend jax

# Test PyTorch backend only
python run_tests.py --backend pytorch

# Test all available backends
python run_tests.py

Using pytest directly:

# Test JAX backend
pytest tests/ --backend jax

# Test PyTorch backend
pytest tests/ --backend pytorch

# Test specific functionality
pytest tests/test_wasserstein_distance.py --backend jax
pytest tests/test_emd_1d.py --backend pytorch

Test Categories

Systematic Tests:

  • test_wasserstein_distance.py - Wasserstein distance functionality
  • test_emd_1d.py - EMD 1D functionality
  • test_emd_comparison.py - Comparison with reference implementations
  • test_backends.py - Backend system tests

Test Features:

  • ✅ Backend-specific filtering
  • ✅ Reference implementation validation
  • ✅ Gradient computation testing
  • ✅ JIT compilation testing
  • ✅ Cross-backend consistency
  • ✅ Edge case coverage

Environment-Specific Testing

JAX Environment:

mamba activate ./envs/goto-swap-jax
python run_tests.py --backend jax

PyTorch Environment:

mamba activate ./envs/goto-swap-pytorch
python run_tests.py --backend pytorch

Advanced Testing Options

With coverage:

python run_tests.py --backend jax --coverage

Verbose output:

python run_tests.py --backend pytorch --verbose

Specific test files:

pytest tests/test_wasserstein_distance.py --backend jax -v
pytest tests/test_emd_1d.py --backend pytorch -v

Error Handling

If you try to use a backend that isn't installed:

# This will raise ImportError with helpful message
loss = wasserstein_distance_loss_images(x, x_true, backend="pytorch")
# ImportError: PyTorch backend requested but PyTorch is not installed. 
# Install with: pip install torch torchvision

Migration from Legacy Code

If you're using the legacy goto_swap.base module:

# Old (deprecated)
from goto_swap.base import wasserstein_distance_loss_images

# New (recommended)
from goto_swap import wasserstein_distance_loss_images

Related Projects

Development

Adding New Backends

To add a new backend (e.g., TensorFlow):

  1. Create a new directory under backends/
  2. Implement the BackendInterface abstract base class
  3. Add dependencies to pyproject.toml
  4. Update the factory function in backends/__init__.py

Example structure:

backends/tensorflow/
├── __init__.py
├── backend.py      # TensorFlowBackend class
└── emd.py          # TensorFlow EMD implementation

Testing

The test suite includes:

  • Backend interface compliance tests
  • Backend-specific functionality tests
  • Cross-backend consistency tests
  • Error handling tests

Run tests with:

python run_tests.py --backend jax
python run_tests.py --backend pytorch

Contributing

We welcome contributions! Please see our testing guide for development details.

License

BSD 3-Clause License - see LICENSE for details.

Citation

If you use GOTO-SWAP in your research, please cite:

@software{goto_swap,
  title={GOTO-SWAP: Gradient-based Optimal Transport Optimization -- Sliced Wasserstein API for Pipelines},
  author={Li, Minhuan and Woollard, Geoffrey and Herreros, David},
  year={2024},
  url={https://github.com/flatironinstitute/GOTO-SWAP}
}

About

A unified lightweight API repo for sliced Wasserstein target on images

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

Generated from rs-station/rs-template