In [1]:
import numpy as np
np.seterr(all='raise')

{'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}

In [13]:
class FastGradMethod:
    def __init__(self, gamma, n, epsilon):
        self.alpha = 0
        self.a     = 0
        
        # constants
        self.epsilon = epsilon
        self.gamma = gamma
        self.l     = 1 / gamma
        self.n     = nk
        
        self.u_lambda = np.ones([n])
        self.u_mu     = np.ones([n])
         
        self.y_lambda = np.zeros(n)
        self.y_mu     = np.zeros(n)
        
        self.x_lambda   = np.ones([n])
        self.x_mu       = np.ones([n])
        
        self.alpha_x_sum = 0
        
        self.x_0 = np.ones([n, n]) / (n**2)
        
    def x_hat(self, c):
        n = self.n
        
        _c = np.min(self.gamma + c + self.y_lambda.repeat(n).reshape(-1, n) + self.y_mu.repeat(n).reshape(-1, n).T)
        
        try:
            b = self.x_0 * np.exp(
                -(self.gamma + c + self.y_lambda.repeat(n).reshape(-1, n) + self.y_mu.repeat(n).reshape(-1, n).T - _c) / self.gamma
            )
        except:
            b = np.ones([n, n]) * 1e-32
            
        return b / b.sum()
    
    def x_wave(self, c):
        # return self.alpha_x_sum * 1 / self.a
        return self.x_hat(c)
    
    def __new_alpha(self):
        return 1 / (2 * self.l) + np.sqrt(1 / (4 * (self.l**2)) + self.alpha**2)
    
    def __new_a(self, new_alpha):
        return self.a + self.__new_alpha()
    
    def __new_y(self):
        new_alpha = self.__new_alpha()
        new_a = self.__new_a(new_alpha)
        return (new_alpha * self.u_lambda + self.a * self.x_lambda) / new_a,\
               (new_alpha * self.u_mu + self.a * self.x_mu) / new_a
    
    def __new_u(self, c, p, q):
        x_hat = self.x_hat(c)
        return self.u_lambda - self.__new_alpha() * (p - x_hat.sum(1)),\
               self.u_mu - self.__new_alpha() * (q - x_hat.sum(0))
        
    def __new_x(self):
        new_alpha = self.__new_alpha()
        new_a = self.__new_a(new_alpha)
        return (new_alpha * self.u_lambda + self.a * self.x_lambda) / new_a,\
               (new_alpha * self.u_mu + self.a * self.x_mu) / new_a
    
    def f(self, c, x_wave):
        return np.sum(c * x_wave) + self.gamma * np.sum(x_wave * np.log(x_wave / self.x_0))
        
    def phi(self, c, lambda_x, mu_x):
        try:
            n = self.x_0 * np.exp(
                -(self.gamma + c + lambda_x.repeat(n).reshape(-1, n) + mu_x.repeat(n).reshape(-1, n).T) / self.gamma
            )
        except:
            n = np.ones([self.n, self.n])
            
        return -np.sum(lambda_x * p) - np.sum(mu_x * q) - self.gamma * np.log(1 / np.exp(1) * np.sum(n))
    
    def fit(self, c, p, q):
        k = 0

        while True:
            self.y_lambda, self.y_mu = self.__new_y()  
            
            self.u_lambda, self.u_mu = self.__new_u(c, p, q)            
            self.x_lambda, self.x_mu = self.__new_x()
            
            self.alpha = self.__new_alpha()
            self.a     = self.__new_a(self.alpha)
            
            self.alpha_x_sum += self.alpha * self.x_hat(c)
            
            x_wave = self.x_wave(c)
            r = np.sqrt((self.x_lambda**2).sum() + (self.x_mu**2).sum())
            epsilon_wave = self.epsilon / r
            
            criteria_a = np.sqrt(np.sum((x_wave.sum(1) - p)**2) + np.sum((x_wave.sum(0) - q)**2)) <= epsilon_wave
            criteria_b = self.f(c, x_wave) - self.phi(c, p, q) <= self.epsilon
            
            if k % 100 == 0:
                print(f"k = {k}, criteria a: {np.sqrt(np.sum((x_wave.sum(1) - p)**2) + np.sum((x_wave.sum(0) - q)**2))}, criteria b: {self.f(c, x_wave) - self.phi(c, p, q)}")
            
            k+=1
              
            if criteria_a and criteria_b or k == 2000:
                return k
            

In [26]:
fgrad = FastGradMethod(0.1, 8, 0.01)

In [27]:
def sample_batch(n):
    C = np.random.uniform(0, 10, size=[n, n])
    p = np.random.dirichlet(np.ones(n), size=1).ravel()
    q = np.random.dirichlet(np.ones(n), size=1).ravel()
    return C, p, q

In [28]:
c, p, q = sample_batch(8)
fgrad.fit(c, p, q)

k = 0, criteria a: 1.238166031161233, criteria b: 1.1911012725243264
k = 100, criteria a: 0.008625550116393172, criteria b: 2.4766502888439783
k = 200, criteria a: 0.0033471772521312428, criteria b: 2.472398638510912
k = 300, criteria a: 0.007370596381021022, criteria b: 2.498597322957443
k = 400, criteria a: 0.0035537234922677343, criteria b: 2.4781662932720847
k = 500, criteria a: 0.0021923977160132366, criteria b: 2.486959272392048
k = 600, criteria a: 0.0014405156928363888, criteria b: 2.4782432671883
k = 700, criteria a: 0.0008353711662790052, criteria b: 2.4836187649128942
k = 800, criteria a: 0.0003109707516923592, criteria b: 2.480425975816598
k = 900, criteria a: 0.00032528354154451897, criteria b: 2.4808382208774065
k = 1000, criteria a: 0.0005037664654451969, criteria b: 2.4827214524304377
k = 1100, criteria a: 0.0004015897237978254, criteria b: 2.4803703405334057
k = 1200, criteria a: 0.0002198742129105957, criteria b: 2.481738116895079
k = 1300, criteria a: 5.4987464689552

2000

In [6]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
4c, p, q = sample_batch(5)
gammas = np.linspace(0.25, 2, 50)
plt.plot(gammas, [FastGradMethod(gamma, 5, 0.01).fit(c, p, q) for gamma in gammas])