In [1]:
import numpy as np

In [6]:
class SinkhornMethod:
    def __init__(self, gamma, n=5, epsilon=1e-3, proxy_epsilon=1e-7):
        self.l = np.ones(n)
        self.m = np.ones(n)
        self.X = 1/n**2 * np.ones((n, n))
        self.X_k = 0
        
        #---------------------#
        # constants
        self.n = n
        self.epsilon = epsilon
        self.proxy_epsilon = proxy_epsilon
        self.gamma = gamma
    
    def _new_lag(self, C, p, q):
        for i in range(self.n):
            self.l[i] = self.gamma * log(1/p[i] * np.sum([exp(-(self.gamma + C[i, j] + self.m[j])/self.gamma) * self.X_k[i, j] for j in range(self.n)]))
        for j in range(self.n):
            self.m[j] = self.gamma * log(1/q[j] * np.sum([exp(-(self.gamma + C[i, j] + self.l[i])/self.gamma) * self.X_k[i, j] for i in range(self.n)]))
    
    def _new_X(self, C, p, q):
        for i in range(self.n):
            for j in range(self.n):
                self.X[i,j] = self.X_k[i, j] * exp(- (self.gamma + C[i,j] + self.l[i] + self.m[j])/self.gamma)
    
    def _cond_error(self, p, q):
        return np.sum(abs(norm(self.X, 1, axis=1) - p)) + np.sum(abs(norm(self.X, 1, axis=0) - q))
    
    def _new_fi(self, C, p, q):
        xx = 0
        for i in range(self.n):
            for j in range(self.n):
                xx += np.exp(-(C[i,j] + self.l[i] + self.m[j] + self.gamma) / self.gamma)
        return (np.sum([self.l[i]*p[i] for i in range(self.n)]) + np.sum([self.m[j]*q[j] for j in range(self.n)]) + self.gamma * xx)
    
    def _new_f(self, C):
        return np.sum([C[i, j] * self.X[i,j] for i in range(self.n) for j in range(self.n)])
    
    def fit(self, C, p, q):
        tt = 0
        while True:
            self.X_k = self.X.copy()
            tt += 1
            t = 0
            while True:
                t += 1
                self._new_lag(C, p, q)
                self._new_X(C, p, q)
                if self._cond_error(p, q) <= self.epsilon:
                    break
            self.fi = self._new_fi(C, p, q)
            self.f = self._new_f(C)
            if self.f - self.fi <= self.proxy_epsilon:
                return self.X, t, tt