In [None]:
import torch
import matplotlib.pyplot as plt

import numpy as np
import pickle
from tqdm import tqdm
import scipy
from MomentumOptimizer_LieGroup_SOn import LieGroupSGD
device=torch.device('cpu')
torch.set_default_dtype(torch.float64)
torch.manual_seed(0)
color_list=plt.rcParams['axes.prop_cycle'].by_key()['color']


In [None]:
from utils import *
from eig_val_decomp import *

# kappa dependence

In [None]:
def generate_trajectory(scheme, seed, dim, kappa, num_iter=5000):
    torch.manual_seed(0)
    eig_vals=generate_eig_value_artificial_conditional_number(dim, kappa)
    A, sol_dict=eig_val_decomp_problem(eig_vals, device=device)
    U=lambda X:eig_val_decomp_loss(A, X)
    mu=sol_dict['mu']
    L=sol_dict['L']
    min_val=sol_dict['min_val']
    X_sol=sol_dict['X_sol']
    xi_noise=torch.randn_like(A)*0.01
    xi_noise=xi_noise-xi_noise.T
    g_init=X_sol@torch.matrix_exp(xi_noise)


    
    g=torch.clone(g_init)
    g.requires_grad_(True)
    g_last=g.clone()
    g_star=X_sol
    if scheme=='heavy_ball':
        parameter_dict_HB=parameter_HB(mu, L)
        optimizer=LieGroupSGD([g], lr=parameter_dict_HB['h'], gamma=parameter_dict_HB['gamma'], scheme=scheme)
    elif scheme=='NAG_SC':
        parameter_dict_NAG_SC=parameter_NAG_SC(mu, L)
        optimizer=LieGroupSGD([g], lr=parameter_dict_NAG_SC['h'], gamma=parameter_dict_NAG_SC['gamma'], scheme=scheme)
    elif scheme=='momentumless':
        parameter_dict_NAG_SC=parameter_momentumless(mu, L)
        optimizer=LieGroupSGD([g], lr=parameter_dict_NAG_SC['h'], scheme=scheme)
    else:
        raise NotImplementedError
    loss_list=[]
    lyap_list=[]
    for i in tqdm(range(num_iter)):
        loss=eig_val_decomp_loss(A, g)-min_val
        optimizer.zero_grad()
        loss.backward()
        g_last.copy_(g)
        optimizer.step()
        xi=optimizer.state[g]['xi']

        if scheme=='heavy_ball':
            lyap_list+=[lyap_HB({'h':parameter_dict_HB['h'], 
                'gamma': parameter_dict_HB['gamma'], 
                'U':U,
                'g':g,
                'xi':xi,
                'g_star':g_star,
                'g_last':g_last
            }).item()]
        elif scheme=='NAG_SC':
            nabla_g_last=optimizer.state[g]['trivialized_grad_last']
            lyap_list+=[lyap_NAG_SC({'h':parameter_dict_NAG_SC['h'], 
                'gamma': parameter_dict_NAG_SC['gamma'], 
                'U':U,
                'g':g,
                'xi':xi,
                'g_star':g_star,
                'g_last':g_last,
                'nabla_g_last':nabla_g_last
            }).item()]
        elif scheme=='momentumless':
            lyap_list+=[loss.item()]
        loss_list+=[loss.item()]
    return {'loss_list':loss_list, 'lyap_list':lyap_list}

In [None]:
def get_convergence_rate(lyap_list):
    lyap_list=torch.Tensor(lyap_list)
    return torch.max(lyap_list[1:]/lyap_list[:-1])

In [None]:
kappa_list=np.arange(100, 10000, 500)
convergence_rate_dict={}
for scheme in ['heavy_ball', 'NAG_SC', 'momentumless']:
    convergence_rate_dict[scheme]=np.zeros_like(kappa_list, dtype=np.float64)
    for i, kappa in enumerate(kappa_list):
        result_dict=generate_trajectory(scheme, 0, 10, kappa, num_iter=100 if scheme!='momentumless' else 3000)
        convergence_rate_dict[scheme][i]=get_convergence_rate(result_dict['lyap_list'])

In [None]:
with open('convergence_rate_dict.pkl', 'wb') as f:
    pickle.dump(convergence_rate_dict, f)

In [None]:
with open('convergence_rate_dict.pkl', 'rb') as f:
    convergence_rate_dict=pickle.load(f)
kappa_list=np.arange(100, 10000, 500)


In [None]:
from scipy.stats import linregress
plt.plot(kappa_list, 1-convergence_rate_dict['heavy_ball'], color=color_list[1], label='heavy_ball')
result=linregress(kappa_list, convergence_rate_dict['heavy_ball']/(1-convergence_rate_dict['heavy_ball']))
fit_list=torch.from_numpy(result.slope*kappa_list)
# convergence_rate_list_fit=fit_list/(1+fit_list)
plt.plot(kappa_list, 1/fit_list,  linestyle='--', dashes=(2,2), color=color_list[1], label=r'$C \kappa^{-1}$')


plt.plot(kappa_list, 1-convergence_rate_dict['NAG_SC'], color=color_list[2], label='NAG_SC')
result=linregress(kappa_list, (convergence_rate_dict['NAG_SC']/(1-convergence_rate_dict['NAG_SC']))**2)
fit_list=torch.from_numpy(result.slope*kappa_list)
# convergence_rate_list_fit=torch.sqrt(fit_list)/(1+torch.sqrt(fit_list))
plt.plot(kappa_list, 1/torch.sqrt(fit_list),  linestyle='--', dashes=(2,2), color=color_list[2], label=r'$C \kappa^{-0.5}$')
plt.legend(fontsize=12)
plt.yscale('log')
# plt.plot(result_dict[scheme][kappa]['loss_list'], linestyle='--', dashes=(2,2), color=color_list[i], label='kappa='+str(kappa))
plt.ylabel('1-convergence rate', fontsize=18)
plt.xlabel('condition number', fontsize=18)
plt.xticks(fontsize=14, rotation=10)
plt.yticks(fontsize=14, rotation=60)

# plt.title('convergence rate for different condition numbers')
plt.savefig('LEV_conv_kappa.pdf', bbox_inches='tight')

In [None]:
exponent_list=[2, 3, 4]
color={'heavy_ball':color_list[0], 'NAG_SC':color_list[1], 'momentumless':color_list[2]}


In [None]:


for scheme in ['heavy_ball', 'NAG_SC', 'momentumless']:
    result_dict[scheme]={}
    for kappa in 10**exponent_list:
        result_dict[scheme][kappa]=generate_trajectory(scheme, 0, 10, kappa, num_iter=2000)

with open('result_dict.pkl', 'wb') as f:
    pickle.dump(result_dict, f)

In [None]:
with open('result_dict.pkl', 'rb') as f:
    result_dict=pickle.load(f)

In [None]:

for scheme in ['momentumless']:
    for i in exponent_list:
        kappa=10**i
        plt.plot(result_dict[scheme][kappa]['loss_list'], linestyle='--', dashes=(2,2), color=color_list[i], label='kappa='+str(kappa))


plt.yscale('log')
# plt.text(1250,1e-2,'solid:      lyap func\n dashed: loss func')
plt.legend(fontsize=12)
plt.yscale('log')
plt.ylabel('loss value', fontsize=18)
plt.xlabel('#iter', fontsize=18)
plt.xticks(fontsize=14, rotation=10)
plt.yticks(fontsize=14, rotation=60)
plt.savefig('LEV_momentumless.pdf', bbox_inches='tight')

In [None]:

for scheme in ['heavy_ball']:
    for i in exponent_list:
        kappa=10**i
        plt.plot(result_dict[scheme][kappa]['loss_list'], linestyle='--', dashes=(2,2), color=color_list[i], label='kappa='+str(kappa))
        plt.plot(result_dict[scheme][kappa]['lyap_list'], color=color_list[i])


plt.yscale('log')
# plt.text(1250,1e-2,'solid:      lyap func\n dashed: loss func')
plt.legend(fontsize=12)
plt.yscale('log')
plt.ylabel('loss value', fontsize=18)
plt.xlabel('#iter', fontsize=18)
plt.xticks(fontsize=14, rotation=10)
plt.yticks(fontsize=14, rotation=60)
plt.savefig('LEV_heavy_ball.pdf', bbox_inches='tight')

In [None]:

for scheme in ['NAG_SC']:
    for i in exponent_list:
        kappa=10**i
        plt.plot(result_dict[scheme][kappa]['loss_list'], linestyle='--', dashes=(2,2), color=color_list[i], label='kappa='+str(kappa))
        plt.plot(result_dict[scheme][kappa]['lyap_list'], color=color_list[i])

plt.yscale('log')
# plt.text(1250,1e-8,'solid:      lyap func\n dashed: loss func')
plt.legend(fontsize=12)
plt.yscale('log')
plt.ylabel('loss value', fontsize=18)
plt.xlabel('#iter', fontsize=18)
plt.xticks(fontsize=14, rotation=10)
plt.yticks(fontsize=14, rotation=60)
plt.savefig('LEV_NAG_SC.pdf', bbox_inches='tight')

# Non-convexity

In [None]:
def generate_trajectory(scheme, seed, dim, kappa, num_iter=5000):
    torch.manual_seed(0)
    eig_vals=generate_eig_value_artificial_conditional_number(dim, kappa)
    A, sol_dict=eig_val_decomp_problem(eig_vals, device=device)
    U=lambda X:eig_val_decomp_loss(A, X)
    mu=sol_dict['mu']
    L=sol_dict['L']
    min_val=sol_dict['min_val']
    X_sol=sol_dict['X_sol']
    xi_noise=torch.randn_like(A)*0.01
    xi_noise=xi_noise-xi_noise.T
    g_init=X_sol[:, np.arange(X_sol.shape[1]-1, -1, -1)]@torch.matrix_exp(xi_noise)

    
    g=torch.clone(g_init)
    g.requires_grad_(True)
    g_last=g.clone()
    g_star=X_sol
    if scheme=='heavy_ball':
        parameter_dict_HB=parameter_HB(mu, L)
        optimizer=LieGroupSGD([g], lr=parameter_dict_HB['h'], gamma=parameter_dict_HB['gamma'], scheme=scheme)
    elif scheme=='NAG_SC':
        parameter_dict_NAG_SC=parameter_NAG_SC(mu, L)
        optimizer=LieGroupSGD([g], lr=parameter_dict_NAG_SC['h'], gamma=parameter_dict_NAG_SC['gamma'], scheme=scheme)
    elif scheme=='momentumless':
        parameter_dict_NAG_SC=parameter_momentumless(mu, L)
        optimizer=LieGroupSGD([g], lr=parameter_dict_NAG_SC['h'], scheme=scheme)
    else:
        raise NotImplementedError
    loss_list=[]
    lyap_list=[]
    for i in tqdm(range(num_iter)):
        loss=eig_val_decomp_loss(A, g)-min_val
        optimizer.zero_grad()
        loss.backward()
        g_last.copy_(g)
        optimizer.step()
        xi=optimizer.state[g]['xi']

        if scheme=='heavy_ball':
            lyap_list+=[lyap_HB({'h':parameter_dict_HB['h'], 
                'gamma': parameter_dict_HB['gamma'], 
                'U':U,
                'g':g,
                'xi':xi,
                'g_star':g_star,
                'g_last':g_last
            }).item()]
        elif scheme=='NAG_SC':
            nabla_g_last=optimizer.state[g]['trivialized_grad_last']
            lyap_list+=[lyap_NAG_SC({'h':parameter_dict_NAG_SC['h'], 
                'gamma': parameter_dict_NAG_SC['gamma'], 
                'U':U,
                'g':g,
                'xi':xi,
                'g_star':g_star,
                'g_last':g_last,
                'nabla_g_last':nabla_g_last
            }).item()]
        loss_list+=[loss.item()]
    return {'loss_list':loss_list, 'lyap_list':lyap_list}

In [None]:
result_dict_non_convex={}
result_dict_non_convex['NAG_SC']=generate_trajectory('NAG_SC', 0, 10, 100, num_iter=10000)
result_dict_non_convex['heavy_ball']=generate_trajectory('heavy_ball', 0, 10, 100, num_iter=10000)
with open('result_dict_non_convex.pkl', 'wb') as f:
    pickle.dump(result_dict_non_convex, f)

In [None]:
with open('result_dict_non_convex.pkl', 'rb') as f:
    result_dict_non_convex=pickle.load(f)

In [None]:
for scheme in ['heavy_ball', 'NAG_SC']:
    plt.plot(result_dict_non_convex[scheme]['loss_list'], color=color[scheme], label=scheme)

plt.legend(fontsize=12)
plt.yscale('log')
plt.ylabel('U', fontsize=18)
plt.xlabel('#iter', fontsize=18)
plt.xlim([0, 8000])
plt.xticks(fontsize=14, rotation=10)
plt.yticks(fontsize=14, rotation=60)
plt.savefig('LEV_non_convex.pdf', bbox_inches='tight')