In [1]:
%matplotlib inline

import numpy as np
import torch

import ot

torch.set_default_tensor_type(torch.DoubleTensor)

## Sinkhorn convolutional unbalanced barycenters (in torch) 

TODO: 
1. Track the error, and stop when smaller than threshold
2. Add balanced barycenter
3. Implement barycenter projection (using gradient descent on barycenter weights)

In [2]:

def sep_kernel_prod(u, Cx, Cy, logspace=True):
    """
        2D batch kernel/vector products, with a separable kernel
    """
    if logspace:
        x = -Cy[:, :, None] + u[:, None, :, :]
        R = torch.logsumexp(x, dim=2)
        y = -Cx[:, :, None] + R.permute(1, 0, 2)[:, None, :, :]
        return torch.logsumexp(y, dim=2).permute(1, 0, 2)
    else:
        R = (Cy[:, None, : ] @ (u[:, None, :, :])).squeeze()
        return (Cx[:, None, : ] @ R.permute(1, 0, 2)[:, None, :, :]).squeeze().permute(1, 0, 2)

    
def sk_barycenter_2D(A, Cx, Cy, reg, reg_m, weights=None, numItermax=1000, 
                     stopThr=1e-6, logspace=True, balanced=False):
    
    dimx, dimy, n_hists = A.shape
    
    if logspace:
        Kx = Cx / reg
        Ky = Cy / reg
    else:
        Kx = torch.exp(-Cx / reg)
        Ky = torch.exp(-Cy / reg)
    
    if weights is None:
        weights = torch.ones(n_hists) / n_hists
    else:
        assert(len(weights) == A.shape[2])

    fi = reg_m / (reg + reg_m)

    if logspace :
        v = torch.zeros((dimx, dimy, n_hists))
    else:
        v= torch.ones((dimx, dimy, n_hists))
    
    for i in range(numItermax):

        ### Update u
        Kv = sep_kernel_prod(v, Kx, Ky, logspace)

        if logspace:
            u = fi * (np.log(A) - Kv)
        else:
            u = (A / Kv) ** fi

        ### Form Barycentre
        Ktu = sep_kernel_prod(u, Kx, Ky, logspace)
        
        if logspace:
            q = (1 / (1 - fi)) * torch.logsumexp((1 - fi) * Ktu + weights.log()[None, None, :], dim=2)
        else:
            q = torch.sum(weights * Ktu ** (1-fi), dim=2) ** (1 / (1. - fi))

        ### Update v
        if logspace:
            v = fi * (q[:, :, None] - Ktu)
        else:
            v = (q[:, : , None] / Ktu) ** fi
                
    if logspace:
        return torch.exp(q) 
    else:
        return q

            
            

## Test

In [3]:
## Generate data (images)

A = torch.rand(3, 3, 4) / 2


In [4]:
## Build distance matrices

dimx, dimy = A.shape[:2]

X = torch.arange(dimx).double()
Y = torch.arange(dimy).double()

Cx = ((X[:, None] - X)**2)
Cy = ((Y[:, None] - Y)**2)

X = np.column_stack(np.nonzero(np.ones((dimx,dimy))))
M =((X[:, None] - X)**2).sum(2)

m = M.max()

Cx /= m
Cy /= m
M = M / m


reg = 0.1
reg_m = 0.5


K = torch.exp(-torch.from_numpy(M)/reg).double()


In [5]:
## In logspace
sk_barycenter_2D(A, Cx, Cy, reg, reg_m, logspace=True, numItermax = 200)

tensor([[0.3041, 0.3375, 0.2776],
        [0.3276, 0.4086, 0.3002],
        [0.3284, 0.4307, 0.2495]])

In [6]:
## Not logspace
sk_barycenter_2D(A, Cx, Cy, reg, reg_m, logspace=False, numItermax = 200)

tensor([[0.3041, 0.3375, 0.2776],
        [0.3276, 0.4086, 0.3002],
        [0.3284, 0.4307, 0.2495]])

In [7]:
ot.barycenter_unbalanced(A.reshape(dimx * dimy, -1).numpy(), M, reg, reg_m, 
                         method = 'sinkhorn', stopThr=1e-12).reshape(dimx, dimy)

array([[0.30405808, 0.33746302, 0.27757692],
       [0.32755265, 0.40861036, 0.3001976 ],
       [0.32841249, 0.43068563, 0.24945079]])