In [None]:
import numpy as np

In [None]:
class FastGradientDescent:
    def __init__(self, gamma, n=5, epsilon=0.001):
        self.alpha = 0
        self.a     = 0
        
        # constants
        self.epsilon = epsilon
        self.gamma = gamma
        self.l     = 1 / gamma
        self.n     = n
        
        self.u_lambda = np.ones(n) / n**2
        self.u_mu     = np.ones(n) / n**2
         
        self.y_lambda = np.zeros(n)
        self.y_mu     = np.zeros(n)
        
        self.x_lambda   = np.zeros([n])
        self.x_mu       = np.zeros([n])
        
        self.x_0 = np.ones([n, n]) / (n**2)
        
    def __get_x(self, c):
        a = self.x_0 * np.exp(-(self.gamma + c + self.x_lambda.repeat(self.n).reshape(-1, self.n) +\
                                self.x_mu.repeat(self.n).reshape(-1, self.n).T))
        
        return a / a.sum()
    
    def __new_alpha(self):
        return 1 / (2 * self.l) + np.sqrt(1 / (4 * (self.l**2)) + self.alpha**2)
    
    def __new_a(self):
        return self.a + self.__new_alpha()
    
    def __new_y(self):
        return (self.__new_alpha() * self.u_lambda + self.a * self.x_lambda) / self.__new_a(),\
               (self.__new_alpha() * self.u_mu + self.a * self.x_mu) / self.__new_a()
        
        
    def __new_u(self, c, p, q):
        x_hat = self.__get_x(c)
        
        return self.u_lambda - self.alpha * (p - x_hat.sum(1)),\
               self.u_mu - self.alpha * (q - x_hat.sum(0))
    
    def __new_x(self):
        return (self.alpha * self.u_lambda + self.a * self.x_lambda) / self.__new_a(),\
                (self.alpha * self.u_mu + self.a * self.x_mu) / self.__new_a()
    
    def __get_phi(self, c, p, q):
        np.exp(-(self.gamma + c + self.x_lambda.repeat(self.n).reshape(-1, self.n) \
                 + self.x_mu.repeat(self.n).reshape(-1, self.n).T)/self.gamma).sum()
        
        # TODO: x, y or u
        return -(self.x_lambda * p).sum() - (self.x_mu * q).sum() + self.gamma * np.log(1 / mp.exp(1) \
                 * np.exp(-(self.gamma + c + self.x_lambda.repeat(self.n).reshape(-1, self.n) \
                 + self.x_mu.repeat(self.n).reshape(-1, self.n).T)/self.gamma).sum())
    
    def fit(self, c, p, q):
        k = 0
        while True:
            k+=1
            self.y_lambda, self.y_mu = self.__new_y()
            self.u_lambda, self.u_mu = self.__new_u(c, p, q)
            self.x_lambda, self.x_mu = self.__new_x()
            
            self.alpha = self.__new_alpha()
            self.a     = self.__new_a()
            
            x_hat = self.__get_x(c)
            
            
            criteria_a = (((p - x_hat.sum(1))**2).sum() + ((q - x_hat.sum(0))**2).sum())**(1/2) <\
                self.epsilon / np.sqrt((self.x_lambda**2).sum() + (self.x_mu**2).sum())
            
            if criteria_a:
                return self.__get_x(c), k

In [None]:
fgrad = FastGradientDescent(0.3)

In [None]:
def sample_batch(n):
    C = np.random.uniform(0, 10, size=[n, n])
    p = np.random.dirichlet(np.ones(5), size=1).ravel()
    q = np.random.dirichlet(np.ones(5), size=1).ravel()
    return C, p, q

In [None]:
def fastgrad(c, p, q, gamma):
    fgrad = FastGradientDescent(gamma)
    return fgrad.fit(c, p, q)

In [None]:
gammas = np.linspace(0.001, 0.5, num=20)    
n_exp = 5
    
n_iterations = list()
for gamma in gammas:
    for _ in range(n_exp):
        c, p, q = sample_batch(5)
        n_iterations.append(fastgrad(c, p, q, gamma)[-1])

n_iterations = np.array(n_iterations).reshape(-1, n_exp).sum(-1)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.plot(n_iterations)