In [1]:
import numpy as np

In [2]:
class DualGradientDescent:
    def __init__(self, gamma, epsilon, n):
        self.gamma   = gamma
        self.epsilon = epsilon
        self.n       = n
        self.small     = 1e-20
        
        self.lam = np.zeros(n)
        self.mu     = np.zeros(n)
        self.x_sum  = np.zeros([n, n])
        self.x_0    = np.ones([n, n]) / (n**2)
    
    def f(self, x):
        return (self.c * x).sum() + self.gamma * ((x + self.small) * np.log((x + self.small) / self.x_0)).sum()
    
    def phi(self, lam, mu, n):
        return (lam * self.p).sum() + (mu * self.q).sum() + \
                self.gamma * np.log(1/np.e * (self.x_0 * np.exp(
                    -(self.gamma + self.c + lam.repeat(n).reshape(-1, n) + mu.repeat(n).reshape(-1, n).T) / self.gamma
                )).sum())
        
    def x_hat(self, lam, mu, n):
        x_hat = self.x_0 * np.exp(
            -(self.gamma + self.c + lam.repeat(n).reshape(-1, n) + mu.repeat(n).reshape(-1, n).T) / self.gamma
        )
        return x_hat / x_hat.sum()
    
    def _new_lm(self, p, q):
        x_hat = self.x_hat(self.lam, self.mu, self.n)
        return self.lam - self.gamma * (p - x_hat.sum(1)),\
               self.mu - self.gamma * (q - x_hat.sum(0))
        
    def x_sum_update(self):
        self.x_sum += self.x_hat(self.lam, self.mu, self.n)
        
    def _new_x_wave(self, k):
        return self.x_sum * 1/k
    
    def deviation_p_q(self, x, p, q):
        return np.sqrt(np.sum((x.sum(1) - p)**2) + np.sum((x.sum(0) - q)**2))
    
    def fit(self, c, p, q):
        self.c, self.p, self.q = c, p, q
        
        k = 1
        while True:
            self.lam, self.mu = self._new_lm(self.p, self.q)
            self.x_sum_update()
            self.x_wave = self._new_x_wave(k)
            R = np.sqrt(np.linalg.norm(self.lam) + np.linalg.norm(self.mu))
            epsilon_wave = self.epsilon / R
            
            criteria_a = self.deviation_p_q(self.x_wave, self.p, self.q) < epsilon_wave
            criteria_b = self.f(self.x_wave) + self.phi(self.lam, self.mu, self.n) < self.epsilon
            
            if k % 100 == 0:
                print(f'iteration {k}:   criteria 1 = {round(self.deviation_p_q(self.x_wave, self.p, self.q), 7)}, ' + \
                                     f'criteria 2 = {round(self.f(self.x_wave) + self.phi(self.lam, self.mu, self.n), 7)}')
            
            if criteria_a and criteria_b:
                return self.x_wave, k
            
            k += 1

In [3]:
def sample_batch(n):
    c = np.random.uniform(0, 10, size=[n, n])
    p = np.random.dirichlet(np.ones(n), size=1).ravel()
    q = np.random.dirichlet(np.ones(n), size=1).ravel()
    return c, p, q

In [4]:
n = 40
c, p, q = sample_batch(n)
x, k = DualGradientDescent(gamma=0.005, epsilon=1e-2, n=n).fit(c, p, q)

iteration 100:   criteria 1 = 0.4262442, criteria 2 = -0.1004958
iteration 200:   criteria 1 = 0.3510085, criteria 2 = -0.1328056
iteration 300:   criteria 1 = 0.307726, criteria 2 = -0.1511445
iteration 400:   criteria 1 = 0.274885, criteria 2 = -0.1603802
iteration 500:   criteria 1 = 0.2545935, criteria 2 = -0.1712263
iteration 600:   criteria 1 = 0.2375244, criteria 2 = -0.1779512
iteration 700:   criteria 1 = 0.2243085, criteria 2 = -0.1846188
iteration 800:   criteria 1 = 0.213792, criteria 2 = -0.191193
iteration 900:   criteria 1 = 0.2054145, criteria 2 = -0.1982152
iteration 1000:   criteria 1 = 0.197571, criteria 2 = -0.2033276
iteration 1100:   criteria 1 = 0.1900756, criteria 2 = -0.2066486
iteration 1200:   criteria 1 = 0.1833871, criteria 2 = -0.2095102
iteration 1300:   criteria 1 = 0.177681, criteria 2 = -0.2127582
iteration 1400:   criteria 1 = 0.1728873, criteria 2 = -0.2166361
iteration 1500:   criteria 1 = 0.1686782, criteria 2 = -0.2204825
iteration 1600:   criteri

iteration 13100:   criteria 1 = 0.0635386, criteria 2 = -0.2601402
iteration 13200:   criteria 1 = 0.0633121, criteria 2 = -0.2602139
iteration 13300:   criteria 1 = 0.0630894, criteria 2 = -0.2602987
iteration 13400:   criteria 1 = 0.0628706, criteria 2 = -0.2603943
iteration 13500:   criteria 1 = 0.0626555, criteria 2 = -0.2605002
iteration 13600:   criteria 1 = 0.0624439, criteria 2 = -0.2606164
iteration 13700:   criteria 1 = 0.0622359, criteria 2 = -0.2607424
iteration 13800:   criteria 1 = 0.0620313, criteria 2 = -0.2608783
iteration 13900:   criteria 1 = 0.0618301, criteria 2 = -0.2610238
iteration 14000:   criteria 1 = 0.0616321, criteria 2 = -0.261179
iteration 14100:   criteria 1 = 0.0614373, criteria 2 = -0.2613438
iteration 14200:   criteria 1 = 0.0612457, criteria 2 = -0.2615181
iteration 14300:   criteria 1 = 0.0610572, criteria 2 = -0.2617019
iteration 14400:   criteria 1 = 0.0608718, criteria 2 = -0.2618951
iteration 14500:   criteria 1 = 0.0606894, criteria 2 = -0.2620

iteration 25900:   criteria 1 = 0.0438473, criteria 2 = -0.2397952
iteration 26000:   criteria 1 = 0.0437362, criteria 2 = -0.2394974
iteration 26100:   criteria 1 = 0.0436259, criteria 2 = -0.239203
iteration 26200:   criteria 1 = 0.0435165, criteria 2 = -0.238912
iteration 26300:   criteria 1 = 0.043408, criteria 2 = -0.2386244
iteration 26400:   criteria 1 = 0.0433002, criteria 2 = -0.2383404
iteration 26500:   criteria 1 = 0.0431934, criteria 2 = -0.2380597
iteration 26600:   criteria 1 = 0.0430873, criteria 2 = -0.2377826
iteration 26700:   criteria 1 = 0.0429821, criteria 2 = -0.2375089
iteration 26800:   criteria 1 = 0.0428777, criteria 2 = -0.2372386
iteration 26900:   criteria 1 = 0.0427741, criteria 2 = -0.2369718
iteration 27000:   criteria 1 = 0.0426713, criteria 2 = -0.2367084
iteration 27100:   criteria 1 = 0.0425693, criteria 2 = -0.2364484
iteration 27200:   criteria 1 = 0.0424681, criteria 2 = -0.2361918
iteration 27300:   criteria 1 = 0.0423676, criteria 2 = -0.235938

iteration 38600:   criteria 1 = 0.0335297, criteria 2 = -0.2079809
iteration 38700:   criteria 1 = 0.0334584, criteria 2 = -0.2076319
iteration 38800:   criteria 1 = 0.0333874, criteria 2 = -0.2072837
iteration 38900:   criteria 1 = 0.0333166, criteria 2 = -0.2069363
iteration 39000:   criteria 1 = 0.0332461, criteria 2 = -0.20659
iteration 39100:   criteria 1 = 0.0331759, criteria 2 = -0.2062449
iteration 39200:   criteria 1 = 0.033106, criteria 2 = -0.2059011
iteration 39300:   criteria 1 = 0.0330364, criteria 2 = -0.2055587
iteration 39400:   criteria 1 = 0.0329671, criteria 2 = -0.2052176
iteration 39500:   criteria 1 = 0.0328982, criteria 2 = -0.204878
iteration 39600:   criteria 1 = 0.0328296, criteria 2 = -0.2045398
iteration 39700:   criteria 1 = 0.0327612, criteria 2 = -0.2042032
iteration 39800:   criteria 1 = 0.0326932, criteria 2 = -0.203868
iteration 39900:   criteria 1 = 0.0326256, criteria 2 = -0.2035344
iteration 40000:   criteria 1 = 0.0325582, criteria 2 = -0.2032023


KeyboardInterrupt: 