Copyright 2022 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Super-acceleration with cyclical step-sizes

This colab reproduces the figures from the blog post https://fa.bianp.net/2022/cyclical/ and the paper

> _Super-Acceleration with Cyclical Step-sizes_, Baptiste Goujaud, Damien Scieur, Aymeric Dieuleveut, Adrien Taylor, Fabian Pedregosa. Proceedings of the 25th International Conference on Artificial Intelligence and Statistics, 2022. https://arxiv.org/pdf/2106.09687.pdf

In [None]:
%%capture
import matplotlib.font_manager as fm

# for nicer fonts
!wget https://github.com/openmaptiles/fonts/raw/master/open-sans/OpenSans-Light.ttf
fm.fontManager.ttflist += fm.createFontList(['OpenSans-Light.ttf'])

# install apngasm for creating animated PNGs
!apt-get install apngasm

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from matplotlib import rcParams
from matplotlib.ticker import StrMethodFormatter
rcParams['font.size'] = 35
rcParams['font.family'] = 'Open Sans'
rcParams['font.weight'] = 'light'
rcParams['mathtext.fontset'] = 'cm'

import numpy as np
from scipy import special

# this is a color palette shared by some of the plots
palette = [
    '#66c2a5', '#fc8d62', '#8da0cb', '#e78ac3', '#a6d854', '#e41a1c', '#377eb8',
    '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', '#a65628', '#f781bf'
]

# Cyclical Heavy Ball animation in 2D

The following code generates the iterates of classical and cyclical heavy ball on a 2D problem, for easier visualization. It will generate one PNG for each iteration. These are put together on a single animated PNG with apngasm.

To download the generated file, find the file on the "Files" tab and right click on Download.

In [None]:
n_grid = 200
max_iter = 10

x_grid = np.linspace(-2, 5, n_grid)
X, Y = np.meshgrid(x_grid, x_grid)
Z = np.array((X, Y)).T
x_init = np.array([1.5, 2.5])

# A hessian with very different eigenvalues
H = np.array([[2, 0], [0, 0.2]])
# largest and smallest eigenvalue
L = np.linalg.eigvalsh(H).max()
mu = np.linalg.eigvalsh(H).min()

# Compute the loss on a grid of values to display with imshow
loss_grid = (1/2) * ((Z @ H) * Z).sum(-1)

# Compute the iterates of Polyak momentum
# and store them in the array all_iterates_momentum
xt = x_init.copy()
all_iterates_momentum = np.zeros((max_iter, 2))
h = (2 / (np.sqrt(L) + np.sqrt(mu))) ** 2
m = ((np.sqrt(L) - np.sqrt(mu)) / (np.sqrt(L) + np.sqrt(mu))) ** 2
xt_old = xt.copy()
for i in range(max_iter):
  all_iterates_momentum[i] = xt[:]
  grad_t = H @ xt
  tmp = xt.copy()
  if i == 0:
    xt = xt - (2 / (L + mu)) * grad_t
  else:
    xt = xt - h * grad_t + m * (xt - xt_old)
  xt_old = tmp


# Compute the iterates of cyclical heavy ball
# and store them in the array all_iterates_cyclical
mu1 = mu
L2 = L
rho = (L2 + mu1) / (L2 - mu1)
# here we choose a high R to have a clear super-acceleration effect
R = 0.9
L1 = mu + (1 - R) * (L - mu) / 2
mu2 = L - (1 - R) * (L - mu) / 2
m = ((np.sqrt(rho**2 - R**2) - np.sqrt(rho**2 - 1)) / np.sqrt(1 - R**2)) ** 2
all_iterates_cyclical = np.zeros((max_iter, 2))
xt = x_init.copy()
xt_old = xt.copy()
for i in range(max_iter):
  all_iterates_cyclical[i] = xt[:]
  grad_t = H @ xt
  tmp = xt.copy()
  if i == 0:
    xt = xt - (2 / (L + mu)) * grad_t
  elif i % 2 == 0:
    # iteration is even
    ht = (1 + m) / L1
    xt = xt - ht * grad_t + m * (xt - xt_old)
  elif i % 2 == 1:
    # iteration is odd
    ht = (1 + m) / mu2
    xt = xt - ht * grad_t + m * (xt - xt_old)
  xt_old = tmp

In [None]:
for i in range(max_iter):
  plt.figure(figsize=(20, 10))
  plt.contour(X, Y, -loss_grid.T +0.05, 50, lw=5, colors='black')
  plt.imshow(-loss_grid.T / np.max(np.abs(loss_grid)), extent=[-2, 5, -2, 5], 
           cmap='gist_heat', alpha=1)
  plt.scatter([0], [0], color='black', s=80)
  plt.text(0.05, 0, '$x^\star$', color='black')
  plt.plot(all_iterates_momentum[:i,  0], all_iterates_momentum[:i,  1],
           c='teal', lw=3, label='Heavy Ball', marker='d',
           markersize=10)
  plt.plot(all_iterates_cyclical[:i,  0], all_iterates_cyclical[:i,  1],
           c='darkred', lw=3, label='Cyclical Heavy Ball',
           marker='^', markersize=10)

  plt.ylim((-0.2, 2.8))
  plt.xlim((-2, 2))
  plt.xticks(())
  plt.yticks(())
  plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
          frameon=False, ncol=1, fontsize=22)
  plt.axes().set_aspect('equal')

  f_path = 'comparison_cyclical_%02d.png' % i
  plt.savefig(f_path, transparent=True, dpi=100, bbox_inches='tight')
  plt.show()

In [None]:
# convert to animated PNG
%%capture
!apngasm comparison_cyclical.png comparison_cyclical_01.png 1 1

# Residual polynomial

In [None]:
# repeat the same plots but using the Polyak momentum polynomial

colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33']

mu, L = 0.2, 2

def poly_gd(x, t):
  step_size = 2 / (mu + L)
  return (1 - step_size * x) ** (t)


def poly_polyak(x, low, high, degree):
  m = ((np.sqrt(high) - np.sqrt(low))/(np.sqrt(high) + np.sqrt(low))) ** 2
  h = (2 / (np.sqrt(high) + np.sqrt(low))) ** 2
  s = (1 + m - h * x) / (2 * np.sqrt(m))
  cheb1_part = special.eval_chebyt(degree, s)
  cheb2_part = special.eval_chebyu(degree, s)
  return (m**(t/2)) * ((2 * m / (1 + m)) * cheb1_part + ((1-m)/(1 + m)) * cheb2_part)


def poly_cyclical(x, mu1, L1, mu2, L2, degree):
  rho = (L2 + mu1) / (L2 - mu1)
  R = (mu2 - L1) / (L2 - mu1)
  m = ((np.sqrt(rho ** 2 - R**2) - np.sqrt(rho**2 - 1)) / np.sqrt(1 - R**2)) ** 2
  h0 = (1 + m) / L1
  h1 = (1 + m) / mu2
  tmp = (1 + m - h0 * x) * (1 + m - h1 * x) 
  s = np.sqrt(np.abs(tmp)) * np.sign(tmp)/ (2 * np.sqrt(m))
  cheb1_part = special.eval_chebyt(degree, s)
  cheb2_part = special.eval_chebyu(degree, s)
  tmp = (m**(t/2)) * ((2 * m / (1 + m)) * cheb1_part + ((1-m)/(1 + m)) * cheb2_part)
  return tmp


x_grid = np.linspace(0, mu + L, 500)
for t in range(2, 19, 2):

  acc_y = poly_cyclical(x_grid, mu, (mu + L)/2, (mu + L)/2, L, t)
  idx = (x_grid >= mu) & (x_grid <= L)

  acc_cylical = poly_cyclical(x_grid, mu, mu + 0.12, L - 0.12, L, t)

  f, axarr = plt.subplots(1, 1, figsize=(12, 10))
  plt.title("Degree %s" % t)
  base_line_2, = axarr.plot(x_grid, acc_y, '--', lw=5, label='Polyak Heavy Ball $P^{Polyak}_t$', color='#ff7f0e')
  base_line_3, = axarr.plot(x_grid, acc_cylical, lw=5, label='Cyclical Heavy Ball $P^{Cyclical}_t$', color=colors[3])

  axarr.set_ylabel('$P_{t}(\lambda)$')
  axarr.set_xlabel('$\lambda$')
  
  axarr.axvline(x=mu, color='grey',)
  axarr.axvline(x=L, color='grey')
  axarr.set_xticks((0.0, mu, 0.5, 1.0, 1.5, L))
  axarr.set_xticklabels((0.0, '$\lambda_\min$', None, None, None, '$\lambda_\max$'))
  axarr.set_yticks((-0.1, 0.0, 0.2, 0.4, 0.6, 0.8, 1.0))
  axarr.set_yticklabels((None, 0, None, None, None, None, 1.0))

  axarr.legend(loc='upper center', frameon=False, bbox_to_anchor=(0.5, -0.1), ncol=1)
  axarr.set_ylim((-0.1, 1))
  axarr.grid()

  f.subplots_adjust(wspace = 0.3) # pad a little

  f_path = 'CyclicalResidualPolynomial%02d.png' % (t //2)
  plt.savefig(f_path, transparent=True, dpi=100, bbox_inches='tight')

  plt.show()

In [None]:
# convert to animated PNG
%%capture
!apngasm CyclicalResidualPolynomial.png CyclicalResidualPolynomial01.png 1 1

# Link function

In this block we'll plot the link function of both classical and cyclical heavy ball. We'll generate different images for different input parameters, and as before, use apngasm to generate an animated PNG from them.

In [None]:
def sigma(x, m, h):
  return (1 + m - h * x) / (2 * np.sqrt(m))

def zeta(x, m, h0, h1):
  idx = ((1 + m - h0 * x) > 0) & ((1 + m - h1 * x) > 0)
  out = np.zeros_like(x)
  out[idx] = np.sqrt((1 + m - h0 * x[idx]) * (1 + m - h1 * x[idx]) / (4 * m))
  out[~idx] = -np.sqrt((1 + m - h0 * x[~idx]) * (1 + m - h1 * x[~idx]) / (4 * m))
  return out

n_grid = 1000

In [None]:
x_grid = np.linspace(0, 2, n_grid)

for i, L in enumerate(np.concatenate((np.linspace(1, 2, 20), np.linspace(2, 1, 20)))):
  m = ((np.sqrt(L) - np.sqrt(mu)) / (np.sqrt(L) + np.sqrt(mu))) ** 2
  h = (2 / (np.sqrt(L) + np.sqrt(mu))) ** 2

  plt.plot(x_grid, sigma(x_grid, m, h), lw=3, label='link function $\sigma$')

  yy = np.linspace(mu, L)
  plt.plot(yy, np.zeros_like(yy), lw=10, alpha=0.5, label='$\sigma^{-1}([-1, 1])$')
  plt.title('Constant step-size link function', fontsize=28)
  plt.ylim(-1.5, 1.5)
  plt.yticks((-1, 0, 1), fontsize=22)
  plt.xlim((0, L+mu))
  plt.xticks(())
  plt.grid()
  plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
          frameon=False, ncol=2, fontsize=22)

  f_path = 'link_function_constant_%02d.png' % i
  plt.savefig(f_path, transparent=True, dpi=100, bbox_inches='tight')

  plt.show()


In [None]:
# convert to animated PNG with
%%capture
!apngasm link_function_constant.png link_function_constant_01.png 1 10

In [None]:
x_grid = np.linspace(0, 2, n_grid)

rho = (L + mu) / (L - mu)
for i, R in enumerate(np.concatenate((np.linspace(0, 0.5, 20), np.linspace(0.5, 0, 20)))):
  m = (np.sqrt(rho ** 2 - R ** 2) - np.sqrt(rho ** 2 - 1)) ** 2 / (1 - R ** 2)
  L1 = mu + (1 - R) * (L - mu) / 2
  mu2 = L - (1 - R) * (L - mu) / 2
  h0 = (1 + m) / L1
  h1 = (1 + m) / mu2

  plt.plot(x_grid, zeta(x_grid, m, h0, h1), lw=3, label='link function $\zeta$')

  yy = np.linspace(mu, L)
  idx = (yy > L1) & (yy < mu2)
  yy_img = np.zeros_like(yy)
  yy_img[idx] = np.nan
  plt.plot(yy, yy_img, lw=10, alpha=0.5, label='$\zeta^{-1}([-1, 1])$')
  plt.title('Cyclical step-size link function', fontsize=28)
  plt.ylim(-1.5, 1.5)
  plt.yticks((-1, 0, 1), fontsize=22)
  plt.xlim((0, L+mu))
  plt.xticks(())
  plt.grid()
  plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
          frameon=False, ncol=2, fontsize=22)

  f_path = 'link_function_cyclical_%02d.png' % i
  plt.savefig(f_path, transparent=True, dpi=100, bbox_inches='tight')

  plt.show()


In [None]:
# convert to animated PNG
%%capture
!apngasm link_function_cyclical.png link_function_cyclical_01.png 1 10

# Spectral density

In this section we download the MNIST dataset and plot the Hessian eigenvalues for a quadratic objective. We'll overlay the quantities $\mu_1, \mu_2, L_1, L_2$ that are important for optimization.

In [None]:
# load the MNIST dataset and convert to a numpy array
import tensorflow_datasets as tfds

ds = tfds.load(name='mnist', split='train')
ds_numpy = tfds.as_numpy(ds)  # Convert `tf.data.Dataset` to Python generator
mnist_images = []
mnist_target = []
for ex in ds_numpy:
  mnist_images.append(ex['image'].ravel()) 
  mnist_target.append(ex['label'])

mnist_images = np.array(mnist_images).astype(np.float64) / 255.
mnist_target = np.array(mnist_target).astype(np.float64)

In [None]:
H = mnist_images.T @ mnist_images
eigs = np.linalg.eigvalsh(H)

L2 = eigs[-1]
mu1 = np.min(eigs)
L1 = eigs[-2]
mu2 = L2 - (L1 - mu1)

print('condition number', np.min(eigs)/np.max(eigs))

In [None]:
fig, axarr = plt.subplots(1, 1, figsize=(1 * 10, 1 * 8))

axarr.hist(eigs / L2, 50)
axarr.set_yscale("log")
axarr.axvline(L1 / L2, color='#4DAF4A', linestyle='--', lw=1)
axarr.axvline(mu1 / L2, color=palette[8], linestyle='--', lw=1)
axarr.axvline(L2 / L2, color=palette[8], linestyle='--', lw=1)
axarr.axvline(mu2 / L2, color='#4DAF4A', linestyle='--', lw=1)

axarr.text(L1 * 0.999  / L2, 500,'$L_1$', color='#4DAF4A')
axarr.text(mu2 * 0.999  / L2, 500,'$\mu_2$', color='#4DAF4A')
axarr.text(L2 * 0.999 / L2, 500,'$L_2$', color=palette[8])
axarr.text(mu1 * 0.999 / L2, 500,'$\mu_1$', color=palette[8])
plt.xticks(())

axarr.set_ylabel("density")
axarr.set_xlabel("eigenvalue magnitude")

p1 = patches.FancyArrowPatch((0, 200), (1, 200), arrowstyle='<->',
                             mutation_scale=20, color=palette[8], linewidth=3)
axarr.add_patch(p1)
p2 = patches.FancyArrowPatch((L1 / L2, 60), (mu2 / L2, 60), arrowstyle='<->',
                             mutation_scale=20, color='#4DAF4A', linewidth=3)
axarr.add_patch(p2)
axarr.text(0.4, 270, r"$L_2 - \mu_1$", color=palette[8], fontsize=30)
axarr.text(0.4, 80, r"$\mu_2 - L_1$", color='#4DAF4A', fontsize=30)
axarr.text(0.33, 12, r"$R = \frac{~~~~~~~~~~~~~~~~}{~~~}$", color='k', fontsize=30)
axarr.text(0.44, 16, r"$\mu_2 - L_1$", color='#4DAF4A', fontsize=30)
axarr.text(0.44, 9, r"$L_2 - \mu_1$", color=palette[8], fontsize=30)

plt.tight_layout()

f_path = 'spectrum_mnist.png'
fig.savefig(f_path, dpi=300, bbox_inches = 'tight', transparent=True, fc='k', ec='k',
            shape="full")

plt.show()

# Robust region of cyclical heavy ball

Here we plot the growing robust region as a function of the relative gap $R$

In [None]:
# fix the problem constants. Can be changed
# and will yield slightly different figures
mu, L = 0.1, 2
n_grid = 1500

In [None]:
all_m = np.linspace(1, 1e-12, n_grid)
all_h = np.linspace(0, 1, n_grid)
m_grid, h_grid = np.meshgrid(all_m, all_h)
m_polyak = ((np.sqrt(L) - np.sqrt(mu)) / (np.sqrt(L) + np.sqrt(mu))) ** 2

robust_m = all_m[all_m > m_polyak]
h_polyak = (2 / (np.sqrt(L) + np.sqrt(mu))) ** 2
rho = (L + mu) / (L-mu)


def varphi(xi):
  return np.abs(xi) + np.sqrt(xi**2 - 1)


def sigma_r(x, m, r):
  h1 = (1+m) / (0.5 * (L + mu) - r * 0.5 * (L-mu))
  h0 = (1+m) / (0.5 * (L + mu) + r * 0.5 * (L-mu))
  return abs((1 + m - h0 * x) * (1 + m - h1 * x) / (2 * m) - 1)


all_R = np.linspace(0, 1.0, 30)
for i_R, R in enumerate(np.concatenate([all_R, all_R[::-1]])):

  L1 = (mu + (1-R) * (L-mu)/2 )
  mu2 = (L - (1-R) * (L-mu)/2)
  optimal_m = ((np.sqrt(rho**2 - R**2) - np.sqrt(rho**2 - 1)) / np.sqrt(1 - R**2))**2

  rate = np.zeros((n_grid, n_grid))

  s1 = np.abs(sigma_r(mu, m_grid, h_grid))
  s2 = np.abs(sigma_r(L1, m_grid, h_grid))
  s3 = np.abs(sigma_r(mu2, m_grid, h_grid))
  s4 = np.abs(sigma_r(L, m_grid, h_grid))
  smax = np.max((s1, s2, s3, s4), axis=0)
  idx = (smax <= 1)
  rate[idx] = np.sqrt(m_grid[idx])
  rate[~idx] = np.nan

  plt.figure(figsize=(16, 8))
  plt.title(f'R={R:.{2}}')
  plt.pcolor(m_grid, h_grid, rate, vmin=0.55)
  plt.xlabel(r'momentum $m$')
  plt.ylabel(r'parameter $r$')
  cbar = plt.colorbar()
  cbar.ax.set_ylabel('asymptotic rate')

  plt.grid()
  plt.ylim((0, 1))
  f_path = 'robust_region_cyclical_%02d.png' % i_R
  plt.savefig(f_path, transparent=True, dpi=100, bbox_inches='tight')
  plt.show()

In [None]:
%%capture
!apngasm robust_region_cyclical.png robust_region_cyclical_00.png 1 5

# Landscape

Plot the convergence rate in color as a function of the two step-sizes. The rate that we display is a consequence of Theorem 3 in [the paper](https://arxiv.org/pdf/2106.09687.pdf).



In [None]:
H = mnist_images.T @ mnist_images

eigs = np.linalg.eigvalsh(H)
eigs += 1e-2 * np.max(eigs)  # regularization
eigs /= np.max(eigs)  # normalize

L2 = np.max(eigs)
mu1 = np.min(eigs)

print('condition number', np.min(eigs)/np.max(eigs))

In [None]:
n_grid = 2000
h_max = 12

smallest_R = 0.75
minimum_rate = ((np.sqrt(rho**2 - smallest_R**2) - np.sqrt(rho**2 - 1)) / np.sqrt(1 - smallest_R**2))

def sigma(x, m, h0, h1):
  return 2 *( (1 + m - h0 * x) / (2 * np.sqrt(m))) * ((1 + m - h1 * x) / (2 * np.sqrt(m))) - 1


all_R = np.linspace(0, smallest_R, 40)
for it, R in enumerate(np.concatenate([all_R, all_R[::-1]])):
  fig, axarr = plt.subplots(1, 2, figsize=(2 * 10, 1 * 8))

  L1 = (mu1 + L2)/2 - (L2 - mu1) * R / 2
  mu2 = (mu1 + L2)/2 + (L2 - mu1) * R / 2

  axarr[0].hist(eigs / L2, 50)
  axarr[0].set_yscale("log")
  axarr[0].axvline(L1 / L2, color='#4DAF4A', linestyle='--', lw=1)
  axarr[0].axvline(mu1 / L2, color=palette[8], linestyle='--', lw=1)
  axarr[0].axvline(L2 / L2, color=palette[8], linestyle='--', lw=1)
  axarr[0].axvline(mu2 / L2, color='#4DAF4A', linestyle='--', lw=1)


  axarr[0].text(L1 * 0.999  / L2, 500,'$L_1$', color='#4DAF4A')
  axarr[0].text(mu2 * 0.999  / L2, 500,'$\mu_2$', color='#4DAF4A')
  axarr[0].text(L2 * 0.999 / L2, 500,'$L_2$', color=palette[8])
  axarr[0].text(mu1 * 0.999 / L2, 500,'$\mu_1$', color=palette[8])
  axarr[0].set_xticks(())

  axarr[0].set_ylabel("density")
  axarr[0].set_xlabel("eigenvalue magnitude")


  p1 = patches.FancyArrowPatch((0, 200), (1, 200), arrowstyle='<->',
                              mutation_scale=20, color=palette[8], linewidth=3)
  axarr[0].add_patch(p1)
  p2 = patches.FancyArrowPatch((L1 / L2, 60), (mu2 / L2, 60), arrowstyle='<->',
                              mutation_scale=20, color='#4DAF4A', linewidth=3)
  axarr[0].add_patch(p2)
  axarr[0].text(0.4, 270, r"$L_2 - \mu_1$", color=palette[8], fontsize=30)
  axarr[0].text(0.4, 80, r"$\mu_2 - L_1$", color='#4DAF4A', fontsize=30)
  axarr[0].text(0.33, 12, r"$R = \frac{~~~~~~~~~~~~~~~~}{~~~}$ = %.2f" % R, color='k', fontsize=30)
  axarr[0].text(0.44, 16, r"$\mu_2 - L_1$", color='#4DAF4A', fontsize=30)
  axarr[0].text(0.44, 9, r"$L_2 - \mu_1$", color=palette[8], fontsize=30)


  rho = (L2 + mu1) / (L2 - mu1)

  optimal_m = (
      (np.sqrt(rho ** 2 - R ** 2) - np.sqrt(rho ** 2 - 1)) / \
      np.sqrt(1 - R**2)) ** 2
  optimal_h0 = (1 + optimal_m) / mu2
  optimal_h1 = (1 + optimal_m) / L1

  all_h1 = np.linspace(1e-6, h_max, n_grid)
  all_h0 = np.linspace(1e-6, h_max, n_grid)
  h0_grid, h1_grid = np.meshgrid(all_h1, all_h1)

  rate = np.zeros((n_grid, n_grid))

  for i in range(n_grid):

    h0 = np.min((h0_grid[i], h1_grid[i]), axis=0)
    h1 = np.max((h0_grid[i], h1_grid[i]), axis=0)

    # compute \sigma_star from Theorem 3.1 in https://arxiv.org/pdf/2106.09687.pdf
    tmp0 = np.abs(sigma(mu1, optimal_m, h0, h1))
    tmp1 = np.abs(sigma(L1, optimal_m, h0, h1))
    tmp2 = np.abs(sigma(mu2, optimal_m, h0, h1))
    tmp3 = np.abs(sigma(L2, optimal_m, h0, h1))
    tmp4 = np.zeros_like(tmp3)
    idx = ((mu1 <= (1+optimal_m) * (h0 + h1) / (2 * h0 * h1)) & ((1+optimal_m) * (h0 + h1) / (2 * h0 * h1) <= L1)) | \
          ((mu2 <= (1+optimal_m) * (h0 + h1) / (2 * h0 * h1)) & ((1+optimal_m) * (h0 + h1) / (2 * h0 * h1) <= L2))
    tmp4[idx] = np.abs(sigma((1+optimal_m) * (h0 + h1) / (2 * h0 * h1), optimal_m, h0, h1))[idx]
    sigma_star = np.max((tmp0, tmp1, tmp2, tmp3, tmp4), axis=0)    
    idx_robust = sigma_star <= 1
    rate[i, :] = np.nan
    rate[i, idx_robust] = np.sqrt(optimal_m)
    idx_convergent = sigma_star <= (1 + optimal_m ** 2) / (2 * optimal_m)
    rate[i, idx_convergent] = np.sqrt(optimal_m * (sigma_star + np.sqrt(sigma_star**2 - 1)))[idx_convergent]


  pc = axarr[1].pcolor(h0_grid, h1_grid, rate, rasterized=True, cmap='viridis', vmin=minimum_rate, vmax=1)
  axarr[1].set_xticks(())
  axarr[1].set_yticks(())
  axarr[1].set_xlabel(r"First step-size $h_0$")
  axarr[1].set_ylabel(r"Second step-size $h_1$")
  axarr[1].spines['top'].set_visible(False)
  axarr[1].spines['right'].set_visible(False)
  axarr[1].scatter(optimal_h0, optimal_h1, s=400, facecolors='none', edgecolors='#d95f02', lw=3)
  axarr[1].plot(np.linspace(0, optimal_h0, 100), optimal_h1 * np.ones(100), '--', c='#d95f02', lw=2)
  axarr[1].plot(np.linspace(optimal_h0, optimal_h0, 100), np.linspace(0, optimal_h1, 100), '--', c='#d95f02', lw=2)
  axarr[1].text(2, 10, r"$\circ$ optimal parameters", c='#d95f02')

  if it > 0:
    axarr[1].scatter(optimal_h1, optimal_h0, s=400, facecolors='none', edgecolors='#d95f02', lw=2)
    axarr[1].plot(np.linspace(0, optimal_h1, 100), optimal_h0 * np.ones(100), '--', c='#d95f02', lw=2)
    axarr[1].plot(np.linspace(optimal_h1, optimal_h1, 100), np.linspace(0, optimal_h0, 100), '--', c='#d95f02', lw=2)

  axarr[1].set_xlim((0, None))
  axarr[1].set_ylim((0, None))

  fig.subplots_adjust(right=0.80)
  cbar_ax = fig.add_axes([0.82, 0.15, 0.02, 0.7])
  fig.colorbar(pc,  cax=cbar_ax, ticks=[0.8, 1])
  cbar_ax.set_ylabel(r'asymptotic rate')

  f_path = 'rate_convergence_cyclical_%02d.png' % it
  plt.savefig(f_path, transparent=True, dpi=100, bbox_inches='tight')
  plt.show()


In [None]:
%%capture
!apngasm rate_convergence_cyclical.png rate_convergence_cyclical_00.png 1 3

# Convergence rate comparison

In this section we compare the asymptotic convergence rates for different condition numbers.

In [None]:
all_R = np.linspace(0, 1, endpoint=False)
all_kappa = np.logspace(0, -4, 20)[1:]


def cyclical_rate(kappa, R):
    rho = (1 + kappa) / (1 - kappa)
    if rho.shape == () and len(R) > 1:
        rho = np.array([rho]*len(R))
    r_polyak = rho - np.sqrt(rho ** 2 - 1)
    r_cyclical = (np.sqrt(rho ** 2 - R ** 2) - np.sqrt(rho ** 2 - 1)) / np.sqrt(1 - R ** 2)
    r_approx = 1 - (1 - r_polyak) / np.sqrt(1 - R ** 2)

    return r_polyak, r_cyclical, r_approx


for i, kappa in enumerate(np.concatenate([all_kappa, all_kappa[::-1]])):
    plt.figure(figsize=(8, 6))
    plt.title(f'$\\kappa$ = {kappa:.{2}}')

    r_polyak, r_cyclical, r_approx = cyclical_rate(kappa, all_R)

    plt.plot(all_R, (1-r_polyak), lw=4, label='Polyak', marker='d', markevery=20, markersize=10)
    plt.plot(all_R, (1-r_cyclical), lw=4, label='Cyclical', marker='^', markevery=18, markersize=10)
    plt.plot(all_R, (1-r_approx), '--', lw=4, label='Approx', marker='s', markevery=15, markersize=10)

    plt.grid()
    plt.xlabel('R')
    plt.ylabel('Rate factor')
    plt.gca().yaxis.set_major_formatter(StrMethodFormatter('{x:,.1f}')) # 2 decimal places


    plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.25), frameon=False, ncol=3, fontsize=26)

    f_path = 'asymptotic_rate_%02d.png' % i
    plt.savefig(f_path, transparent=True, dpi=50, bbox_inches='tight')

    plt.show()


In [None]:
%%capture
!apngasm asymptotic_rate.png asymptotic_rate_00.png 1 5