Pure PyTorch | Triton GPU Kernels | Differentiable | Up to 175x faster than POT
TorchGW aligns two point clouds by matching their internal distance structures -- even when they live in different dimensions. Instead of the full O(NK(N+K)) GW cost, it samples M anchor pairs each iteration and approximates the cost in O(NKM), enabling GPU-accelerated alignment at scales where standard solvers are impractical.
Use cases: single-cell multi-omics integration, cross-domain graph matching, shape correspondence, manifold alignment.
|
Performance
|
Features
|
v0.4.0 (2026-04-07) -- Triton fused Sinkhorn (2-5x GPU speedup), mixed precision, smart early stopping, Sinkhorn warm-start, Dijkstra caching, and 15 numerical stability fixes. See CHANGELOG.md.
pip install -e .Requirements: numpy, scipy, scikit-learn, torch>=2.0, joblib.
Triton ships with PyTorch and enables GPU kernel fusion automatically. No POT needed.
from torchgw import sampled_gw
# Basic usage
T = sampled_gw(X_source, X_target)
# Recommended for large-scale (fastest)
T = sampled_gw(X_source, X_target, distance_mode="landmark", mixed_precision=True)Minimal working example (click to expand)
import torch
from torchgw import sampled_gw
X = torch.randn(500, 3) # source: 500 points in 3D
Y = torch.randn(600, 5) # target: 600 points in 5D (dimensions may differ)
T = sampled_gw(X, Y, epsilon=0.005, M=80, max_iter=200)
# T is a (500, 600) transport plan: T[i,j] = coupling weight between X[i] and Y[j]
print(f"Transport plan: {T.shape}, total mass: {T.sum():.4f}")Spiral (2D) to Swiss roll (3D) alignment on NVIDIA L40S:
| Scale | Method | Time | Spearman rho |
|---|---|---|---|
| 400 vs 500 | POT ot.gromov_wasserstein |
1.6 s | 0.999 |
| 400 vs 500 | TorchGW | 0.46 s | 0.999 |
| 4000 vs 5000 | POT ot.gromov_wasserstein |
183 s | 0.999 |
| 4000 vs 5000 | TorchGW precomputed | 5.1 s | 0.998 |
| 4000 vs 5000 | TorchGW landmark | 1.0 s | 0.999 |
At 4000x5000 with landmark distances, TorchGW is up to ~175x faster than POT with equal quality.
Choose based on your data scale:
| Mode | Best for | Per-iteration | Memory | Notes |
|---|---|---|---|---|
"precomputed" |
N < 5k | O(NM) lookup | O(N^2) | All-pairs Dijkstra upfront |
"dijkstra" |
5k-50k | O(MN log N) | O(NM) | On-the-fly with caching |
"landmark" |
any scale | O(NMd) GPU | O(Nd) | Recommended default |
# Small scale: precompute all distances once
T = sampled_gw(X, Y, distance_mode="precomputed")
# Bring your own distance matrices
T = sampled_gw(dist_source=D_X, dist_target=D_Y, distance_mode="precomputed")
# Large scale (recommended)
T = sampled_gw(X, Y, distance_mode="landmark", n_landmarks=50)Best performance settings
T = sampled_gw(
X, Y,
distance_mode="landmark", # avoids expensive all-pairs Dijkstra
mixed_precision=True, # float32 Sinkhorn (2x faster on GPU)
M=80, # more samples = better cost estimate
epsilon=0.005, # moderate regularization
)Fused Gromov-Wasserstein
Blend structural (graph distance) and feature (linear) costs:
C_feat = torch.cdist(features_src, features_tgt)
T = sampled_gw(X, Y, fgw_alpha=0.5, C_linear=C_feat)
# Pure Wasserstein (no graph distances needed)
T = sampled_gw(fgw_alpha=1.0, C_linear=C_feat)Semi-relaxed transport
For unbalanced datasets (e.g., cell types present in one sample but not the other):
T = sampled_gw(X, Y, semi_relaxed=True, rho=1.0)
# Source marginal enforced, target marginal relaxed via KL penaltyMulti-scale warm start
Speeds up convergence by solving a coarse problem first:
T = sampled_gw(X, Y, multiscale=True, n_coarse=200)Note: GW has symmetric local optima. Works best on data without strong symmetries.
Differentiable mode
Use GW cost as a training loss (gradients via envelope theorem):
C_feat = torch.cdist(encoder(X), encoder(Y))
T = sampled_gw(fgw_alpha=1.0, C_linear=C_feat, differentiable=True)
loss = (C_feat.detach() * T).sum()
loss.backward() # gradients flow to encoder parametersLow-rank Sinkhorn (N, K > 50k)
For very large problems where the N*K transport plan does not fit in memory:
from torchgw import sampled_lowrank_gw
T = sampled_lowrank_gw(X, Y, rank=30, distance_mode="landmark", n_landmarks=50)Memory: O((N+K)*rank) instead of O(NK).
sampled_gw(
X_source, X_target, # (N, D) and (K, D') feature matrices
*,
distance_mode="dijkstra", # "precomputed" | "dijkstra" | "landmark"
fgw_alpha=0.0, # 0 = pure GW, 1 = Wasserstein, (0,1) = Fused GW
C_linear=None, # (N, K) feature cost matrix for FGW
M=50, # anchor pairs per iteration
epsilon=0.001, # entropic regularization
max_iter=500, tol=1e-5, # convergence control
mixed_precision=False, # float32 Sinkhorn for GPU speed
semi_relaxed=False, # relax target marginal
differentiable=False, # keep autograd graph
multiscale=False, # coarse-to-fine warm start
log=False, # return (T, info_dict)
... # see docs for full parameter list
) -> Tensor # (N, K) transport planSame interface plus rank, lr_max_iter, lr_dykstra_max_iter.
Uses Scetbon, Cuturi & Peyre (2021) factorization.
When to use: only when N*K exceeds GPU memory. At smaller scales,
sampled_gwis faster.
Full API documentation: chansigit.github.io/torchgw
┌─────────────────────────────────────────────┐
│ GW Main Loop │
│ │
T_init ──────────►│ 1. GPU multinomial sampling (M anchors) │
│ 2. Distance computation (Dijkstra/landmark)│
│ 3. GW cost matrix assembly │
│ 4. Triton fused Sinkhorn projection │
│ 5. Momentum update + warm-start │
│ 6. Cost plateau convergence check │
│ │
│ ↺ repeat until converged │
└──────────────────────┬──────────────────────┘
│
▼
T* (N × K)
Acceleration stack:
- Triton kernels -- fused row/column logsumexp, fused T materialization, fused marginal check
- Warm-start -- reuse Sinkhorn potentials across iterations
- Mixed precision -- float32 log-domain + float64 output
- Dijkstra cache -- avoid redundant SSSP across iterations
- Cost plateau early stopping -- stop when converged, not at max_iter
git clone https://github.com/chansigit/torchgw.git
cd torchgw
pip install -e ".[dev]"
pytest tests/ -v # 72 tests, ~18sIf you use TorchGW in your research, please cite:
@software{torchgw,
author = {Sijie Chen},
title = {TorchGW: Fast Sampled Gromov-Wasserstein Optimal Transport},
url = {https://github.com/chansigit/torchgw},
version = {0.4.0},
year = {2026},
}This project is source-available.
It is free for academic and other non-commercial research and educational use under the terms of the included LICENSE.
Any commercial use — including any use by or on behalf of a for-profit entity, internal commercial research, product development, consulting, paid services, or deployment in commercial settings — requires a separate paid commercial license.
For commercial licensing inquiries, please contact: chansigit@gmail.com
See COMMERCIAL_LICENSE.md for details.

