In [1]:
import numpy as np

In [2]:
class DualGradientDescent:
    def __init__(self, gamma, epsilon, n):
        self.gamma   = gamma
        self.epsilon = epsilon
        self.n       = n
        
        self.lam = np.zeros(n)
        self.mu     = np.zeros(n)
        self.x_sum  = 0
        self.x_0    = np.ones([n, n]) / (n**2)
    
    def f(self, x):
        return (self.c * x).sum() + self.gamma * (x * np.log(x / self.x_0)).sum()
    
    def phi(self, lam, mu, n):
        return (lam * self.p).sum() + (mu * self.q).sum() + \
                self.gamma * np.log(1/np.e * (self.x_0 * np.exp(
                    -(self.gamma + self.c + lam.repeat(n).reshape(-1, n) + mu.repeat(n).reshape(-1, n).T) / self.gamma
                )).sum())
        
    def x_hat(self, lam, mu, n):
        x_hat = self.x_0 * np.exp(
            -(self.gamma + self.c + lam.repeat(n).reshape(-1, n) + mu.repeat(n).reshape(-1, n).T) / self.gamma
        )
        return x_hat / x_hat.sum()
    
    def _new_lm(self, p, q):
        x_hat = self.x_hat(self.lam, self.mu, self.n)
        return self.lam - self.gamma * (p - x_hat.sum(1)),\
               self.mu - self.gamma * (q - x_hat.sum(0))
        
    def x_sum_update(self):
        self.x_sum += self.x_hat(self.lam, self.mu, self.n)
        
    def _new_x_wave(self, k):
        return self.x_sum * 1/k
    
    def deviation_p_q(self, x, p, q):
        return np.sqrt(np.sum((x.sum(1) - p)**2) + np.sum((x.sum(0) - q)**2))
    
    def fit(self, c, p, q):
        self.c, self.p, self.q = c, p, q
        
        k = 1
        while True:
            self.lam, self.mu = self._new_lm(self.p, self.q)
            self.x_sum_update()
            self.x_wave = self._new_x_wave(k)
            R = np.sqrt(np.linalg.norm(self.lam) + np.linalg.norm(self.mu))
            epsilon_wave = self.epsilon / R
            
            criteria_a = self.deviation_p_q(self.x_wave, self.p, self.q) < epsilon_wave
            criteria_b = self.f(self.x_wave) + self.phi(self.lam, self.mu, self.n) < self.epsilon
            
            if k % 100 == 0:
                print(f'iteration {k}:   criteria 1 = {round(self.deviation_p_q(self.x_wave, self.p, self.q), 7)}, ' + \
                                     f'criteria 2 = {round(self.f(self.x_wave) + self.phi(self.lam, self.mu, self.n), 7)}')
            
            if criteria_a and criteria_b:
                return self.x_wave, k
            
            k += 1

In [3]:
def sample_batch(n):
    c = np.random.uniform(0, 10, size=[n, n])
    p = np.random.dirichlet(np.ones(n), size=1).ravel()
    q = np.random.dirichlet(np.ones(n), size=1).ravel()
    return c, p, q

In [4]:
n = 3
c, p, q = sample_batch(n)
x, k = DualGradientDescent(gamma=0.1, epsilon=1e-2, n=n).fit(c, p, q)

iteration 100:   criteria 1 = 0.2178951, criteria 2 = -0.6836377
iteration 200:   criteria 1 = 0.1076421, criteria 2 = -0.4427657
iteration 300:   criteria 1 = 0.0717497, criteria 2 = -0.3624984
iteration 400:   criteria 1 = 0.0538121, criteria 2 = -0.3221349
iteration 500:   criteria 1 = 0.0430497, criteria 2 = -0.2978355
iteration 600:   criteria 1 = 0.0358747, criteria 2 = -0.2816013
iteration 700:   criteria 1 = 0.0307498, criteria 2 = -0.2699883
iteration 800:   criteria 1 = 0.026906, criteria 2 = -0.261269
iteration 900:   criteria 1 = 0.0239165, criteria 2 = -0.2544818
iteration 1000:   criteria 1 = 0.0215248, criteria 2 = -0.2490483
iteration 1100:   criteria 1 = 0.019568, criteria 2 = -0.2446004
iteration 1200:   criteria 1 = 0.0179374, criteria 2 = -0.2408922
iteration 1300:   criteria 1 = 0.0165576, criteria 2 = -0.2377532
iteration 1400:   criteria 1 = 0.0153749, criteria 2 = -0.2350619
iteration 1500:   criteria 1 = 0.0143499, criteria 2 = -0.2327287
iteration 1600:   crit

In [5]:
x, k

(array([[  5.08155564e-16,   3.01431401e-01,   3.57091856e-02],
        [  7.17533040e-02,   2.33849364e-01,   1.14722261e-24],
        [  3.72051173e-28,   3.57212495e-01,   4.42501530e-05]]), 3869)