In [1]:
import torch

def sinkhorn_loss(x, y, epsilon=0.1, max_iters=100):
    """
    Calculate the Sinkhorn loss between two sets of vectors, x and y.
    :param x: torch.Tensor, shape=(n, d)
    :param y: torch.Tensor, shape=(m, d)
    :param epsilon: regularization parameter
    :param max_iters: maximum number of Sinkhorn iterations
    :return: Optimal Transport distance
    """
    
    # Calculate pairwise distances
    C = torch.cdist(x, y, p=2)  # Cost matrix
    
    # Initialize variables
    n, m = x.shape[0], y.shape[0]
    a, b = torch.ones(n) / n, torch.ones(m) / m  # Uniform distribution over bins
    
    # Sinkhorn iterations
    u = torch.zeros(n)
    v = torch.zeros(m)
    K = torch.exp(-C / epsilon)
    
    for _ in range(max_iters):
        u = torch.log(a / torch.matmul(K, torch.exp(v)))  # Update u
        v = torch.log(b / torch.matmul(K.t(), torch.exp(u)))  # Update v
    
    # Compute Sinkhorn loss
    loss = torch.sum(torch.exp(u) * torch.matmul(K, torch.exp(v)) * C)
    return loss

Sinkhorn loss (distance) between x and y: 220.2919464111328


In [6]:
import torch

def sinkhorn_normalized(x, y, epsilon=0.1, max_iters=100):
    """
    Perform Sinkhorn iterations to compute the normalized transport plan.
    :param x: torch.Tensor, shape=(n, d)
    :param y: torch.Tensor, shape=(m, d)
    :param epsilon: regularization parameter
    :param max_iters: maximum number of Sinkhorn iterations
    :return: torch.Tensor, shape=(n, m), the normalized transport plan
    """
    
    # Calculate pairwise distances
    C = torch.cdist(x, y, p=2)  # Cost matrix
    
    # Initialize variables
    n, m = x.shape[0], y.shape[0]
    a, b = torch.ones(n) / n, torch.ones(m) / m  # Uniform distribution over bins
    
    # Sinkhorn iterations
    u = torch.zeros(n)
    v = torch.zeros(m)
    K = torch.exp(-C / epsilon)
    
    for _ in range(max_iters):
        u = torch.log(a / torch.matmul(K, torch.exp(v)))  # Update u
        v = torch.log(b / torch.matmul(K.t(), torch.exp(u)))  # Update v
    
    # Compute transport plan
    T = torch.exp(u[:, None] + v[None, :] - C / epsilon)
    
    return T

In [8]:
torch.manual_seed(0)  # for reproducibility

n, m, d = 100, 100, 10  # Number of samples and dimensions

# Generate random binary vectors
x = torch.randint(2, (n, d)).float()
y = torch.randint(2, (m, d)).float()

T = sinkhorn_normalized(x, y, epsilon=0.1, max_iters=100)

# Sample from T to obtain correspondences between x and y
sampled_indices = torch.multinomial(T[0], 5, replacement=True)
sampled_y = y[sampled_indices]

print(f"Sampled y corresponding to the first x: {sampled_y}")

Sampled y corresponding to the first x: tensor([[0., 0., 1., 0., 1., 1., 1., 1., 1., 1.],
        [0., 1., 1., 0., 1., 1., 1., 0., 1., 1.],
        [0., 0., 1., 0., 1., 1., 1., 1., 1., 1.],
        [0., 1., 1., 0., 1., 1., 1., 0., 1., 1.],
        [0., 0., 1., 0., 1., 1., 1., 1., 1., 1.]])


In [2]:
torch.manual_seed(0)  # for reproducibility

n, m, d = 100, 100, 10  # Number of samples and dimensions

# Generate random binary vectors
x = torch.randint(2, (n, d)).float()
y = torch.randint(2, (m, d)).float()

loss = sinkhorn_loss(x, y, epsilon=0.1, max_iters=100)
print(f"Sinkhorn loss (distance) between x and y: {loss.item()}")

Sinkhorn loss (distance) between x and y: 220.2919464111328
