In [1]:
import numpy as np

In [2]:
class DualGradientDescent:
    def __init__(self, gamma, epsilon, n):
        self.gamma   = gamma
        self.epsilon = epsilon
        self.n       = n
        
        self.lam = np.zeros(n)
        self.mu     = np.zeros(n)
        self.x_sum  = 0
        self.x_0    = np.ones([n, n]) / (n**2)
    
        
        p = np.random.dirichlet(np.ones(n), size=1).ravel()
        q = np.random.dirichlet(np.ones(n), size=1).ravel()
        
    def f(self, x):
        return (self.c * x).sum() + self.gamma * (x * np.log(x / self.x_0)).sum()
    
    def phi(self, lam, mu, n):
        return (lam * self.p).sum() + (mu * self.q).sum() + \
                self.gamma * np.log(1/np.e * (self.x_0 * np.exp(
                    -(self.gamma + self.c + lam.repeat(n).reshape(-1, n) + mu.repeat(n).reshape(-1, n).T) / self.gamma
                )).sum())
        
    def x_hat(self, lam, mu, n):
        x_hat = self.x_0 * np.exp(
            -(self.gamma + self.c + lam.repeat(n).reshape(-1, n) + mu.repeat(n).reshape(-1, n).T) / self.gamma
        )
        return x_hat / x_hat.sum()
    
    def _new_lm(self, p, q):
        x_hat = self.x_hat(self.lam, self.mu, self.n)
        return self.lam - self.gamma * (p - x_hat.sum(1)),\
               self.mu - self.gamma * (q - x_hat.sum(0))
        
    def x_sum_update(self):
        self.x_sum += self.x_hat(self.lam, self.mu, self.n)
        
    def _new_x_wave(self, k):
        return self.x_sum * 1/k
    
    def deviation_p_q(self, x, p, q):
        return np.sqrt(np.sum((x.sum(1) - p)**2) + np.sum((x.sum(0) - q)**2))
    
    def fit(self, c, p, q):
        self.c, self.p, self.q = c, p, q
        
        k = 1
        while True:
            self.lam, self.mu = self._new_lm(self.p, self.q)
            self.x_sum_update()
            self.x_wave = self._new_x_wave(k)
            R = np.sqrt(np.linalg.norm(self.lam) + np.linalg.norm(self.mu))
            epsilon_wave = self.epsilon / R
            
            criteria_a = self.deviation_p_q(self.x_wave, self.p, self.q) < epsilon_wave
            criteria_b = self.f(self.x_wave) + self.phi(self.lam, self.mu, self.n) < self.epsilon
            
            if k % 100 == 0:
                print(f'iteration {k}:   criteria 1 = {round(self.deviation_p_q(self.x_wave, self.p, self.q), 7)}, ' + \
                                     f'criteria 2 = {round(self.f(self.x_wave) + self.phi(self.lam, self.mu, self.n), 7)}')
            
            if criteria_a and criteria_b:
                return self.x_wave, k
            
            k += 1

In [3]:
def sample_batch(n):
    c = np.random.uniform(0, 10, size=[n, n])
    p = np.random.dirichlet(np.ones(n), size=1).ravel()
    q = np.random.dirichlet(np.ones(n), size=1).ravel()
    return c, p, q

In [4]:
n = 4
c, p, q = sample_batch(n)
x, k = DualGradientDescent(gamma=0.1, epsilon=1e-2, n=n).fit(c, p, q)

iteration 100:   criteria 1 = 0.2776446, criteria 2 = -0.9580941
iteration 200:   criteria 1 = 0.1571035, criteria 2 = -0.6667158
iteration 300:   criteria 1 = 0.1143277, criteria 2 = -0.5631482
iteration 400:   criteria 1 = 0.093478, criteria 2 = -0.5187149
iteration 500:   criteria 1 = 0.0813728, criteria 2 = -0.4987456
iteration 600:   criteria 1 = 0.0735698, criteria 2 = -0.4911919
iteration 700:   criteria 1 = 0.0681764, criteria 2 = -0.4908029
iteration 800:   criteria 1 = 0.0642562, criteria 2 = -0.4949251
iteration 900:   criteria 1 = 0.061296, criteria 2 = -0.5020726
iteration 1000:   criteria 1 = 0.0589924, criteria 2 = -0.5113475
iteration 1100:   criteria 1 = 0.0571536, criteria 2 = -0.5221459
iteration 1200:   criteria 1 = 0.055598, criteria 2 = -0.5331957
iteration 1300:   criteria 1 = 0.0535068, criteria 2 = -0.5333029
iteration 1400:   criteria 1 = 0.0503838, criteria 2 = -0.5174254
iteration 1500:   criteria 1 = 0.0472949, criteria 2 = -0.4993001
iteration 1600:   crit

iteration 14100:   criteria 1 = 0.005057, criteria 2 = -0.2328186
iteration 14200:   criteria 1 = 0.0050214, criteria 2 = -0.2325886
iteration 14300:   criteria 1 = 0.0049863, criteria 2 = -0.2323618
iteration 14400:   criteria 1 = 0.0049517, criteria 2 = -0.2321382
iteration 14500:   criteria 1 = 0.0049175, criteria 2 = -0.2319177
iteration 14600:   criteria 1 = 0.0048839, criteria 2 = -0.2317001
iteration 14700:   criteria 1 = 0.0048506, criteria 2 = -0.2314855
iteration 14800:   criteria 1 = 0.0048179, criteria 2 = -0.2312738
iteration 14900:   criteria 1 = 0.0047855, criteria 2 = -0.231065
iteration 15000:   criteria 1 = 0.0047536, criteria 2 = -0.2308589
iteration 15100:   criteria 1 = 0.0047221, criteria 2 = -0.2306555
iteration 15200:   criteria 1 = 0.0046911, criteria 2 = -0.2304548
iteration 15300:   criteria 1 = 0.0046604, criteria 2 = -0.2302567
iteration 15400:   criteria 1 = 0.0046301, criteria 2 = -0.2300612
iteration 15500:   criteria 1 = 0.0046003, criteria 2 = -0.22986

In [5]:
x, k

(array([[  1.58094222e-06,   5.16658885e-33,   4.56296538e-02,
           6.03921059e-04],
        [  6.24949303e-26,   3.77502695e-03,   4.87886504e-02,
           3.13048681e-01],
        [  9.39614887e-03,   3.20481079e-01,   7.97871231e-05,
           2.17367717e-02],
        [  1.92459938e-14,   2.36442617e-01,   1.60819380e-05,
           5.07543573e-31]]), 22800)