Skip to content

khainb/SROT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Sliced-Regularized Optimal Transport

Official PyTorch implementation for paper: Sliced-Regularized Optimal Transport

Details of the model architecture and experimental results can be found in our papers.

@article{nguyen2026srot,
  title={Sliced-Regularized Optimal Transport},
  author={Khai Nguyen},
  journal={arXiv preprint arXiv:2604.23944},
  year={2026},
  pdf={https://arxiv.org/pdf/2604.23944.pdf}
}

Please CITE our paper whenever this repository is used to help produce published results or incorporated into other software.

This implementation is made by Khai Nguyen.

Implementation of SROT

import numpy as np 
import ot 
import torch

def build_sot_plan(X, Y, a, b, L=50, delta=1e-8, rng=None):
    """
    Uniform-average SOT reference plan from L random 1D projections.

    Returns (1 - delta) pi_SOT + delta pi_ind with pi_ind = a outer b (product /
    independent coupling). Same marginals as pi_SOT; for delta > 0 every entry is
    strictly positive when a, b > 0 (helps Sinkhorn kernel support).
    """
    if L <= 0:
        raise ValueError("L must be a positive integer.")
    if not (0.0 <= delta < 1.0):
        raise ValueError("delta must satisfy 0 <= delta < 1.")

    rng = rng or np.random.default_rng(0)
    n, d = X.shape
    m = Y.shape[0]
    pi_sot = np.zeros((n, m), dtype=np.float64)

    thetas = rng.standard_normal((L, d))
    theta_norms = np.linalg.norm(thetas, axis=1, keepdims=True)
    theta_norms = np.maximum(theta_norms, 1e-300)
    thetas /= theta_norms

    px_all = X @ thetas.T
    py_all = Y @ thetas.T
    emd_1d = ot.emd_1d
    inv_L = 1.0 / L

    for ell in range(L):
        pi_sot += emd_1d(px_all[:, ell], py_all[:, ell], a, b)

    pi_sot *= inv_L
    if delta > 0.0:
        pi_ind = np.outer(a, b)
        pi_sot = (1.0 - delta) * pi_sot + delta * pi_ind
    return pi_sot
    
def sinkhorn_sot(a, b, C, pi_sot, eps, max_iter=2000, tol=1e-9, log_freq=1):
    """Sinkhorn with SOT-guided kernel pi_sot * exp(-C/eps)."""
    K = pi_sot * np.exp(-C / eps)
    K = np.maximum(K, 1e-300)
    u = np.ones(len(a))
    v = np.ones(len(b))
    log = []

    for it in range(1, max_iter + 1):
        Kv = np.maximum(K @ v, 1e-300)
        u = a / Kv
        KTu = np.maximum(K.T @ u, 1e-300)
        v = b / KTu

        if it % log_freq == 0:
            pi = u[:, None] * K * v[None, :]
            cost = float(np.sum(pi * C))
            log.append((it, cost, pi))
            err = max(
                np.max(np.abs(pi.sum(axis=1) - a)),
                np.max(np.abs(pi.sum(axis=0) - b)),
            )
            if err < tol:
                break
    return log
    
def sinkhorn_sot_torch(a, b, C, pi_sot, eps, max_iter=2000, tol=1e-9, log_freq=1, device="cuda"):
    """Torch implementation of SOT-guided Sinkhorn (runs on CUDA if available)."""
    dev = torch.device(device)
    a_t = torch.as_tensor(a, dtype=torch.float64, device=dev)
    b_t = torch.as_tensor(b, dtype=torch.float64, device=dev)
    C_t = torch.as_tensor(C, dtype=torch.float64, device=dev)
    pi_sot_t = torch.as_tensor(pi_sot, dtype=torch.float64, device=dev)
    K = (pi_sot_t * torch.exp(-C_t / eps)).clamp_min(1e-300)
    u = torch.ones_like(a_t)
    v = torch.ones_like(b_t)
    log = []
    for it in range(1, max_iter + 1):
        u = a_t / (K @ v).clamp_min(1e-300)
        v = b_t / (K.t() @ u).clamp_min(1e-300)
        if it % log_freq == 0:
            pi = (u[:, None] * K) * v[None, :]
            cost = float((pi * C_t).sum().item())
            log.append((it, cost, pi.detach().cpu().numpy()))
            err = max(
                torch.max(torch.abs(pi.sum(dim=1) - a_t)).item(),
                torch.max(torch.abs(pi.sum(dim=0) - b_t)).item(),
            )
            if err < tol:
                break
    return log

simulation

python simulation/analysis_time_half_moon_vs_L.py
python simulation/analysis_time_half_moon_vs_n.py
python simulation/analysis_tv_vs_L_sot_init_sr_ot.py
python simulation/analysis_sr_ot_vs_L.py
python simulation/analysis_l1_vs_epsilon.py
python simulation/analysis_l1_vs_iterations.py
python simulation/plot_matching.py

Color_Transfer

python Color_Transfer/batch_transfer_all_pairs.py

GradientFlow

python GradientFlow/main.py

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages