In [1]:
import ot

In [19]:
import torch
import ot

def compute_pairwise_distances(x, y):
    if not x.dim() == y.dim() == 2:
        raise ValueError('Both inputs should be matrices.')
    if x.size(1) != y.size(1):
        raise ValueError('The number of features should be the same.')

    norm = lambda x: torch.sum(torch.pow(x, 2), 1)
    return torch.transpose(norm(torch.unsqueeze(x, 2) - torch.transpose(y, 0, 1)), 0, 1)

def optimal_transport_distance(x, y, reg=0.1):
    '''
    Calculate the Sinkhorn distance between two distributions using the Sinkhorn algorithm
    Args:
        x: source domain embeddings
        y: target domain embeddings
        reg: regularization parameter for the Sinkhorn algorithm
    Returns:
        Sinkhorn distance
    '''
    x_np = x.cpu().detach().numpy()
    y_np = y.cpu().detach().numpy()
    
    # Compute pairwise distance matrix
    M = ot.dist(x_np, y_np, metric='euclidean')
    
    # Uniform weights for source and target distributions
    n = x_np.shape[0]
    m = y_np.shape[0]
    a = ot.unif(n)
    b = ot.unif(m)
    
    # Compute the Sinkhorn distance
    sinkhorn_distance = ot.sinkhorn2(a, b, M, reg)
    
    return torch.tensor(sinkhorn_distance, dtype=torch.float32)

def ot_distance(hs, ht, reg=0.1):
    '''
    Optimal Transport distance using Sinkhorn algorithm
    Args:
        hs: source domain embeddings
        ht: target domain embeddings
        reg: regularization parameter for the Sinkhorn algorithm
    Returns:
        Sinkhorn distance
    '''
    loss_value = optimal_transport_distance(hs, ht, reg)
    return torch.clamp(loss_value, min=1e-4)

# Example usage:
hs = torch.randn(100, 50) # Example source embeddings
ht = torch.randn(100, 50)# Example target embeddings
distance = ot_distance(hs, ht, reg=.1)
print('Sinkhorn Distance:', distance)


Sinkhorn Distance: tensor(8.1880)
