In [None]:
import pickle
import matplotlib.pyplot as plt
import numpy as np

legend_dict={}
legend_dict['original_RGAS']='Original (RGAS)'
legend_dict['original_RAGAS']='Original (RAGAS)'
legend_dict['StiefelSGD_ours']='Stiefel SGD (Ours)'
legend_dict['StiefelAdam_ours']='Stiefel Adam (Ours)'
legend_dict['MomentumlessStiefelSGD']='Momentumless Stiefel SGD'
legend_dict['ProjectedStiefelSGD']='Projected Stiefel SGD'
legend_dict['ProjectedStiefelAdam']='Projected Stiefel Adam'

method_list = ['original_RGAS', 'original_RAGAS', 'StiefelSGD_ours', 'MomentumlessStiefelSGD', 'ProjectedStiefelSGD']
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset


In [None]:
# MNIST
fig, ax = plt.subplots(figsize=(7,2.5))
axins = zoomed_inset_axes(ax, 4, loc=7)
color_cycle = ax._get_lines.prop_cycler

with open('OT_val_memory_dict_mnist.pkl', 'rb') as handle:
    OT_val_memory_dict_mnist = pickle.load(handle)
for method in method_list:
    color=next(color_cycle)['color']
    if method not in OT_val_memory_dict_mnist.keys():
        continue
    OT_list=[]
    for i in range(10):
        for j in range(i + 1, 10):
            OT_list+=[OT_val_memory_dict_mnist[method]['({}, {})'.format(i,j)]]
            
    OT_tensor=np.stack(OT_list, axis=1)
    mean_OT=np.mean(OT_tensor, axis=1)
    ax.plot(mean_OT/1000, label=legend_dict[method], color=color)
    axins.plot(mean_OT/1000, label=legend_dict[method], color=color)
ax.set_xlabel('iter', fontsize=13)
ax.set_ylabel('PRW distance mean', fontsize=13)
ax.set_ylim(0.0, 0.92)
axins.set_xlim(44, 49)
axins.set_xticks([44,45,46, 47, 48])
axins.set_ylim(0.80,0.9)
axins.set_yticks([0.8, 0.825,0.85,0.875, 0.9])
# axins.set_yticks([0.05, 0.075,0.10,0.125, 0.150, 0.175])

# plt.yscale('log')
ax.tick_params(axis='x', labelsize=13)
ax.tick_params(axis='y', labelsize=13)
axins.tick_params(axis='x', labelsize=13)
axins.tick_params(axis='y', labelsize=13)
ax.set_title('PRW distance between MNIST digits', fontsize=15)
# plt.legend()
mark_inset(ax, axins, loc1=1, loc2=2, fc="none", ec="0.5")
#  plt.text(0.5, 0.05,'projection robust Wasserstein \ndistance between MNIST digits',
#      horizontalalignment='left',
#      verticalalignment='bottom',
#      transform = ax.transAxes, 
#      size=13)
plt.savefig('./PRW_mnist.pdf',  bbox_inches='tight')
plt.show()


    

In [None]:
# Shakespeare

fig, ax = plt.subplots(figsize=(7,2.5))
axins = zoomed_inset_axes(ax, 8, loc=7)
color_cycle = ax._get_lines.prop_cycler

scripts = ['H5.txt', 'Ham.txt', 'JC.txt', 'MV.txt', 'Oth.txt', 'Rom.txt']


with open('OT_val_memory_dict_shakespeare.pkl', 'rb') as handle:
    OT_val_memory_dict_shakespeare = pickle.load(handle)
for method in method_list:
    color=next(color_cycle)['color']
    if method not in OT_val_memory_dict_shakespeare.keys():
        continue
    OT_list=[]
    for art1 in scripts:
        for art2 in scripts:
            i = scripts.index(art1)
            j = scripts.index(art2)
            if i < j:
                OT_list+=[OT_val_memory_dict_shakespeare[method]['({}, {})'.format(i,j)]]
            
    OT_tensor=np.stack(OT_list, axis=1)
    mean_OT=np.mean(OT_tensor, axis=1)
    ax.plot(mean_OT, label=legend_dict[method], color=color)
    axins.plot(mean_OT, label=legend_dict[method], color=color)
ax.set_xlabel('iter', fontsize=13)#, loc='right')
# ax.set_ylabel('optimal transport value mean', fontsize=13)
ax.set_ylabel('PRW distance mean', fontsize=13)
ax.set_ylim(0.04, 0.195)

axins.set_xlim(46, 49)
axins.set_xticks([46, 47, 48])
axins.set_yticks([0.184, 0.186,0.188,0.190])
# ax.set_yticks([0.05, 0.075,0.10,0.125, 0.150, 0.175])

axins.set_ylim(0.182,0.191)
# plt.yscale('log')
ax.set_title('PRW distance between Shakespeare plays', fontsize=15)
# plt.legend()
ax.tick_params(axis='x', labelsize=13)
ax.tick_params(axis='y', labelsize=13)
axins.tick_params(axis='x', labelsize=13)
axins.tick_params(axis='y', labelsize=13)
mark_inset(ax, axins, loc1=1, loc2=2, fc="none", ec="0.5")
# plt.text(0.35, 0.02,'projection robust Wasserstein \ndistance between Shakespeare plays',
#      horizontalalignment='left',
#      verticalalignment='bottom',
#      transform = ax.transAxes, 
#      size=13)
plt.savefig('./PRW_shakespeare.pdf',  bbox_inches='tight')# 
plt.show()


    

In [None]:
import pylab
figlegend = pylab.figure(figsize=(3,2))
figlegend.legend(ax.get_legend_handles_labels()[0], ax.get_legend_handles_labels()[1])

figlegend.savefig('PRW_legend.pdf', bbox_inches='tight')
