In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
class FastGrad:
    def __init__(self, c, gamma, p=None, q=None, epochs=1000, debug=False):
        """
        :param c: transport costs
        :param gamma: regularizer multiplier
        :param epochs: epochs amount
        :param debug: allow debug output
        """
        
        # Transport problem variables
        self.c = c # matrix
        self.p = p or np.rand(c.shape[0])
        self.q = q or np.rand(c.shape[1])
        self.x = np.random.rand([epochs] + list(c.shape)) # matrix
        self.gamma = gamma
        
        # Dual problem variables
        self.lambda = np.zeros([epochs] + list(c.shape[0]))
        self.mu = np.zeros([epochs] + list(c.shape[0]))
        
        # Nesterov momentum variables
        self.alpha = np.zeros(epochs) # scalar
        self.A = np.zeros(epochs) # scalar
        self.y = np.zeros(shape=([epochs] + list(c.shape)) # matrix
        self.U = self.x.copy() # matrix
        self.L = 0 # scalar
        self.epochs = 1
        self.max_epochs = epochs
        self.debug = debug
        
    def func(self, epoch):
        return np.sum(self.x[epoch]*self.c) +\
                          self.gamma * np.sum(self.x[epoch]*np.log(self.x[epoch]/self.c.shape[0]**2))
    
    def dual_func(self, epoch):
        return self.func(epoch) + np.sum(self.lambda[epoch]*(self.p - )) +\
                          np.sum(self.mu[epoch]*(self.q - ))
    
    def run(self):
        while self.epochs < self.max_epochs:
            self.alpha[self.epochs] = 1/(2*self.L) + np.sqrt(1/(L**2*4) + self.alpha[self.epochs-1]**2)
            
            self.A[self.epochs] = A[self.epochs-1] + self.alpha[self.epochs]
            
            self.y[self.epochs] = (self.alpha[self.epochs]*self.U[self.epochs-1] +\
                self.A[self.epochs-1]*self.x[self.epochs-1])/self.A[self.epochs-1]
            
            self.U[self.epochs] = self.y[self.epochs] -\
                self.alpha[self.epochs-1] * self.gradient(func(self.y[self.epochs]))
            # TODO: change gradient func input
                
            self.x[self.epochs] = (self.alpha[self.epochs]*self.U[self.epochs] +\
                self.A[self.epochs-1]*self.x[self.epochs-1]) / self.A[self.epochs]
            
            if self.debug:
                print("Epoch {self.epochs}:\nx={self.x}".format(self=self))
            
    # TODO: some tests
    # TODO: stopping criteria
    # TODO: calc average steps amount
    # TODO: lambda :- u, mu :- v
