In [1]:
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
def sinkhorn_divergence_unbalanced(α, x, β, y, epsilon=0.01, rho=0.1, nits=200, p=2):
    """
    Compute unbalanced entropic Sinkhorn divergence between two empirical measures:
      α on x ∈ R^d and β on y ∈ R^d
    using KL relaxation on marginals (Feydy et al., 2019).
    """
    # Distances
    C = torch.cdist(x, y, p=2).pow(p)
    K = torch.exp(-C / epsilon)

    # Initializations
    u = torch.ones_like(α)
    v = torch.ones_like(β)

    # Exponent for marginal relaxation
    tau = rho / (rho + epsilon)

    for _ in range(nits):
        Kv = K @ v
        u = (α / (Kv + 1e-16)) ** tau
        Kt_u = K.t() @ u
        v = (β / (Kt_u + 1e-16)) ** tau

    # Compute plan
    pi = torch.diag(u) @ K @ torch.diag(v)

    # Compute transport cost
    cost = torch.sum(pi * C)

    # KL regularization terms
    a_hat = pi.sum(dim=1)
    b_hat = pi.sum(dim=0)

    kl = lambda p, q: torch.sum(p * (torch.log((p + 1e-16) / (q + 1e-16)) - 1.0) + q)
    kl_a = kl(a_hat, α)
    kl_b = kl(b_hat, β)
    kl_pi = kl(pi.view(-1), torch.ger(α, β).view(-1))

    loss = cost + epsilon * kl_pi + rho * (kl_a + kl_b)
    return loss, pi

In [None]:


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


def load_image(path, max_size=None):
    """Charge une image RGB normalisée [0,1]."""
    img = Image.open(path).convert("RGB")
    if max_size:
        img.thumbnail(max_size, Image.BILINEAR)
    img = np.array(img, dtype=np.float32) / 255.0
    return img


def image_to_measure_5d(img, device=device):
    """
    Convertit une image en mesure empirique sur R^5 : (x, y, r, g, b)
    """
    H, W, _ = img.shape
    y, x = np.meshgrid(np.linspace(0, 1, H), np.linspace(0, 1, W), indexing="ij")
    pos = np.stack([x.flatten(), y.flatten()], axis=1)
    col = img.reshape(-1, 3)
    X = np.concatenate([pos, col], axis=1)
    X = torch.tensor(X, dtype=torch.float32, device=device)
    a = torch.full((X.shape[0],), 1.0 / X.shape[0], device=device)
    return X, a, (H, W)

In [None]:
# charge deux images (de tailles éventuellement différentes)
img_src = load_image("../data/cifar10/automobile/0000.jpg", max_size=(64, 64))
img_tgt = load_image("../data/pixelart/images/image_241.JPEG", max_size=(32, 32))

X_src, a_src, (H_src, W_src) = image_to_measure_5d(img_src)
X_tgt, b_tgt, (H_tgt, W_tgt) = image_to_measure_5d(img_tgt)

print("Source :", X_src.shape, "| Cible :", X_tgt.shape)

In [None]:
eps = 1e-3
rho = 1e-3
nits = 200

loss, pi = sinkhorn_divergence_unbalanced(
    a_src, X_src, b_tgt, X_tgt, epsilon=eps, rho=rho, nits=nits, p=2
)

print(f"Unbalanced Sinkhorn loss: {loss.item():.6f}")
print("π shape:", pi.shape)

In [None]:
def transport_colors(X_src, X_tgt, pi, H_src, W_src):
    """Transporte les couleurs selon le plan π."""
    col_tgt = X_tgt[:, 2:5]
    mass_out = pi.sum(dim=1, keepdim=True) + 1e-16
    new_colors = (pi @ col_tgt) / mass_out
    new_colors_np = new_colors.detach().cpu().numpy().reshape(H_src, W_src, 3)
    return np.clip(new_colors_np, 0, 1)


if pi is not None:
    img_transported = transport_colors(X_src, X_tgt, pi, H_src, W_src)

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(img_src)
    plt.title("Source")
    plt.subplot(1, 3, 2)
    plt.imshow(img_tgt)
    plt.title("Cible")
    plt.subplot(1, 3, 3)
    plt.imshow(img_transported)
    plt.title("Transportée (Sinkhorn explicite)")
    for ax in plt.gcf().axes:
        ax.axis("off")
    plt.tight_layout()
    plt.show()