In [37]:
import numpy as np
np.seterr(all='raise')

{'divide': 'raise', 'invalid': 'raise', 'over': 'raise', 'under': 'raise'}

In [185]:
class FastGradMethod:
    def __init__(self, gamma, n, epsilon):
        self.alpha = 0
        self.a     = 0
        
        # constants
        self.epsilon = epsilon
        self.gamma = gamma
        self.l     = 1 / gamma
        self.n     = n
        
        self.u_lambda = np.ones([n])
        self.u_mu     = np.ones([n])
         
        self.y_lambda = np.zeros(n)
        self.y_mu     = np.zeros(n)
        
        self.x_lambda   = np.ones([n])
        self.x_mu       = np.ones([n])
        
        self.alpha_x_sum = 0
        
        self.x_0 = np.ones([n, n]) / (n**2)
        
    def x_hat(self, c):
        n = self.n
        b = self.x_0 * np.exp(
            -(self.gamma + c + self.y_lambda.repeat(n).reshape(-1, n) + self.y_mu.repeat(n).reshape(-1, n).T) / self.gamma
        )  
        return b / b.sum()
    
    def x_wave(self, c):
        return self.alpha_x_sum * 1 / self.a
    
    def __new_alpha(self):
        return 1 / (2 * self.l) + np.sqrt(1 / (4 * (self.l**2)) + self.alpha**2)
    
    def __new_a(self, new_alpha):
        return self.a + self.__new_alpha()
    
    def __new_y(self):
        new_a = self.__new_a(self.__new_alpha())
        return (self.__new_alpha() * self.u_lambda + self.a * self.x_lambda) / new_a,\
               (self.__new_alpha() * self.u_mu + self.a * self.x_mu) / new_a
    
    def __new_u(self, c, p, q):
        x_hat = self.x_hat(c)
        return self.u_lambda - self.__new_alpha() * (p - x_hat.sum(1)),\
               self.u_mu - self.__new_alpha() * (q - x_hat.sum(0))
        
    def __new_x(self):
        new_alpha = self.__new_alpha()
        new_a = self.__new_a(new_alpha)
        return (new_alpha * self.u_lambda + self.a * self.x_lambda) / new_a,\
               (new_alpha * self.u_mu + self.a * self.x_mu) / new_a
    
    def f(self, c, x_wave):
        return np.sum(c * x_wave) + self.gamma * np.sum(x_wave * np.log(x_wave / self.x_0))
        
    def phi(self, c, lambda_x, mu_x):
        n = self.n
        return -np.sum(lambda_x * p) - np.sum(mu_x * q) +\
        self.gamma * np.log(1 / np.exp(1) * np.sum(
            self.x_0 * np.exp(
                -(self.gamma + c + lambda_x.repeat(n).reshape(-1, n) + mu_x.repeat(n).reshape(-1, n).T) / self.gamma
            )
        ))
    
    def fit(self, c, p, q):
        k = 0
        while True:
            k+=1
            self.y_lambda, self.y_mu = self.__new_y()          
            self.u_lambda, self.u_mu = self.__new_u(c, p, q)            
            self.x_lambda, self.x_mu = self.__new_x()
            
            self.alpha = self.__new_alpha()
            self.a     = self.__new_a(self.alpha)
            
            self.alpha_x_sum += self.alpha * self.x_hat(c)
            
            x_wave = self.x_wave(c)
            r = np.sqrt((self.x_lambda**2).sum() + (self.x_mu**2).sum())
            epsilon_wave = self.epsilon / r
            
            criteria_a = np.sqrt(np.sum((x_wave.sum(1) - p)**2) + np.sum((x_wave.sum(0) - q)**2)) <= epsilon_wave
            criteria_b = self.f(c, x_wave) - self.phi(c, p, q) <= self.epsilon
            
            if k % 100 == 0:
                print(f"k = {k}, criteria a: {np.sqrt(np.sum((x_wave.sum(1) - p)**2) + np.sum((x_wave.sum(0) - q)**2))}, criteria b: {self.f(c, x_wave) - self.phi(c, p, q)}")
            
            if criteria_a and criteria_b or k == 5000:
                return k

In [186]:
fgrad = FastGradMethod(10, 8, 0.01)

In [187]:
def sample_batch(n):
    C = np.random.uniform(0, 10, size=[n, n])
    p = np.random.dirichlet(np.ones(n), size=1).ravel()
    q = np.random.dirichlet(np.ones(n), size=1).ravel()
    return C, p, q

In [188]:
c, p, q = sample_batch(8)

In [189]:
fgrad.fit(c,p,q)

k = 100, criteria a: 0.012027513330847521, criteria b: 33.044464154267146
k = 200, criteria a: 0.005978269909690961, criteria b: 33.255868815104506
k = 300, criteria a: 0.003977919371002246, criteria b: 33.32328368074158
k = 400, criteria a: 0.0029821784743357693, criteria b: 33.35608903853376
k = 500, criteria a: 0.0023855564143676817, criteria b: 33.37551748876457
k = 600, criteria a: 0.001987761688306545, criteria b: 33.38840158482401
k = 700, criteria a: 0.0017037001446300335, criteria b: 33.39755675113966
k = 800, criteria a: 0.0014907486344428002, criteria b: 33.404389483745284
k = 900, criteria a: 0.0013251281605666662, criteria b: 33.40968901401862
k = 1000, criteria a: 0.0011926238068406534, criteria b: 33.413920628107
k = 1100, criteria a: 0.001084219442177453, criteria b: 33.417375742948146
k = 1200, criteria a: 0.0009938877847337072, criteria b: 33.42025005274214
k = 1300, criteria a: 0.0009174512707286908, criteria b: 33.42267929041586
k = 1400, criteria a: 0.0008519332677

5000