In [59]:
import torch 
from torch.optim import Optimizer 
from torch import Tensor
from typing import List, Optional
from autograd import grad, jacobian, elementwise_grad
import numpy as np
import matplotlib.pyplot as plt
import sys, os

In [60]:
class BaseSaddle(object):
    def __init__(self):
        self.xopt = None
        self.yopt = None
        self.xrange = None
        self.yrange = None
        self.f = None
        self.g = None 
        self.dfdx = grad(self.f)
        self.dfdy = 0
        
        self.dgdx = 0
        self.dgdy = grad(self.g, 1) 
        # self.d2fdxdx = grad(self.dfdx)
        # self.d2fdydy = 0
        # self.d2fdxdy = grad(self.dfdx, 1)
        # self.d2fdydx = grad(self.dfdy)

    # def fr(self, x, y):
    #     "this is used for the baseline model(follow the ridge)"
    #     yy = self.d2fdydy(x, y)
    #     yx = self.d2fdydx(x, y)
    #     if yy == 0:
    #         return 0
    #     return yx/yy
    
    def grad(self, x, y):
        derivs = np.array([self.dfdx(x), self.dgdy(y)])
        return derivs[0], derivs[1]
    
    def loss(self, x, y):
        return (x-self.xopt)**2 + (y-self.yopt)**2

In [61]:
def APDG(problem,
         params: dict,
         x0: np.ndarray,
         y0: np.ndarray,
         A: np.ndarray,
         iter_num: int
        ):
    x, y = x0, y0
    x_f, y_f = x0, y0
    y_prev = y0
    
    loss = []
    if problem.xopt is not None:
        xopt, yopt = problem.xopt, problem.yopt
        loss.append(np.sqrt((x-xopt)**2 + (y-yopt)**2))
    x_hist, y_hist = [x], [y]
    
    for i in range(iter_num):
        y_m = y + params['theta'] * (y - y_prev)
        x_g = params['tau_x'] * x + (1 - params['tau_x']) * x_f
        y_g = params['tau_y'] * y + (1 - params['tau_y']) * y_f
        
        grad_x, grad_y = problem.grad(x_g, y_g)
        x = (x + params['eta_x'] * params['alpha_x'] * (x_g - x) -
             params['eta_x'] * params['beta_x'] * (A.T.dot(A.dot(x) - grad_y)) - 
             params['eta_x'] * (grad_x + A.T.dot(y_m))
            )
        y = (y + params['eta_y'] * params['alpha_y'] * (y_g - y) - 
             params['eta_y'] * params['beta_y'] * (A.dot(A.T.dot(y) - grad_x)) -
             params['eta_y'] * (grad_y + A.dot(x))
            )
        x_f = x_g + params['sigma_x'] * (x - x_hist[-1])
        y_f = y_g + params['sigma_y'] * (y - y_hist[-1])
        
        x_hist.append(x)
        y_hist.append(y)
        loss.append(problem.loss(x, y))
    return loss, x_hist, y_hist

In [62]:
class func2(BaseSaddle):
    def __init__(self):
        super().__init__()
        self.xopt, self.yopt = 0.40278777, 0.59721223   
        self.xrange = [-5, 5, .1]
        self.yrange = [-5, 5, .1]
        self.f = lambda x : x**2
        self.g = lambda y: y**2
        self.constraint = False   
        self.dfdx = grad(self.f)  
        self.dgdy = grad(self.g)

In [69]:
def plot(loss, xpath, ypath, iteration, k, start, fig_dir=None):
    x0, y0 = start
    loss1, loss2, loss3, loss4, loss5, loss6, loss7, loss8 = loss
    xpath1, xpath2, xpath3, xpath4, xpath5, xpath6, xpath7, xpath8= xpath
    ypath1, ypath2, ypath3, ypath4, ypath5, ypath6, ypath7, ypath8 = ypath
    fig, axlist = plt.subplots(1, 2, figsize=(14,5))
    ax1 = axlist[0]
    ax2 = axlist[1]    
    ax1.contourf(x, y, z, 5, cmap=plt.cm.gray)
    ax1.quiver(x, y, x - dz_dx, y - dz_dy, alpha=.5)
    ax1.plot(xpath7, ypath7,  'm-', linewidth=2, label='SimGDA',markevery=markevery)
    ax1.plot(xpath1, ypath1, 'g--', linewidth=2, label='AltGDA',markevery=markevery)
    ax1.plot(xpath2, ypath2, '--',linewidth=2, label='Avg',markevery=markevery)
    ax1.plot(xpath3, ypath3, 'k-^', linewidth=2, label='EG',markevery=markevery)
    ax1.plot(xpath4, ypath4, 'c-*', linewidth=2, label='OMD',markevery=markevery)
    ax1.plot(xpath6, ypath6, 'b->', linewidth=2, label='SimGDA-RAM', markevery=markevery)
    ax1.plot(xpath5, ypath5, 'r-d', linewidth=2, label='AltGDA-RAM', markevery=markevery)
    x_init = ax1.scatter(x0, y0, marker='s', s=250, c='g',alpha=1,zorder=3, label='Start')
    x_sol = ax1.scatter(xsol, ysol, s=250, marker='*', color='violet', zorder=3, label='Optima')
    ax1.legend([x_init, x_sol],['Start','Optima'], markerscale=1, loc=4, fancybox=True, framealpha=1., fontsize=20)
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')  
    ax1.set_xlim([xmin,xmax])
    ax1.set_ylim([ymin,ymax])
    
    plot_interval =1
    # ax2.semilogy(np.arange(0, iteration+plot_interval, plot_interval), loss7[::plot_interval], 'm-', markevery=markevery, label='SimGDA')
    ax2.semilogy(np.arange(0, iteration+plot_interval, plot_interval), loss1[::plot_interval], 'g--', markevery=markevery, label='AltGDA')
    # ax2.semilogy(np.arange(0, iteration+plot_interval, plot_interval), loss2[::plot_interval], '--', markevery=markevery, label='Averaging')
    # ax2.semilogy(np.arange(0, iteration+plot_interval, plot_interval), loss3[::plot_interval], 'k-^', markevery=markevery, label='EG')
    # ax2.semilogy(np.arange(0, iteration+plot_interval, plot_interval), loss4[::plot_interval],'c-*', markevery=markevery, label='OMD')
    # ax2.semilogy(np.arange(0, iteration+plot_interval, plot_interval), loss5[::plot_interval], 'r-d', markevery=markevery, label='SimGDA-RAM')
    # ax2.semilogy(np.arange(0, iteration+plot_interval, plot_interval), loss6[::plot_interval], 'b->', markevery=markevery, label='AltGDA-RAM')
    ax2.set_xlabel('Iteration')
    ax2.set_ylim([1e-25,1e4])
    ax2.set_ylabel('Distance to optimal')
    axlist.flatten()[-1].legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=1, fancybox=True, framealpha=1., fontsize=20, markerscale=2)
    if fig_dir is not None:
        fig.savefig(os.path.join(fig_dir, figname), dpi=300, bbox_inches = 'tight', pad_inches = 0)
    else:
        fig.savefig(figname, dpi=300, bbox_inches = 'tight', pad_inches = 0)

        
def main(problem, iteration, x0, y0, A, params, k=5):
    allloss = [[] for _ in  range(8)]
    allxpath = [[] for _ in  range(8)]
    allypath = [[] for _ in  range(8)]
    allloss[0], allxpath[0], allypath[0] = APDG(problem=problem, x0=x0, y0=y0, A=A, iter_num=iteration, params=params['apdg'])
    # allloss[1], allxpath[1], allypath[1] = avg(problem, x0, y0, iteration, lr=lrset['avg']) 
    # allloss[2], allxpath[2], allypath[2] = eg(problem, x0, y0, iteration, lr=lrset['eg'])
    # allloss[3], allxpath[3], allypath[3] = omd(problem, x0, y0, iteration, lr=lrset['omd'])
    # allloss[4], allxpath[4], allypath[4]= simGDAAM(problem, x0, y0, iteration, lr=lrset['AA'], k=k)   
    # allloss[5], allxpath[5], allypath[5]= altGDAAM(problem, x0, y0, iteration, lr=lrset['AA'] ,k=k)   
    # allloss[6], allxpath[6], allypath[6]= simgd(problem, x0, y0, iteration, lr=lrset['simgd'])   
    return allloss, allxpath, allypath

In [72]:
figname = 'APDG_1.png'
FIG_DIR = os.path.join("..", "figures")
if not os.path.exists(FIG_DIR):
    os.mkdir(FIG_DIR)
k = 20
markevery= 10
x0, y0 = np.array([3.]),np.array([3.])
A = np.eye(1, 1)

problem = func2()
xsol, ysol = problem.xopt, problem.yopt
params = {'apdg': {"eta_x": 0.5,
                   "alpha_x": 0.5,
                   "beta_x": 0.5,
                   "tau_x": 0.5,
                   "sigma_x": 0.5,
                   "eta_y": 0.5,
                   "alpha_y": 0.5,
                   "beta_y": 0.5,
                   "tau_y": 0.5,
                   "sigma_y": 0.5,
                   "theta": 0.5}}
# {'simgd':0.05, 'altgd':0.1, 'avg':1, 'adam':0.01, 'eg':0.6,'omd':0.3, 'fr':0.05,'AA':0.5}
f = problem.f
g = problem.g

type2=True
iteration = 10
loss_f3, xpath_f3, ypath_f3 = main(problem, iteration, x0, y0, A, params, k=k)
xmin, xmax, xstep = [-3.5, 5, .1]
ymin, ymax, ystep = [-3.5, 5, .1]
x, y = np.meshgrid(np.arange(xmin, xmax + xstep, xstep), np.arange(ymin, ymax + ystep, ystep))
z = lambda x, y: f(x) - g(y) + x * A * y
z = z(x, y)
dz_dx = elementwise_grad(f, argnum=0)(x)
dz_dy = elementwise_grad(g, argnum=0)(y)
plot(loss_f3, xpath_f3, ypath_f3, iteration, k, [x0, y0], fig_dir=FIG_DIR)