In [6]:
import os, yaml, matplotlib.pyplot as plt, numpy as np, collections, torch, PIL
suffixs = [250, 500, 1000, 1500, 2000]

In [7]:
main_dir_prefix = '../results_save/cat[]thinspiral_'
loss, support, mass, stds, means, final_loss = {}, {}, {}, {}, {}, collections.defaultdict(list)
for suffix in suffixs:
    main_dir = main_dir_prefix + str(suffix)
    loss[suffix], support[suffix], mass[suffix] = collections.defaultdict(list), {}, {}
    for sub_dir in os.listdir(main_dir):
        method_name = sub_dir[:sub_dir.find('_')]
        if 'S[0]' in sub_dir:
            support[suffix][method_name] = np.load(os.path.join(main_dir, sub_dir, 'support.npy'))
            mass[suffix][method_name] = np.load(os.path.join(main_dir, sub_dir, 'mass.npy'))
        _value = np.load(os.path.join(main_dir, sub_dir, 'w2.npy'))
        loss[suffix][method_name].append(_value)
for suffix in suffixs:
    stds[suffix], means[suffix] = {}, {}
    for method_name in loss[suffix].keys():
        loss[suffix][method_name] = np.array(loss[suffix][method_name])
        means[suffix][method_name] = np.mean(loss[suffix][method_name], axis = 0)
        stds[suffix][method_name] = np.std(loss[suffix][method_name], axis = 0)
        final_loss[method_name].append(means[suffix][method_name][-1])

In [8]:
save_dir = '../figures/morphing'
settings = yaml.load(open('../load_results/plot_settings.yaml').read(), Loader = yaml.FullLoader)
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
with open(os.path.join(save_dir, 'sketching_iter.txt'), mode = 'w') as f:
    for name in settings['order']:
        if name not in final_loss.keys(): continue
        values = final_loss[name]
        f.write('{} & \t'.format(settings['label'][name]) + '{:.3e} & {:.3e} & {:.3e} & {:.3e} & {:.3e}'.format(*values) + r'\\' + '\n')

In [9]:
# plot_result_iter_seperate
for num, (y_min1, y_max1), (y_min2, y_max2) in zip(suffixs, [(2.5e-2, 0.15), (1.5e-2, 0.1), (1.0e-2, 0.1), (1.0e-2, 0.1), (1.0e-2, 0.1)], [(1e-2, 4e-2), (6e-3, 4e-2), (4e-3, 4e-2), (4e-3, 4e-2), (4e-3, 4e-2)]):
    plt.figure(figsize=(7.2 * 2, 4.0 * 1))
    plt.subplot(121)
    for method_name in ['MMDF', 'MMDFCA', 'MMDFDK']:
        if method_name not in final_loss.keys(): continue
        x_axis = np.linspace(0, 15000, len(means[num][method_name]))
        plt.errorbar(
            x_axis, means[num][method_name], yerr = stds[num][method_name], capsize = 3, 
            color = settings['color'][method_name], 
            linestyle = settings['linestyle'][method_name], 
            label = settings['label'][method_name], 
            alpha = 1.0, linewidth = 1.5
        )
    plt.ylim(y_min1, y_max1)
    plt.yscale('log')
    plt.xscale('linear')
    plt.xlabel('Iterations, M = {}'.format(num), {'size': 14})
    plt.ylabel('2-Wasserstein Distance', {'size': 14})
    plt.tick_params(labelsize = 12)
    plt.legend(fontsize = 14)
    plt.subplot(122)
    for method_name in ['SD', 'SDCA', 'SDDK']:
        if method_name not in final_loss.keys(): continue
        x_axis = np.linspace(0, 300, len(means[num][method_name]))
        plt.errorbar(
            x_axis, means[num][method_name], yerr = stds[num][method_name], capsize = 3, 
            color = settings['color'][method_name], 
            linestyle = settings['linestyle'][method_name], 
            label = settings['label'][method_name], 
            alpha = 1.0, linewidth = 1.5
        )
    plt.ylim(y_min2, y_max2)
    plt.yscale('log')
    plt.xscale('linear')
    plt.xlabel('Iterations, M = {}'.format(num), {'size': 14})
    plt.ylabel('2-Wasserstein Distance', {'size': 14})
    plt.tick_params(labelsize = 12)
    plt.legend(fontsize = 14)
    plt.savefig(os.path.join(save_dir, 'sketching_%s_iter.pdf'%num), dpi = 120, bbox_inches = 'tight')
    plt.close()

### particles

In [10]:
particle_size = 5
for num in suffixs:
    _save_dir = os.path.join(save_dir, 'num%d'%num)
    if not os.path.exists(_save_dir):
        os.makedirs(_save_dir)
    for method_name in support[num].keys():
        _support = support[num][method_name][-1]
        _mass = mass[num][method_name][-1]
        _mass = (_mass / _mass.sum()) * len(_mass)
        size_list = _mass * particle_size
        fig = plt.figure(figsize = (4.0, 4.0))
        plt.scatter(_support[:, 0], _support[:, 1], alpha = 0.5, s = size_list, c = 'r', label = settings['label'][method_name])
        plt.legend(fontsize = 16, loc = 3, bbox_to_anchor=(-0.06, -0.06))
        plt.xlim((-0.06, 1.0))
        plt.ylim((-0.06, 1.0))
        plt.axis('off')
        plt.savefig(os.path.join(_save_dir, '%s.pdf'%(settings['label'][method_name])), dpi = 120, bbox_inches = 'tight')
        plt.close()