In [None]:
import numpy as np
import matplotlib.pyplot as plt
import abc
from math import sqrt

In [None]:
# define function classes

class Obj:
  @abc.abstractmethod
  def fval(self, x):
    pass

  @abc.abstractmethod
  def grad(self, x):
    pass

class QuadObj(Obj):
  def fval(self, x):
    return 0.5 * x.T @ self.A @ x - self.b @ x

  def grad(self, x):
    return self.A @ x - self.b

  def ferr(self, x):
    return self.fval(x) - self.fval(self.x_opt)
    

class TwoBandQuadObj(QuadObj):
  def __init__(self, mu1, L1, mu2, L2, dim):
    self.mu1 = mu1
    self.mu2 = mu2
    self.L1 = L1
    self.L2 = L2
    self.dim = dim

    _Lambd = np.diag(np.concatenate([np.linspace(mu1, L1, dim//2),
                  np.linspace(mu2, L2, dim//2)]))
    _Q, _ = np.linalg.qr(np.random.randn(dim,dim))

    self.A = _Q @ _Lambd @ _Q.T
    self.x_opt = np.random.randn(dim)
    self.b = self.A @ self.x_opt

  def agd(self, x_init, v_init, L, mu, T):
      kappa = L / mu
      alpha = sqrt(kappa) / (sqrt(kappa) + 1)
      beta = 1 - 1/sqrt(kappa)
      
      x, v = x_init, v_init
      for _ in range(T):
          y = alpha * x + (1 - alpha) * v
          gy = self.A @ y - self.b
          v = beta * v + (1-beta) * (y - 1/mu * gy)
          x = y - 1/L * gy
      return (x, v)


  def acbsls(self, x_init, T1, T2, naive=False):
    cnt_lst = []
    ferr_lst = []
    grad_query_cnt = 0
    x, v = x_init, x_init
    
    for t1 in range(T1):
        (x, v) = self.agd(x, v, L1, mu1, 1)

        if naive:
          (x, v) = self.agd(v, v, L2, mu2, T2)
          grad_query_cnt += (1 + T2)

        else:
          (x, _) = self.agd(x, x, L2, mu2, T2)
          (_, v) = self.agd(v, v, L2, mu2, T2)
          grad_query_cnt += (1 + 2 * T2)
        
        cnt_lst.append(grad_query_cnt)
        ferr_lst.append(self.ferr(x))

    return cnt_lst, ferr_lst

## Naive AcBSLS may not converge

In [None]:
mu1=1e-4; L1=1e-3; mu2=5e-1; L2=1; dim=128; 
obj = TwoBandQuadObj(mu1=mu1, L1=L1, mu2=mu2, L2=L2, dim=dim)

x_init = np.random.randn(dim)

ax = plt.gca()

linestyle_lst = ['dotted', 'dashed', 'dashdot', (0, (5,5 ))]

ax.semilogy(*obj.acbsls(x_init, 32, 8, naive=False), label=f'(Principled) AcBSLS ($T_2 = {8}$)', linestyle='solid')
for i, T2 in enumerate([8, 16, 32, 64]):
  ax.semilogy(*obj.acbsls(x_init, 32, T2, naive=True), label=f'Naive AcBSLS ($T_2 = {T2}$)', linestyle = linestyle_lst[i])

ax.set_xlabel('Number of gradient queries', size='large')
ax.set_ylabel('Function errors',  size='large')

ax.legend()

plt.savefig('naive_acbsls.pdf')