# Sinkhorn loss module


Implements `SinkhornLoss` module for

$loss = H[target, P(M)]$ where $P(M)$ is the solution of the regularised OT problem with affinity matrix M.

It overwrites the backward in the associated `SinkhornLossFunc`.

Drawback of this approach: we do not have access to P.

**Notation**

OT convention for entropy
$H[\pi] = -\langle \pi, \log \pi -1 \rangle$

Usual definition for cross entropy
$H[\sigma, \pi] = - \langle \sigma, \log \pi \rangle$

$\text{loss} = H[\sigma, \pi] = \frac{1}{\epsilon} [L_\epsilon(M) - \langle \sigma , M\rangle ] - \langle \sigma, 1\rangle$

$\nabla_M \text{loss} = \frac{1}{\epsilon} [\pi - \sigma ]$



**TODO**

- add convergence criterion in sinkhorn (tol)
- maybe use exernal solver (python OT)
- give access to P in module

## Implementation

In [1]:
import torch
from torch.autograd import Function
import torch.nn as nn

def sinkhorn(M, a, b, epsilon, n_iter):
    """Basic sinkhorn algorithm. 
    
    Solves the regularised OT problem:
        max <M, P> + \epsilon H[P] s.t. \sum_j P_ij = a_i and \sum_i P_ij = b_j
    with entropy H[P] = - \sum_ij P_ij [log(P_ij) - 1]
    
    Args:
        M (torch.Tensor): affinity matrix
        a (torch.Tensor): user capacities
        b (torch.Tensor): item capacities
        epsilon (float): regularisation parameter
        n_iter (int): number of iteration
    
    Returns:
        P (torch.Tensor): coupling matrix
    """
    v = torch.ones_like(b)
    K = torch.exp(M / epsilon)
    
    for _ in range(n_iter):
        u = a / torch.matmul(K, v)
        v = b / torch.matmul(torch.transpose(K, 1, 0), u)
        
    uv = torch.outer(u, v)
    P = K * uv
    
    return P


class SinkhornLossFunc(Function):
    @staticmethod
    def forward(ctx, M, target, a, b, epsilon, solver, solver_options):
        P = solver(M.detach(), a, b, epsilon, **solver_options)
        cross_entropy = - (target * P.log()).sum()
        delta_P = (P - target) / epsilon
        ctx.save_for_backward(delta_P)
        
        return cross_entropy

    @staticmethod
    def backward(ctx, grad_output):
        delta_P, = ctx.saved_tensors
        grad_M = delta_P * grad_output
        
        return grad_M, None, None, None, None, None, None   


class SinkhornLoss(nn.Module):
    """Sinkhorn loss. 
    
    Computes loss = H[target, P(M)] where P(M) is the solution of the regularised OT problem
    with affinity matrix M.
    
    Args:
        a (torch.Tensor): user capacities
        b (torch.Tensor): item capacities
        epsilon (float): regularisation parameter
        solver (function): OT solver
        solver_kwargs (int): options to pass to the solver
    """
    def __init__(self, a, b, epsilon, solver, **solver_options):
        super().__init__()
        self.a = a
        self.b = b
        self.epsilon = epsilon
        self.solver = solver
        self.solver_options = solver_options
        
    def forward(self, M, target):
        return SinkhornLossFunc.apply(M, target, self.a, self.b, self.epsilon, self.solver, self.solver_options)
    
    def extra_repr(self):
        return (
            f"a={self.a},\nb={self.b},\n"
            f"epsilon={self.epsilon:.2e}, solver={self.solver}, solver_options={self.solver_options}"
        )

## Test

In [2]:
torch.manual_seed(42)

# affinity matrix
M = torch.rand(3, 6, dtype=torch.double, requires_grad=True)


target = torch.rand(*M.shape, dtype=torch.double)
a = target.sum(axis=1)
b = target.sum(axis=0)

In [3]:
SL = SinkhornLoss(a ,b, 0.01, sinkhorn, n_iter=100)
SL

SinkhornLoss(
  a=tensor([2.4401, 1.5512, 2.2407], dtype=torch.float64),
  b=tensor([0.9066, 0.7111, 1.2323, 1.1531, 1.1917, 1.0373], dtype=torch.float64),
  epsilon=1.00e-02, solver=<function sinkhorn at 0x13cba4820>, solver_options={'n_iter': 100}
)

In [4]:
loss = SL(M, target)
M.grad = None
loss.backward()
M.grad

tensor([[-62.3403, -43.7279,  26.2985, -68.3430,  87.9453,  60.2457],
        [ 87.4850,  60.3409, -93.1909, -15.2126, -26.5045, -13.0433],
        [-25.1447, -16.6130,  66.8924,  83.5556, -61.4408, -47.2023]],
       dtype=torch.float64)

## Gradient checking

In [5]:
# alias
sinkhorn_loss = SinkhornLossFunc.apply
# gradcheck do not pass for too small epsilon
from torch.autograd import gradcheck
import numpy as np
for epsilon in np.linspace(0.1,0,11):
    inputs = (M, target, a, b, epsilon, sinkhorn, dict(n_iter=100))
    try:
        gradcheck(sinkhorn_loss, inputs)
    except:
        print(f"gradcheck error for epsilon={epsilon:.2f}")
        break
    print(f"gradcheck pass for epsilon={epsilon:.2f}")

gradcheck pass for epsilon=0.10
gradcheck pass for epsilon=0.09
gradcheck pass for epsilon=0.08
gradcheck pass for epsilon=0.07
gradcheck pass for epsilon=0.06
gradcheck pass for epsilon=0.05
gradcheck error for epsilon=0.04


# Sinkhorn value module


Implements `SinkhornValue` module that takes the affinity matrix $M$ as input and returns the value function of the regularised OT problem:

$L(M) = L_M^\epsilon(a,b) = \max_{\pi \in U(a,b)} \langle \pi, M \rangle + \epsilon H[\pi]$ 

The prescribed marginals $a$ and $b$ are given parameters, 
we take the OT convention for entropy

$H[\pi] = -\langle \pi, \log \pi -1 \rangle$

It overwrites the backward in the associated `SinkhornValueFunc`.

The gradient wrt $M$ is simply (see Prop 9.2 in Computational OT, Peyr√© and Cuturi):

$\nabla_M L_M^\epsilon(a,b) = \pi^*$

where $\pi^*$ is the optimal coupling:

$\pi^* = \arg\max_{\pi \in U(a,b)} \langle \pi, M \rangle + \epsilon H[\pi]$ 

In [6]:
class SinkhornValueFunc(Function):
    @staticmethod
    def forward(ctx, M, a, b, epsilon, solver, solver_options):
        P = solver(M.detach(), a, b, epsilon, **solver_options)
        ctx.save_for_backward(P)
        # clamping log(P) to -100 to avoid 0 log(0) = nan
        log_P = P.log().clamp(min=-100)
        H = (P * (1 - log_P)).sum()
        value_OT = (P*M).sum() + epsilon*H
        return value_OT

    @staticmethod
    def backward(ctx, grad_output):
        P, = ctx.saved_tensors
        grad_M = P * grad_output
        return grad_M, None, None, None, None, None   


class SinkhornValue(nn.Module):
    """Sinkhorn value. 
    
    Returns optimal value for the regularised OT problem:
        L(M) = max <M, P> + \epsilon H[P] s.t. \sum_j P_ij = a_i and \sum_i P_ij = b_j
    with entropy H[P] = - \sum_ij P_ij [log(P_ij) - 1]
    
    Args:
        a (torch.Tensor): user capacities
        b (torch.Tensor): item capacities
        epsilon (float): regularisation parameter
        solver (function): OT solver
        solver_kwargs (int): options to pass to the solver
    """
    def __init__(self, a, b, epsilon, solver, **solver_options):
        super().__init__()
        self.a = a
        self.b = b
        self.epsilon = epsilon
        self.solver = solver
        self.solver_options = solver_options
        
    def forward(self, M):
        return SinkhornValueFunc.apply(M, self.a, self.b, self.epsilon, self.solver, self.solver_options)
    
    def extra_repr(self):
        return (
            f"a={self.a},\nb={self.b},\n"
            f"epsilon={self.epsilon:.2e}, solver={self.solver}, solver_options={self.solver_options}"
        )

## Test

In [7]:
a.shape, b.shape, M.shape

(torch.Size([3]), torch.Size([6]), torch.Size([3, 6]))

In [8]:
SV = SinkhornValue(a ,b, 0.01, sinkhorn, n_iter=100)
SV

SinkhornValue(
  a=tensor([2.4401, 1.5512, 2.2407], dtype=torch.float64),
  b=tensor([0.9066, 0.7111, 1.2323, 1.1531, 1.1917, 1.0373], dtype=torch.float64),
  epsilon=1.00e-02, solver=<function sinkhorn at 0x13cba4820>, solver_options={'n_iter': 100}
)

In [9]:
loss = SV(M)
M.grad = None
loss.backward()
M.grad

tensor([[1.4196e-27, 2.2927e-26, 3.3772e-01, 2.2559e-14, 1.1917e+00, 9.1152e-01],
        [9.0619e-01, 6.4378e-01, 4.3136e-15, 4.7384e-23, 7.7781e-29, 2.5152e-17],
        [4.0832e-04, 6.7278e-02, 8.9459e-01, 1.1531e+00, 2.7904e-30, 1.2575e-01]],
       dtype=torch.float64)

In [10]:
loss

tensor(3.8302, dtype=torch.float64, grad_fn=<SinkhornValueFuncBackward>)

## Gradient checking

In [11]:
# alias
sinkhorn_value = SinkhornValueFunc.apply
# gradcheck do not pass for too small epsilon
from torch.autograd import gradcheck
import numpy as np
for epsilon in np.linspace(0.1,0,11):
    inputs = (M, a, b, epsilon, sinkhorn, dict(n_iter=100))
    try:
        gradcheck(sinkhorn_value, inputs)
    except:
        print(f"gradcheck error for epsilon={epsilon:.2f}")
        break
    print(f"gradcheck pass for epsilon={epsilon:.2f}")

gradcheck pass for epsilon=0.10
gradcheck pass for epsilon=0.09
gradcheck pass for epsilon=0.08
gradcheck pass for epsilon=0.07
gradcheck pass for epsilon=0.06
gradcheck error for epsilon=0.05
