In [None]:
import torch
import scipy.integrate as integrate
import numpy as np
import sys
sys.path.append('../Optimizers')
sys.path.append('..')
import time
import matplotlib.pyplot as plt
import pickle

In [None]:
method_list=['StiefelSGD_ours', 'StiefelAdam_ours', 'ProjectedStiefelSGD', 'ProjectedStiefelAdam', 'MomentumlessStiefelSGD']

transpose_needed=['ProjectedStiefelAdam', 'ProjectedStiefelSGD']

legend_dict={'RegularizerStiefelSGD':'Regularizer SGD (Cisse et al)', 
'RegularizerStiefelAdam':'Regularizer Adam (Cisse et al)',
'StiefelSGD_ours':'Stiefel SGD (Ours)', 
'StiefelAdam_ours':'Stiefel Adam (Ours)', 
'ProjectedStiefelSGD':'Projected Stiefel SGD (Li et al)', 
'ProjectedStiefelAdam':'Projected Stiefel Adam (Li et al)',
 'MomentumlessStiefelSGD':'Momentumles Stiefel SGD (Wen & Yin)',
 'LiCombinedOptimizer' : 'Our retraction + Li algo'}

 
optimizer_dict={}

from StiefelRegularizer import RegularizerStiefelSGD
optimizer_dict['RegularizerStiefelSGD']=lambda param: RegularizerStiefelSGD(param, lr=1e-3, momentum=0.9, stiefel_regularizer=1)

from StiefelRegularizer import RegularizerStiefelAdam
optimizer_dict['RegularizerStiefelAdam']=lambda param: RegularizerStiefelAdam(param, lr=1e-3, betas=(0.9, 0.999), stiefel_regularizer=1)

from StiefelOptimizers import StiefelSGD
optimizer_dict['StiefelSGD_ours']=lambda param: StiefelSGD(param, lr=1e-1, momentum=0.9)

from StiefelOptimizers import StiefelAdam
optimizer_dict['StiefelAdam_ours']=lambda param: StiefelAdam(param, lr=1e-3, betas=(0.9, 0.999))

from ProjectedStiefelOptimizer.stiefel_optimizer import SGDG as ProjectedStiefelSGD
optimizer_dict['ProjectedStiefelSGD']=lambda param: ProjectedStiefelSGD(param, lr=2e-1, momentum=0.9, stiefel=True)

from ProjectedStiefelOptimizer.stiefel_optimizer import AdamG as ProjectedStiefelAdam
optimizer_dict['ProjectedStiefelAdam']=lambda param: ProjectedStiefelAdam(param, lr=5e-0, momentum=0.9, beta2= 0.999, stiefel=True)

from MomentumlessStiefelSGD import MomentumlessStiefelSGD
optimizer_dict['MomentumlessStiefelSGD']=lambda param: MomentumlessStiefelSGD(param, lr=1e-1)

from LiCombinedOptimizer import LiCombinedOptimizer
optimizer_dict['LiCombinedOptimizer']=lambda param: LiCombinedOptimizer(param, lr=1e-1, momentum=0.9)

device=torch.device('cpu')
torch.set_default_dtype(torch.float64)



In [None]:
import torch
import numpy as np
def lev_problem(n,m, device='cpu', dtype=None):
    assert n >= m
    if dtype==None:
        dtype=torch.get_default_dtype()

    A=torch.randn(n, n, device=device, dtype=dtype)
    A=(A+A.t())/2/np.sqrt(n)
    X_init=torch.zeros(n, m, device=device, dtype=dtype)
    torch.nn.init.orthogonal_(X_init)
    eig_vals=torch.linalg.eigh(A).eigenvalues
    eig_vals=eig_vals.sort(descending=True).values
    sol=torch.sum(eig_vals[0:m])
    return A, X_init, sol


def lev_loss(A, X):
    return -torch.trace(X.t()@A@X)


# Convergence and deviation

In [None]:


torch.manual_seed(0)
n=1000
m=10

A, X_init, sol=lev_problem(n,m, device=device)

num_iter=5000

dev_dict={}
loss_dict={}


for method in method_list:
    print(method)
    loss_dict[method]=[]
    dev_dict[method]=[]
    loss_mem=loss_dict[method]
    dev_mem=dev_dict[method]
    X=X_init.clone().to(device)
    if method in transpose_needed:
        Y=X.t()
        Y.requires_grad=True
        X=Y.t()
        optimizer=optimizer_dict[method]([Y])
    else:
        X.requires_grad=True
        optimizer=optimizer_dict[method]([X])
    t=time.time()
    for i in range(num_iter):
        optimizer.zero_grad()
        loss=lev_loss(A, X)+sol
        loss.backward()
        
        optimizer.step()
        dev=torch.norm(X.t()@X-torch.eye(m).to(device))
        dev_mem.append(dev.item())
        loss_mem.append(loss.item())
        if i%100==0:
            print(loss.item())

with open('dev_dict.pkl', 'wb') as handle:
    pickle.dump(dev_dict, handle)
with open('loss_dict.pkl', 'wb') as handle:
    pickle.dump(loss_dict, handle)

# Time dependent on m with fixed n/m

In [None]:



torch.manual_seed(0)

p=200
m_range=[8,16,32,64,128, 256]
nm_ratio=10


mean_iter=100

time_dict_m={}
for method in method_list:
    time_dict_m[method]=[None]*len(m_range)
    time_mem=time_dict_m[method]
    for idx, m in enumerate(m_range):
        n=round(nm_ratio*m)
        print((method, m))
        
        A, X_init, sol=lev_problem(n,m, device=device)
        X=X_init.clone().to(device)
        if method in transpose_needed:
            Y=X.t()
            Y.requires_grad=True
            X=Y.t()
            optimizer=optimizer_dict[method]([Y])
        else:
            X.requires_grad=True
            optimizer=optimizer_dict[method]([X])
        time_mem[idx]=0
        for i in range(mean_iter):
            optimizer.zero_grad()
            loss=lev_loss(A, X)+sol
            loss.backward()
            t=time.time()
            optimizer.step()
            time_mem[idx]+=(time.time()-t)
        time_mem[idx]/=mean_iter


with open('time_dict_m.pkl', 'wb') as handle:
    pickle.dump(time_dict_m, handle)


# Time dependent on n with fixed m

In [None]:
torch.set_num_threads(1)
torch.set_default_dtype(torch.float64)
device=torch.device('cpu')

In [None]:

torch.set_default_dtype(torch.float64)

torch.manual_seed(0)

p=200
n_range=[100, 200, 300, 500, 750, 1000, 2000, 3000, 5000]
m=10

mean_iter=100

time_dict_n={}
for method in method_list:
    time_dict_n[method]=[None]*len(n_range)
    time_mem=time_dict_n[method]
    for idx, n in enumerate(n_range):
        print((method, n))
        
        A, X_init, sol=lev_problem(n,m, device=device)
        X=X_init.clone().to(device)
        if method in transpose_needed:
            Y=X.t()
            Y.requires_grad=True
            X=Y.t()
            optimizer=optimizer_dict[method]([Y])
        else:
            X.requires_grad=True
            optimizer=optimizer_dict[method]([X])
        time_mem[idx]=0
        for i in range(mean_iter):
            optimizer.zero_grad()
            loss=lev_loss(A, X)+sol
            loss.backward()

            t=time.time()
            optimizer.step()
            time_mem[idx]+=(time.time()-t)
        time_mem[idx]/=mean_iter


with open('time_dict_n.pkl', 'wb') as handle:
    pickle.dump(time_dict_n, handle)

# Different inner product and matrix

In [None]:
torch.set_default_dtype(torch.float64)

torch.manual_seed(0)
n=1000
m=10
p=200

A, X_init, sol=lev_problem(n,m, device=device)

num_iter=5000

expm_innerprod_adam_dict={}
expm_method_list=['MatrixExp', 'Cayley', 'ForwardEuler']
inner_prod_list=['Euclidean', 'Canonical']
for expm_method in expm_method_list:
    for inner_prod in inner_prod_list:
        name=expm_method+'+'+inner_prod
        optimizer_func=lambda param: StiefelAdam(param, lr=1e-3, betas=(0.9, 0.999),inner_prod=inner_prod, expm_method=expm_method)
        loss_mem=[]
        dev_mem=[]
        X=X_init.clone().to(device)
        X.requires_grad=True
        optimizer=optimizer_func([X])
        t=time.time()
        for i in range(num_iter):
            optimizer.zero_grad()
            loss=lev_loss(A, X)+sol
            loss.backward()
            
            optimizer.step()
            dev=torch.norm(X.t()@X-torch.eye(m).to(device))
            dev_mem.append(dev.item())
            loss_mem.append(loss.item())
            if i%100==0:
                print(loss.item())
            time_comsuming=time.time()-t
        expm_innerprod_adam_dict[name]=loss_mem
    
with open('expm_innerprod_adam_dict.pkl', 'wb') as handle:
    pickle.dump(expm_innerprod_adam_dict, handle)


In [None]:
torch.manual_seed(0)
n=1000
m=10
p=200

A, X_init, sol=lev_problem(n,m, device=device)

num_iter=5000

expm_innerprod_sgd_dict={}
expm_method_list=['MatrixExp', 'Cayley', 'ForwardEuler']
inner_prod_list=['Euclidean', 'Canonical']
for expm_method in expm_method_list:
    for inner_prod in inner_prod_list:
        name=expm_method+'+'+inner_prod
        optimizer_func=lambda param: StiefelSGD(param, lr=1e-1, momentum=0.9,inner_prod=inner_prod, expm_method=expm_method)
        loss_mem=[]
        dev_mem=[]
        X=X_init.clone().to(device)
        X.requires_grad=True
        optimizer=optimizer_func([X])
        t=time.time()
        for i in range(num_iter):
            optimizer.zero_grad()
            loss=lev_loss(A, X)+sol
            loss.backward()
            
            optimizer.step()
            dev=torch.norm(X.t()@X-torch.eye(m).to(device))
            dev_mem.append(dev.item())
            loss_mem.append(loss.item())
            if i%100==0:
                print(loss.item())
            time_comsuming=time.time()-t
        expm_innerprod_sgd_dict[name]=loss_mem
    
with open('expm_innerprod_sgd_dict.pkl', 'wb') as handle:
    pickle.dump(expm_innerprod_sgd_dict, handle)


# Li's Projected Stiefel SGD with and Our retraction

In [None]:
method_list=['LiCombinedOptimizer', 'ProjectedStiefelSGD', 'StiefelSGD_ours']


torch.manual_seed(0)
n=1000
m=10

A, X_init, sol=lev_problem(n,m, device=device)

num_iter=2000

loss_dict={}
dev_dict={}

for method in method_list:
    print(method)
    loss_dict[method]=[]
    dev_dict[method]=[]
    loss_mem=loss_dict[method]
    dev_mem=dev_dict[method]
    X=X_init.clone().to(device)
    if method in transpose_needed:
        Y=X.t()
        Y.requires_grad=True
        X=Y.t()
        optimizer=optimizer_dict[method]([Y])
    else:
        X.requires_grad=True
        optimizer=optimizer_dict[method]([X])
    t=time.time()
    for i in range(num_iter):
        optimizer.zero_grad()
        loss=lev_loss(A, X)+sol
        loss.backward()
        
        optimizer.step()
        dev=torch.norm(X.t()@X-torch.eye(m).to(device))
        dev_mem.append(dev.item())
        loss_mem.append(loss.item())
        if i%100==0:
            print(loss.item())

with open('dev_dict_LiCombined.pkl', 'wb') as handle:
    pickle.dump(dev_dict, handle)
with open('loss_dict_LiCombined.pkl', 'wb') as handle:
    pickle.dump(loss_dict, handle)

# Number for inner loop

In [None]:


torch.manual_seed(0)
n=100
m=10

A, X_init, sol=lev_problem(n,m, device=device)

num_iter=2000

dev_loop_dict={}
loss_loop_dict={}

method_list=['ProjectedStiefelSGD']
'''
loop_list = [2]
qr_every_list = [1,2,4,8,16]
'''
loop_list = [1,2,4,6, 8,16]
qr_every_list = [int(1e6)]

for method in method_list:
    for loop_num in loop_list:
        for qr_every in qr_every_list:
            
            dev_dict = dev_loop_dict[str((loop_num, qr_every))] = {}
            loss_dict = loss_loop_dict[str((loop_num, qr_every))] = {}
            print(method)
            loss_dict[method]=[]
            dev_dict[method]=[]
            loss_mem=loss_dict[method]
            dev_mem=dev_dict[method]
            X=X_init.clone().to(device)
            if method in transpose_needed:
                Y=X.t()
                Y.requires_grad=True
                X=Y.t()
                optimizer=optimizer_dict[method]([Y])
            else:
                X.requires_grad=True
                optimizer=optimizer_dict[method]([X])

            optimizer.param_groups[0]['QR_every'] = qr_every
            optimizer.param_groups[0]['Cayley_loop_num'] = loop_num
            t=time.time()
            for i in range(num_iter):
                optimizer.zero_grad()
                loss=lev_loss(A, X)+sol
                loss.backward()
                
                optimizer.step()
                dev=torch.norm(X.t()@X-torch.eye(m).to(device))
                dev_mem.append(dev.item())
                loss_mem.append(loss.item())
                if i%100==0:
                    print(loss.item())
with open('dev_loop_dict_Li.pkl', 'wb') as handle:
    pickle.dump(dev_loop_dict, handle)
with open('loss_loop_dict_Li.pkl', 'wb') as handle:
    pickle.dump(loss_loop_dict, handle)

# QR frequency is import in projected Stiefel SGD

In [None]:


torch.manual_seed(0)
n=100
m=10

A, X_init, sol=lev_problem(n,m, device=device)

num_iter=1500

dev_loop_dict={}
loss_loop_dict={}

method_list=['ProjectedStiefelSGD']

loop_list = [5]
qr_every_list = [1,2]
'''
loop_list = [1,2,4,6, 8,16]
qr_every_list = [1e6]
'''
for method in method_list:
    for loop_num in loop_list:
        for qr_every in qr_every_list:
            
            dev_dict = dev_loop_dict[str((loop_num, qr_every))] = {}
            loss_dict = loss_loop_dict[str((loop_num, qr_every))] = {}
            print(method)
            loss_dict[method]=[]
            dev_dict[method]=[]
            loss_mem=loss_dict[method]
            dev_mem=dev_dict[method]
            X=X_init.clone().to(device)
            if method in transpose_needed:
                Y=X.t()
                Y.requires_grad=True
                X=Y.t()
                optimizer=optimizer_dict[method]([Y])
            else:
                X.requires_grad=True
                optimizer=optimizer_dict[method]([X])

            optimizer.param_groups[0]['QR_every'] = qr_every
            optimizer.param_groups[0]['Cayley_loop_num'] = loop_num
            t=time.time()
            for i in range(num_iter):
                optimizer.zero_grad()
                loss=lev_loss(A, X)+sol
                loss.backward()
                
                optimizer.step()
                dev=torch.norm(X.t()@X-torch.eye(m).to(device))
                dev_mem.append(dev.item())
                loss_mem.append(loss.item())
                if i%100==0:
                    print(loss.item())


method ='StiefelSGD_ours'
            
dev_dict = dev_loop_dict[method] = {}
loss_dict = loss_loop_dict[method] = {}
print(method)
loss_dict[method]=[]
dev_dict[method]=[]
loss_mem=loss_dict[method]
dev_mem=dev_dict[method]
X=X_init.clone().to(device)
if method in transpose_needed:
    Y=X.t()
    Y.requires_grad=True
    X=Y.t()
    optimizer=optimizer_dict[method]([Y])
else:
    X.requires_grad=True
    optimizer=optimizer_dict[method]([X])


t=time.time()
for i in range(num_iter):
    optimizer.zero_grad()
    loss=lev_loss(A, X)+sol
    loss.backward()
    
    optimizer.step()
    dev=torch.norm(X.t()@X-torch.eye(m).to(device))
    dev_mem.append(dev.item())
    loss_mem.append(loss.item())
    if i%100==0:
        print(loss.item())


with open('projected_qr_dev.pkl', 'wb') as handle:
    pickle.dump(dev_loop_dict, handle)
with open('projected_qr_loss.pkl', 'wb') as handle:
    pickle.dump(loss_loop_dict, handle)

# Convergence and deviation from manifold under different number of inner loop in our Stiefel SGD

In [None]:

method_list=['StiefelSGD_ours']

torch.manual_seed(0)
n=1000
m=10

A, X_init, sol=lev_problem(n,m, device=device)

num_iter=1500

dev_loop_dict={}
loss_loop_dict={}

loop_list = [1,2, 3,4, 5, 6, 7, 8, 100]

for loop_num in loop_list:
        
    dev_dict = dev_loop_dict[str(loop_num)] = {}
    loss_dict = loss_loop_dict[str(loop_num)] ={}
    for method in method_list:
        print(method)
        loss_dict[method]=[]
        dev_dict[method]=[]
        loss_mem=loss_dict[method]
        dev_mem=dev_dict[method]
        X=X_init.clone().to(device)
        if method in transpose_needed:
            Y=X.t()
            Y.requires_grad=True
            X=Y.t()
            optimizer=optimizer_dict[method]([Y])
        else:
            X.requires_grad=True
            optimizer=optimizer_dict[method]([X])
        optimizer.param_groups[0]['max_inner_iter'] = loop_num
        t=time.time()
        for i in range(num_iter):
            optimizer.zero_grad()
            loss=lev_loss(A, X)+sol
            loss.backward()
            
            optimizer.step()
            dev=torch.norm(X.t()@X-torch.eye(m).to(device))
            dev_mem.append(dev.item())
            loss_mem.append(loss.item())
            if i%100==0:
                print(loss.item())

with open('dev_loop_dict_ours.pkl', 'wb') as handle:
    pickle.dump(dev_loop_dict, handle)
with open('loss_loop_dict_ours.pkl', 'wb') as handle:
    pickle.dump(loss_loop_dict, handle)

# Plot

In [None]:

with open('expm_innerprod_sgd_dict.pkl', 'rb') as handle:
    expm_innerprod_sgd_dict = pickle.load(handle)


label_dict={'MatrixExp': 'Matrix Exp', 'Cayley':'Cayley map', 'ForwardEuler':'Forward Euler', 'Euclidean':'Euclidean', 'Canonical':'Canonical'}


for expm_method in expm_method_list:
    for inner_prod in inner_prod_list:
        name=label_dict[expm_method]+'+'+label_dict[inner_prod]
        plt.plot(np.abs(expm_innerprod_sgd_dict[expm_method+'+'+inner_prod]), label=name)
plt.xlabel('iter')
plt.ylabel('loss')
plt.yscale('log')
# plt.title('Convergence for inner prod and expm approx (Stiefel Adam)')
# plt.title('Convergence for inner prod and expm approx (Stiefel SGD)')
# plt.text(0.25, 0.2,'Convergence for inner prod \n and expm approx (Stiefel SGD)',
#      horizontalalignment='left',
#      verticalalignment='bottom',
#      transform = ax.transAxes, 
#      size=13)
plt.legend()
# plt.savefig('./lev_compare_Adam.pdf', bbox_inches='tight')
plt.savefig('./lev_compare_SGD.pdf', bbox_inches='tight')
plt.show()

In [None]:

with open('expm_innerprod_adam_dict.pkl', 'rb') as handle:
    expm_innerprod_adam_dict = pickle.load(handle)

label_dict={'MatrixExp': 'Matrix Exp', 'Cayley':'Cayley map', 'ForwardEuler':'Forward Euler', 'Euclidean':'Euclidean', 'Canonical':'Canonical'}
for expm_method in expm_method_list:
    for inner_prod in inner_prod_list:
        name=label_dict[expm_method]+'+'+label_dict[inner_prod]
        plt.plot(np.abs(expm_innerprod_adam_dict[expm_method+'+'+inner_prod]), label=name)
plt.xlabel('iter')
plt.ylabel('loss')
plt.yscale('log')
# plt.title('Convergence for inner prod and expm approx (Stiefel Adam)')
# plt.title('Convergence for inner prod and expm approx (Stiefel SGD)')
# plt.text(0.05, 0.05,'Convergence for inner \n prod and expm \n approx (Stiefel Adam)',
#      horizontalalignment='left',
#      verticalalignment='bottom',
     # transform = ax.transAxes, 
     # size=13)
plt.legend()
plt.savefig('./lev_compare_Adam.pdf', bbox_inches='tight')
# plt.savefig('./lev_compare_SGD.pdf', bbox_inches='tight')
plt.show()

In [None]:
with open('dev_dict.pkl', 'rb') as handle:
    dev_dict = pickle.load(handle)
with open('loss_dict.pkl', 'rb') as handle:
    loss_dict = pickle.load(handle)

method_list=['StiefelSGD_ours', 'StiefelAdam_ours', 'ProjectedStiefelSGD', 'ProjectedStiefelAdam', 'MomentumlessStiefelSGD']
fig, ax = plt.subplots()
for method in method_list:
    plt.plot(loss_dict[method], label=legend_dict[method])
ax.set_xlabel('iter')
ax.set_ylabel('loss')
plt.yscale('log')
# plt.title('Convergence')
# plt.legend()
plt.text(0.15, 0.85,'Convergence',
     horizontalalignment='left',
     verticalalignment='bottom',
     transform = ax.transAxes, 
     size=13)
plt.savefig('./lev_convergence.pdf', bbox_inches='tight')
plt.show()

for method in method_list:
    plt.plot(dev_dict[method], label=legend_dict[method])
plt.xlabel('iter')
plt.ylabel('deviation')
plt.yscale('log')
# plt.title('Deviation from manifold')
# plt.legend()
plt.text(0.15, 0.85,'Manifold preservance',
     horizontalalignment='left',
     verticalalignment='bottom',
     transform = ax.transAxes, 
     size=13)
plt.savefig('./lev_dev.pdf', bbox_inches='tight')
plt.show()

In [None]:

import pylab
figlegend = pylab.figure(figsize=(3,2))
figlegend.legend(ax.get_legend_handles_labels()[0], ax.get_legend_handles_labels()[1])

figlegend.savefig('lev_legend.pdf', bbox_inches='tight')


In [None]:
with open('time_dict_m.pkl', 'rb') as handle:
    time_dict_m = pickle.load(handle)

for method in method_list:
    plt.plot(m_range, time_dict_m[method], label=legend_dict[method])
plt.xlabel('m (fix n/m=10)')
plt.ylabel('time consuming on CPU (s per iter)')
plt.yscale('log')
plt.xscale('log')
# plt.title('Time dependent on m (fix n/m) (log scale)')
# plt.legend()
plt.text(0.45, 0.05,'Time complexity against m \n (fix n/m=10)',
     horizontalalignment='left',
     verticalalignment='bottom',
     transform = ax.transAxes, 
     size=13)
plt.savefig('./lev_time_m_log.pdf', bbox_inches='tight')
plt.show()

In [None]:
with open('time_dict_n.pkl', 'rb') as handle:
    time_dict_n = pickle.load(handle)

for method in method_list:
    plt.plot(n_range, time_dict_n[method], label=legend_dict[method])
plt.xlabel('n (fix m=10)')
plt.ylabel('time consuming on CPU')
plt.yscale('log')
plt.xscale('log')
# plt.title('Time dependent on n (fix m) (log scale)')
# plt.legend()
plt.text(0.05, 0.85,'Time complexity against n \n (fix m=10)',
     horizontalalignment='left',
     verticalalignment='bottom',
     transform = ax.transAxes, 
     size=13)
plt.savefig('./lev_time_n_log.pdf', bbox_inches='tight')
plt.show()

In [None]:
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w
window_width=10


with open('projected_qr_dev.pkl', 'rb') as handle:
    dev_loop_dict = pickle.load(handle)
with open('projected_qr_loss.pkl', 'rb') as handle:
    loss_loop_dict = pickle.load(handle)

method_list=['ProjectedStiefelSGD']

loop_list = [5]
qr_every_list = [1, 2]


fig, ax = plt.subplots()

for loop_num in loop_list:
    for qr_every in qr_every_list:
        
        dev_dict = dev_loop_dict[str((loop_num, qr_every))]
        loss_dict = loss_loop_dict[str((loop_num, qr_every))]
        for method in method_list:
            plt.plot(moving_average(np.abs(loss_dict[method]), window_width), label='Cayley iter '+str(loop_num))
plt.plot(moving_average(np.abs(loss_loop_dict['StiefelSGD_ours']['StiefelSGD_ours']), window_width), label='Cayley iter '+str(loop_num))
ax.set_xlabel('iter')
ax.set_ylabel('loss (abs val)')
plt.yscale('log')
plt.title('Convergence')
plt.legend()
plt.savefig('./cayley_iter_convergence.pdf', bbox_inches='tight')
plt.show()

for loop_num in loop_list:
    for qr_every in qr_every_list:
        dev_dict = dev_loop_dict[str((loop_num, qr_every))]
        loss_dict = loss_loop_dict[str((loop_num, qr_every))]
        for method in method_list:
            plt.plot(moving_average(dev_dict[method], window_width), label='Cayley iter '+str(loop_num))
plt.plot(moving_average(np.abs(dev_loop_dict['StiefelSGD_ours']['StiefelSGD_ours']), window_width), label='Cayley iter '+str(loop_num))

plt.xlabel('iter')
plt.ylabel('deviation')
plt.yscale('log')
plt.title('Deviation from manifold')
plt.legend()
plt.savefig('./cayley_iter_dev.pdf', bbox_inches='tight')
plt.show()

In [None]:
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w
window_width=1


with open('dev_dict_LiCombined.pkl', 'rb') as handle:
    dev_dict = pickle.load(handle)
with open('loss_dict_LiCombined.pkl', 'rb') as handle:
    loss_dict = pickle.load(handle)

method_list=['LiCombinedOptimizer', 'ProjectedStiefelSGD', 'StiefelSGD_ours']

plt.figure(figsize=(15,5))
ax1 = plt.subplot(1, 2, 1)

for method in method_list:
    ax1.plot(moving_average(np.abs(loss_dict[method]), window_width), label=legend_dict[method])
ax1.set_xlabel('iter')
ax1.set_ylabel('loss')
plt.yscale('log')
plt.title('Convergence')
plt.legend()
# plt.savefig('./LiWithOurRetraction_convergence.jpg', bbox_inches='tight')
# plt.show()
ax2 = plt.subplot(1, 2, 2)

for method in method_list:
    ax2.plot(moving_average(dev_dict[method], window_width), label=legend_dict[method])
ax2.set_xlabel('iter')
ax2.set_ylabel('deviation')
ax2.set_yscale('log')
ax2.set_title('Manifold preservance')
plt.legend()
plt.savefig('./LiWithOurRetraction.pdf', bbox_inches='tight')
plt.show()

In [None]:
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w
window_width=10


with open('dev_loop_dict_ours.pkl', 'rb') as handle:
    dev_loop_dict = pickle.load(handle)
with open('loss_loop_dict_ours.pkl', 'rb') as handle:
    loss_loop_dict = pickle.load(handle)

method_list=['StiefelSGD_ours']
loop_list = ['1', '2', '3', '4', '5', '6', '7', '8', '100']


fig, ax = plt.subplots()

for loop_num in loop_list:
        
        dev_dict = dev_loop_dict[str(loop_num)]
        loss_dict = loss_loop_dict[str(loop_num)]
        for method in method_list:
            plt.plot(moving_average(np.abs(loss_dict[method]), window_width), label='mat root inv '+str(loop_num)+' iter')
ax.set_xlabel('iter')
ax.set_ylabel('loss')
plt.yscale('log')
plt.title('Convergence')
plt.legend()
plt.savefig('./ours_inner_iter_convergence.pdf', bbox_inches='tight')
plt.show()

for loop_num in loop_list:
    dev_dict = dev_loop_dict[str(loop_num)]
    loss_dict = loss_loop_dict[str(loop_num)]
    for method in method_list:
        plt.plot(moving_average(dev_dict[method], window_width), label='mat root inv '+str(loop_num)+' iter')
plt.xlabel('iter')
plt.ylabel('deviation')
plt.yscale('log')
plt.title('Deviation from manifold')
plt.legend()
plt.savefig('./ours_inner_iter_dev.pdf', bbox_inches='tight')
plt.show()

In [None]:
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w
window_width=10


with open('dev_loop_dict_Li.pkl', 'rb') as handle:
    dev_loop_dict = pickle.load(handle)
with open('loss_loop_dict_Li.pkl', 'rb') as handle:
    loss_loop_dict = pickle.load(handle)

method_list=['ProjectedStiefelSGD']

loop_list = [1,2,4,8,16]
qr_every_list = [int(1e6)]


fig, ax = plt.subplots()

for loop_num in loop_list:
    for qr_every in qr_every_list:
        
        dev_dict = dev_loop_dict[str((loop_num, qr_every))]
        loss_dict = loss_loop_dict[str((loop_num, qr_every))]
        for method in method_list:
            plt.plot(moving_average(np.abs(loss_dict[method]), window_width), label='Cayley iter '+str(loop_num))
ax.set_xlabel('iter')
ax.set_ylabel('loss (abs val)')
plt.yscale('log')
plt.title('Convergence')
plt.legend()
plt.savefig('./cayley_iter_convergence.pdf', bbox_inches='tight')
plt.show()

for loop_num in loop_list:
    for qr_every in qr_every_list:
        dev_dict = dev_loop_dict[str((loop_num, qr_every))]
        loss_dict = loss_loop_dict[str((loop_num, qr_every))]
        for method in method_list:
            plt.plot(moving_average(dev_dict[method], window_width), label='Cayley iter '+str(loop_num))
plt.xlabel('iter')
plt.ylabel('deviation')
plt.yscale('log')
plt.title('Deviation from manifold')
plt.legend()
plt.savefig('./cayley_iter_dev.pdf', bbox_inches='tight')
plt.show()