Official PyTorch implementation for paper: Sliced-Regularized Optimal Transport
Details of the model architecture and experimental results can be found in our papers.
@article{nguyen2026srot,
title={Sliced-Regularized Optimal Transport},
author={Khai Nguyen},
journal={arXiv preprint arXiv:2604.23944},
year={2026},
pdf={https://arxiv.org/pdf/2604.23944.pdf}
}
Please CITE our paper whenever this repository is used to help produce published results or incorporated into other software.
This implementation is made by Khai Nguyen.
import numpy as np
import ot
import torch
def build_sot_plan(X, Y, a, b, L=50, delta=1e-8, rng=None):
"""
Uniform-average SOT reference plan from L random 1D projections.
Returns (1 - delta) pi_SOT + delta pi_ind with pi_ind = a outer b (product /
independent coupling). Same marginals as pi_SOT; for delta > 0 every entry is
strictly positive when a, b > 0 (helps Sinkhorn kernel support).
"""
if L <= 0:
raise ValueError("L must be a positive integer.")
if not (0.0 <= delta < 1.0):
raise ValueError("delta must satisfy 0 <= delta < 1.")
rng = rng or np.random.default_rng(0)
n, d = X.shape
m = Y.shape[0]
pi_sot = np.zeros((n, m), dtype=np.float64)
thetas = rng.standard_normal((L, d))
theta_norms = np.linalg.norm(thetas, axis=1, keepdims=True)
theta_norms = np.maximum(theta_norms, 1e-300)
thetas /= theta_norms
px_all = X @ thetas.T
py_all = Y @ thetas.T
emd_1d = ot.emd_1d
inv_L = 1.0 / L
for ell in range(L):
pi_sot += emd_1d(px_all[:, ell], py_all[:, ell], a, b)
pi_sot *= inv_L
if delta > 0.0:
pi_ind = np.outer(a, b)
pi_sot = (1.0 - delta) * pi_sot + delta * pi_ind
return pi_sot
def sinkhorn_sot(a, b, C, pi_sot, eps, max_iter=2000, tol=1e-9, log_freq=1):
"""Sinkhorn with SOT-guided kernel pi_sot * exp(-C/eps)."""
K = pi_sot * np.exp(-C / eps)
K = np.maximum(K, 1e-300)
u = np.ones(len(a))
v = np.ones(len(b))
log = []
for it in range(1, max_iter + 1):
Kv = np.maximum(K @ v, 1e-300)
u = a / Kv
KTu = np.maximum(K.T @ u, 1e-300)
v = b / KTu
if it % log_freq == 0:
pi = u[:, None] * K * v[None, :]
cost = float(np.sum(pi * C))
log.append((it, cost, pi))
err = max(
np.max(np.abs(pi.sum(axis=1) - a)),
np.max(np.abs(pi.sum(axis=0) - b)),
)
if err < tol:
break
return log
def sinkhorn_sot_torch(a, b, C, pi_sot, eps, max_iter=2000, tol=1e-9, log_freq=1, device="cuda"):
"""Torch implementation of SOT-guided Sinkhorn (runs on CUDA if available)."""
dev = torch.device(device)
a_t = torch.as_tensor(a, dtype=torch.float64, device=dev)
b_t = torch.as_tensor(b, dtype=torch.float64, device=dev)
C_t = torch.as_tensor(C, dtype=torch.float64, device=dev)
pi_sot_t = torch.as_tensor(pi_sot, dtype=torch.float64, device=dev)
K = (pi_sot_t * torch.exp(-C_t / eps)).clamp_min(1e-300)
u = torch.ones_like(a_t)
v = torch.ones_like(b_t)
log = []
for it in range(1, max_iter + 1):
u = a_t / (K @ v).clamp_min(1e-300)
v = b_t / (K.t() @ u).clamp_min(1e-300)
if it % log_freq == 0:
pi = (u[:, None] * K) * v[None, :]
cost = float((pi * C_t).sum().item())
log.append((it, cost, pi.detach().cpu().numpy()))
err = max(
torch.max(torch.abs(pi.sum(dim=1) - a_t)).item(),
torch.max(torch.abs(pi.sum(dim=0) - b_t)).item(),
)
if err < tol:
break
return log
python simulation/analysis_time_half_moon_vs_L.py
python simulation/analysis_time_half_moon_vs_n.py
python simulation/analysis_tv_vs_L_sot_init_sr_ot.py
python simulation/analysis_sr_ot_vs_L.py
python simulation/analysis_l1_vs_epsilon.py
python simulation/analysis_l1_vs_iterations.py
python simulation/plot_matching.pypython Color_Transfer/batch_transfer_all_pairs.pypython GradientFlow/main.py